Online coupled MIMO prediction versus independent SISO¶
Tutorial goal
Show that full online MIMO lattice prediction captures cross-channel dynamics that independent SISO predictors miss.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
The diagonal-MIMO tutorial is a reduction test: when reflection matrices are diagonal, MIMO equals independent SISO. This tutorial tests the complementary case. A synthetic vector AR process contains off-diagonal lag matrices, so previous samples from one channel help predict another channel. A full MIMO lattice predictor can use that cross-channel history; per-channel SISO predictors cannot.
Key idea and equations¶
The training signal follows a coupled vector AR model
Off-diagonal entries of \(A_k\) encode cross-channel dynamics. The full MIMO predictor estimates matrix reflection coefficients and predicts
where \(g\) may mix channels through matrix coefficients. Independent SISO baselines instead fit one predictor per channel,
so they cannot use \(y_j[n-k]\) for \(j\ne i\). The example also includes a diagonal ablation of the full MIMO reflection matrices to show what happens when the learned off-diagonal entries are removed.
Causality and data use¶
The coefficients are estimated from a finite training record. Prediction on the held-out test record is online: each model calls predict() before update(y_n), so \widehat y[n] uses only previous test vectors or previous samples from the corresponding SISO channel.
What this example verifies¶
This verifies the reason to use MIMO rather than three independent SISO predictors. On a held-out coupled VAR stream, the full matrix predictor should reduce residual RMS and, more importantly, reduce residual cross-channel correlation relative to diagonal/SISO baselines.
How to read the result¶
Compare residual RMS and residual-covariance plots. The full MIMO residual should be smaller and less cross-correlated than the independent SISO baseline when the data really contain cross-channel dynamics. The off-diagonal residual-correlation reduction is often the clearer MIMO diagnostic than scalar RMS alone.
Run command¶
python examples/online_coupled_mimo_vs_siso.py
Run status¶
Return code: 0
Captured stdout¶
channels: 3
order: 2
training samples: 6000
test samples: 2200
true companion spectral radius: 0.749380
fitted companion spectral radius: 0.748557
full MIMO reflection norms: [0.710496 0.24016 ]
full MIMO residual RMS: 0.301616
diagonal-ablation residual RMS: 0.320645
independent SISO residual RMS: 0.319416
relative RMS improvement vs independent SISO: 5.57%
mean abs off-diagonal residual correlation, full MIMO: 0.016742
mean abs off-diagonal residual correlation, independent SISO: 0.053008
off-diagonal residual correlation reduction vs independent SISO: 68.42%
causal contract: prediction is requested before update(y_n) for every test vector
Figures¶
online_coupled_mimo_coefficient_matrices.png¶
online_coupled_mimo_prediction_trace.png¶
online_coupled_mimo_residual_covariance.png¶
online_coupled_mimo_rms_comparison.png¶
Generated data files¶
Source code¶
1"""Online coupled MIMO prediction versus independent SISO baselines.
2
3The diagonal-MIMO tutorial checks that matrix lattice prediction reduces to
4independent one-channel prediction when all reflection matrices are diagonal.
5This tutorial checks the complementary case: when the training signal has true
6cross-channel dynamics, a full online MIMO lattice predictor can use those
7off-diagonal terms while independent SISO predictors cannot.
8"""
9
10from __future__ import annotations
11
12import csv
13import os
14from pathlib import Path
15
16import numpy as np
17
18import lattice_dsp as ld
19
20
21def artifact_dir() -> Path:
22 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
23 path.mkdir(parents=True, exist_ok=True)
24 return path
25
26
27def simulate_coupled_var(coefficients: np.ndarray, samples: int, seed: int = 71) -> np.ndarray:
28 """Generate a stable coupled vector AR process."""
29
30 rng = np.random.default_rng(seed)
31 order, channels, _ = coefficients.shape
32 burn_in = 512
33 x = np.zeros((samples + burn_in, channels), dtype=np.float64)
34 noise = rng.normal(scale=0.30, size=x.shape)
35 for n in range(order, x.shape[0]):
36 value = noise[n].copy()
37 for lag in range(1, order + 1):
38 value -= coefficients[lag - 1] @ x[n - lag]
39 x[n] = value
40 return x[burn_in:]
41
42
43def normalized_covariance(x: np.ndarray) -> np.ndarray:
44 centered = np.asarray(x, dtype=np.float64) - np.mean(x, axis=0, keepdims=True)
45 cov = centered.T @ centered / max(centered.shape[0] - 1, 1)
46 scale = np.sqrt(np.outer(np.diag(cov), np.diag(cov))) + 1e-30
47 return cov / scale
48
49
50def mean_abs_offdiag(matrix: np.ndarray) -> float:
51 mask = ~np.eye(matrix.shape[0], dtype=bool)
52 return float(np.mean(np.abs(matrix[mask])))
53
54
55def fit_full_mimo_predictor(
56 train: np.ndarray, order: int
57) -> tuple[ld.MultichannelARResult, ld.MIMOLatticePredictor]:
58 r = ld.multichannel_autocorrelation(train, order=order)
59 result = ld.block_levinson_durbin(r, order=order)
60 return result, ld.MIMOLatticePredictor.from_levinson(result)
61
62
63def fit_independent_siso_predictors(train: np.ndarray, order: int) -> list[ld.MIMOLatticePredictor]:
64 predictors: list[ld.MIMOLatticePredictor] = []
65 for ch in range(train.shape[1]):
66 r = ld.multichannel_autocorrelation(train[:, [ch]], order=order)
67 result = ld.block_levinson_durbin(r, order=order)
68 predictors.append(ld.MIMOLatticePredictor.from_levinson(result))
69 return predictors
70
71
72def process_independent_siso(
73 predictors: list[ld.MIMOLatticePredictor], x: np.ndarray
74) -> tuple[np.ndarray, np.ndarray]:
75 prediction = np.empty_like(x, dtype=np.float64)
76 error = np.empty_like(x, dtype=np.float64)
77 for n, sample in enumerate(x):
78 for ch, predictor in enumerate(predictors):
79 prediction[n, ch] = predictor.predict()[0].real
80 error[n, ch] = predictor.update(np.array([sample[ch]]))[0].real
81 return prediction, error
82
83
84def diagonal_ablation_predictor(result: ld.MultichannelARResult) -> ld.MIMOLatticePredictor:
85 """Keep only per-channel reflection entries from a full MIMO fit."""
86
87 kf = np.asarray([np.diag(np.diag(stage)) for stage in result.reflection], dtype=np.complex128)
88 if result.backward_reflection is None:
89 raise ValueError("block-Levinson result does not contain backward reflections")
90 kb = np.asarray(
91 [np.diag(np.diag(stage)) for stage in result.backward_reflection], dtype=np.complex128
92 )
93 return ld.MIMOLatticePredictor(kf, kb)
94
95
96def residual_rms(error: np.ndarray, warmup: int) -> float:
97 return float(np.sqrt(np.mean(np.asarray(error[warmup:]).real ** 2)))
98
99
100def residual_rms_by_channel(error: np.ndarray, warmup: int) -> np.ndarray:
101 return np.sqrt(np.mean(np.asarray(error[warmup:]).real ** 2, axis=0))
102
103
104def save_summary_csv(
105 out_dir: Path,
106 *,
107 warmup: int,
108 full_error: np.ndarray,
109 diagonal_error: np.ndarray,
110 siso_error: np.ndarray,
111) -> Path:
112 rows: list[dict[str, object]] = []
113 errors = {
114 "full_mimo": full_error,
115 "diagonal_ablation": diagonal_error,
116 "independent_siso": siso_error,
117 }
118 for name, err in errors.items():
119 by_channel = residual_rms_by_channel(err, warmup)
120 cov = normalized_covariance(err[warmup:].real)
121 for ch, rms in enumerate(by_channel):
122 rows.append(
123 {
124 "model": name,
125 "channel": ch,
126 "residual_rms": float(rms),
127 "mean_abs_offdiag_residual_correlation": mean_abs_offdiag(cov),
128 }
129 )
130 path = out_dir / "online_coupled_mimo_vs_siso_summary.csv"
131 with path.open("w", newline="", encoding="utf-8") as f:
132 writer = csv.DictWriter(f, fieldnames=list(rows[0]))
133 writer.writeheader()
134 writer.writerows(rows)
135 return path
136
137
138def save_figures(
139 out_dir: Path,
140 *,
141 true_coefficients: np.ndarray,
142 test: np.ndarray,
143 full_prediction: np.ndarray,
144 siso_prediction: np.ndarray,
145 full_error: np.ndarray,
146 diagonal_error: np.ndarray,
147 siso_error: np.ndarray,
148 warmup: int,
149) -> None:
150 try:
151 import matplotlib.pyplot as plt
152 except ImportError: # pragma: no cover - optional plotting dependency
153 print("matplotlib is not installed; skipped figures")
154 return
155
156 n_view = min(260, test.shape[0])
157 t = np.arange(n_view)
158 fig, ax = plt.subplots(figsize=(9.0, 4.3))
159 ax.plot(t, test[:n_view, 0], label="observed channel 0", linewidth=1.6)
160 ax.plot(t, full_prediction[:n_view, 0].real, label="full MIMO prediction", linewidth=1.3)
161 ax.plot(
162 t,
163 siso_prediction[:n_view, 0].real,
164 "--",
165 label="independent SISO prediction",
166 linewidth=1.2,
167 )
168 ax.set_xlabel("test sample")
169 ax.set_ylabel("amplitude")
170 ax.set_title("Online prediction: full MIMO can use cross-channel history")
171 ax.legend(loc="best")
172 ax.grid(True, alpha=0.25)
173 fig.tight_layout()
174 path = out_dir / "online_coupled_mimo_prediction_trace.png"
175 fig.savefig(path, dpi=160)
176 plt.close(fig)
177 print(f"wrote {path}")
178
179 labels = ["full MIMO", "diagonal\nablation", "independent\nSISO"]
180 values = [
181 residual_rms(full_error, warmup),
182 residual_rms(diagonal_error, warmup),
183 residual_rms(siso_error, warmup),
184 ]
185 fig, ax = plt.subplots(figsize=(7.2, 4.2))
186 ax.bar(np.arange(len(values)), values)
187 ax.set_xticks(np.arange(len(values)), labels)
188 ax.set_ylabel("residual RMS")
189 ax.set_title("Coupled online MIMO prediction reduces residual energy")
190 for idx, value in enumerate(values):
191 ax.text(idx, value, f"{value:.3f}", ha="center", va="bottom")
192 fig.tight_layout()
193 path = out_dir / "online_coupled_mimo_rms_comparison.png"
194 fig.savefig(path, dpi=160)
195 plt.close(fig)
196 print(f"wrote {path}")
197
198 matrices = [
199 ("full MIMO residual", normalized_covariance(full_error[warmup:].real)),
200 ("independent SISO residual", normalized_covariance(siso_error[warmup:].real)),
201 ]
202 fig, axes = plt.subplots(1, 2, figsize=(8.8, 3.8))
203 for ax, (title, matrix) in zip(axes, matrices, strict=True):
204 im = ax.imshow(matrix, vmin=-1.0, vmax=1.0)
205 ax.set_title(title)
206 ax.set_xlabel("channel")
207 ax.set_ylabel("channel")
208 fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.84)
209 path = out_dir / "online_coupled_mimo_residual_covariance.png"
210 fig.savefig(path, dpi=160, bbox_inches="tight")
211 plt.close(fig)
212 print(f"wrote {path}")
213
214 order = true_coefficients.shape[0]
215 fig, axes = plt.subplots(1, order, figsize=(4.2 * order, 3.8))
216 if order == 1:
217 axes = [axes]
218 max_abs = float(np.max(np.abs(true_coefficients)))
219 for lag, ax in enumerate(axes):
220 im = ax.imshow(true_coefficients[lag], vmin=-max_abs, vmax=max_abs)
221 ax.set_title(f"true A[{lag + 1}]")
222 ax.set_xlabel("source channel")
223 ax.set_ylabel("target channel")
224 fig.colorbar(im, ax=np.ravel(axes).tolist(), shrink=0.84)
225 path = out_dir / "online_coupled_mimo_coefficient_matrices.png"
226 fig.savefig(path, dpi=160, bbox_inches="tight")
227 plt.close(fig)
228 print(f"wrote {path}")
229
230
231def main() -> None:
232 out_dir = artifact_dir()
233
234 # Stable coupled VAR(2). The off-diagonal entries are the part independent
235 # SISO predictors cannot represent.
236 true_coefficients = np.asarray(
237 [
238 [[0.55, 0.30, 0.00], [-0.25, 0.45, 0.22], [0.18, -0.12, 0.40]],
239 [[-0.18, 0.08, 0.02], [0.05, -0.14, -0.05], [-0.03, 0.07, -0.10]],
240 ],
241 dtype=np.float64,
242 )
243 order, channels, _ = true_coefficients.shape
244 train_samples = 6000
245 test_samples = 2200
246 x = simulate_coupled_var(true_coefficients, train_samples + test_samples)
247 train = x[:train_samples]
248 test = x[train_samples:]
249 warmup = order
250
251 full_result, full_predictor = fit_full_mimo_predictor(train, order)
252 full_prediction, full_error = full_predictor.process(test)
253
254 diagonal_predictor = diagonal_ablation_predictor(full_result)
255 diagonal_prediction, diagonal_error = diagonal_predictor.process(test)
256
257 siso_predictors = fit_independent_siso_predictors(train, order)
258 siso_prediction, siso_error = process_independent_siso(siso_predictors, test)
259
260 full_rms = residual_rms(full_error, warmup)
261 diagonal_rms = residual_rms(diagonal_error, warmup)
262 siso_rms = residual_rms(siso_error, warmup)
263 relative_improvement = (siso_rms - full_rms) / max(siso_rms, 1e-30)
264
265 full_cov = normalized_covariance(full_error[warmup:].real)
266 siso_cov = normalized_covariance(siso_error[warmup:].real)
267 full_offdiag = mean_abs_offdiag(full_cov)
268 siso_offdiag = mean_abs_offdiag(siso_cov)
269 offdiag_reduction = (siso_offdiag - full_offdiag) / max(siso_offdiag, 1e-30)
270
271 csv_path = save_summary_csv(
272 out_dir,
273 warmup=warmup,
274 full_error=full_error,
275 diagonal_error=diagonal_error,
276 siso_error=siso_error,
277 )
278
279 print("channels:", channels)
280 print("order:", order)
281 print("training samples:", train_samples)
282 print("test samples:", test_samples)
283 print(
284 "true companion spectral radius:", f"{ld.companion_spectral_radius(true_coefficients):.6f}"
285 )
286 print(
287 "fitted companion spectral radius:",
288 f"{ld.companion_spectral_radius(full_result.coefficients):.6f}",
289 )
290 print("full MIMO reflection norms:", np.round(full_result.reflection_spectral_norms, 6))
291 print("full MIMO residual RMS:", f"{full_rms:.6f}")
292 print("diagonal-ablation residual RMS:", f"{diagonal_rms:.6f}")
293 print("independent SISO residual RMS:", f"{siso_rms:.6f}")
294 print("relative RMS improvement vs independent SISO:", f"{100.0 * relative_improvement:.2f}%")
295 print("mean abs off-diagonal residual correlation, full MIMO:", f"{full_offdiag:.6f}")
296 print("mean abs off-diagonal residual correlation, independent SISO:", f"{siso_offdiag:.6f}")
297 print(
298 "off-diagonal residual correlation reduction vs independent SISO:",
299 f"{100.0 * offdiag_reduction:.2f}%",
300 )
301 print("causal contract: prediction is requested before update(y_n) for every test vector")
302 print(f"wrote {csv_path}")
303
304 save_figures(
305 out_dir,
306 true_coefficients=true_coefficients,
307 test=test,
308 full_prediction=full_prediction,
309 siso_prediction=siso_prediction,
310 full_error=full_error,
311 diagonal_error=diagonal_error,
312 siso_error=siso_error,
313 warmup=warmup,
314 )
315
316
317if __name__ == "__main__":
318 main()