Causal online MIMO lattice prediction¶
Tutorial goal
Use block-Levinson matrix reflections in a sample-by-sample causal MIMO lattice predictor.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
This tutorial separates coefficient estimation from runtime filtering. The matrix reflection coefficients are estimated once from a finite training record, then a stateful lattice predictor runs online. At each time step it predicts the next vector from stored backward-error states before the current vector is observed.
Key idea and equations¶
With forward and backward matrix reflection coefficients \(K_m\) and \(L_m\), the online vector lattice recursion is
The one-step prediction is obtained before seeing \(y[n]\) by evaluating the same recursion with \(y[n]=0\) and negating the result:
After the true vector is observed, \(f_p[n]=y[n]-\hat y[n]\) is the forward prediction error and the backward-error states are updated. This is causal in the prediction sense: \(\hat y[n]\) depends only on stored states from samples \(< n\).
Causality and data use¶
The block-Levinson fit is a batch estimation step. The MIMOLatticePredictor object created from those reflection matrices is a runtime object: predict() uses only previous vectors, and update(y_n) consumes the current vector afterward.
What this example verifies¶
This verifies the online contract for MIMOLatticePredictor. Prediction
is requested before the current vector is observed, then update(y_n)
consumes the current vector and advances the backward-error state. The
residual is compared with the direct VAR residual from the same fitted
block-Levinson model.
How to read the result¶
Check that the online lattice residual matches the direct AR residual from the same block-Levinson fit, then inspect the trace, reflection-norm, and residual-covariance figures.
Run command¶
python examples/causal_mimo_lattice_prediction.py
Source code¶
1"""Online causal MIMO lattice prediction from block Levinson reflections."""
2
3from __future__ import annotations
4
5import os
6from pathlib import Path
7
8import numpy as np
9
10import lattice_dsp as ld
11
12
13def _artifact_dir() -> Path:
14 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
15 path.mkdir(parents=True, exist_ok=True)
16 return path
17
18
19def simulate_var(coefficients: np.ndarray, samples: int = 12000, seed: int = 31) -> np.ndarray:
20 """Generate a stable coupled vector AR process."""
21
22 rng = np.random.default_rng(seed)
23 order, channels, _ = coefficients.shape
24 x = np.zeros((samples + 512, channels), dtype=np.float64)
25 noise = rng.normal(scale=0.35, size=x.shape)
26 for n in range(order, x.shape[0]):
27 value = noise[n].copy()
28 for lag in range(1, order + 1):
29 value -= coefficients[lag - 1] @ x[n - lag]
30 x[n] = value
31 return x[512:]
32
33
34def _normalized_covariance(x: np.ndarray) -> np.ndarray:
35 centered = x - np.mean(x, axis=0, keepdims=True)
36 cov = centered.T @ centered.conj() / max(x.shape[0] - 1, 1)
37 scale = np.sqrt(np.outer(np.real(np.diag(cov)), np.real(np.diag(cov)))) + 1e-30
38 return np.real(cov / scale)
39
40
41def _save_figures(
42 *,
43 x: np.ndarray,
44 prediction: np.ndarray,
45 error: np.ndarray,
46 forward_norms: np.ndarray,
47 backward_norms: np.ndarray,
48) -> None:
49 try:
50 import matplotlib.pyplot as plt
51 except ImportError: # pragma: no cover - optional plotting dependency
52 print("matplotlib is not installed; skipped figures")
53 return
54
55 out_dir = _artifact_dir()
56 n_view = min(240, x.shape[0])
57
58 fig, ax = plt.subplots(figsize=(8.4, 4.0))
59 ax.plot(np.arange(n_view), x[:n_view, 0], label="observed channel 0")
60 ax.plot(np.arange(n_view), prediction[:n_view, 0].real, label="one-step prediction")
61 ax.set_xlabel("sample")
62 ax.set_ylabel("amplitude")
63 ax.set_title("Causal MIMO lattice prediction uses only previous vectors")
64 ax.legend(loc="best")
65 fig.tight_layout()
66 path = out_dir / "causal_mimo_lattice_prediction_trace.png"
67 fig.savefig(path, dpi=160)
68 plt.close(fig)
69 print(f"wrote {path}")
70
71 fig, ax = plt.subplots(figsize=(7.0, 4.0))
72 stages = np.arange(1, len(forward_norms) + 1)
73 ax.plot(stages, forward_norms, marker="o", label="forward K")
74 ax.plot(stages, backward_norms, marker="s", label="backward L")
75 ax.axhline(1.0, linestyle="--", linewidth=1.0)
76 ax.set_xlabel("lattice stage")
77 ax.set_ylabel("spectral norm")
78 ax.set_title("Matrix reflection norms for online MIMO prediction")
79 ax.set_xticks(stages)
80 ax.legend(loc="best")
81 fig.tight_layout()
82 path = out_dir / "causal_mimo_lattice_reflection_norms.png"
83 fig.savefig(path, dpi=160)
84 plt.close(fig)
85 print(f"wrote {path}")
86
87 input_cov = _normalized_covariance(x)
88 error_cov = _normalized_covariance(error)
89 fig, axes = plt.subplots(1, 2, figsize=(8.5, 3.8))
90 for ax, title, matrix in (
91 (axes[0], "input normalized covariance", input_cov),
92 (axes[1], "prediction-error normalized covariance", error_cov),
93 ):
94 im = ax.imshow(matrix, vmin=-1.0, vmax=1.0)
95 ax.set_title(title)
96 ax.set_xlabel("channel")
97 ax.set_ylabel("channel")
98 fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.82)
99 path = out_dir / "causal_mimo_lattice_residual_covariance.png"
100 fig.savefig(path, dpi=160, bbox_inches="tight")
101 plt.close(fig)
102 print(f"wrote {path}")
103
104
105def main() -> None:
106 true_coefficients = np.asarray(
107 [
108 [[0.34, 0.08, -0.03], [-0.05, 0.30, 0.06], [0.02, -0.06, 0.27]],
109 [[-0.12, 0.03, 0.01], [0.02, -0.10, -0.02], [0.00, 0.04, -0.08]],
110 ],
111 dtype=np.float64,
112 )
113 x = simulate_var(true_coefficients)
114 order = true_coefficients.shape[0]
115
116 # Offline/batch estimation step: obtain matrix reflection coefficients from a finite record.
117 r = ld.multichannel_autocorrelation(x, order=order)
118 levinson = ld.block_levinson_durbin(r, order=order)
119
120 # Online/runtime step: predict each vector before updating the state with that vector.
121 predictor = ld.MIMOLatticePredictor.from_levinson(levinson)
122 prediction, error = predictor.process(x)
123 direct_error = ld.multichannel_prediction_error(x, levinson.coefficients)
124 online_direct_difference = np.linalg.norm(error[order:] - direct_error) / max(
125 np.linalg.norm(direct_error), 1e-30
126 )
127
128 input_cov = np.cov(x.T)
129 error_cov = np.cov(error[order:].real.T)
130 input_offdiag = (np.sum(np.abs(input_cov)) - np.trace(np.abs(input_cov))) / (
131 x.shape[1] * (x.shape[1] - 1)
132 )
133 error_offdiag = (np.sum(np.abs(error_cov)) - np.trace(np.abs(error_cov))) / (
134 x.shape[1] * (x.shape[1] - 1)
135 )
136
137 print("channels:", x.shape[1])
138 print("order:", order)
139 print(
140 "companion spectral radius:", f"{ld.companion_spectral_radius(levinson.coefficients):.6f}"
141 )
142 print("forward reflection norms:", np.round(levinson.reflection_spectral_norms, 6))
143 print("backward reflection norms:", np.round(levinson.backward_reflection_spectral_norms, 6))
144 print("online lattice/direct AR residual difference:", f"{online_direct_difference:.3e}")
145 print(
146 "input/error mean absolute off-diagonal covariance:",
147 f"{input_offdiag:.4f}",
148 f"{error_offdiag:.4f}",
149 )
150 print(
151 "takeaway: after batch coefficient estimation, the MIMO lattice predictor is causal and online"
152 )
153
154 _save_figures(
155 x=x,
156 prediction=prediction,
157 error=error[order:],
158 forward_norms=levinson.reflection_spectral_norms,
159 backward_norms=levinson.backward_reflection_spectral_norms,
160 )
161
162
163if __name__ == "__main__":
164 main()