Online coupled MIMO prediction versus independent SISO

Tutorial goal

Show that full online MIMO lattice prediction captures cross-channel dynamics that independent SISO predictors miss.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

The diagonal-MIMO tutorial is a reduction test: when reflection matrices are diagonal, MIMO equals independent SISO. This tutorial tests the complementary case. A synthetic vector AR process contains off-diagonal lag matrices, so previous samples from one channel help predict another channel. A full MIMO lattice predictor can use that cross-channel history; per-channel SISO predictors cannot.

Key idea and equations

The training signal follows a coupled vector AR model

\[y[n] + \sum_{k=1}^{p} A_k y[n-k] = e[n], \qquad A_k \in \mathbb{R}^{c\times c}.\]

Off-diagonal entries of \(A_k\) encode cross-channel dynamics. The full MIMO predictor estimates matrix reflection coefficients and predicts

\[\widehat y[n] = g\bigl(y[n-1], y[n-2], \ldots\bigr),\]

where \(g\) may mix channels through matrix coefficients. Independent SISO baselines instead fit one predictor per channel,

\[\widehat y_i[n] = g_i\bigl(y_i[n-1], y_i[n-2], \ldots\bigr),\]

so they cannot use \(y_j[n-k]\) for \(j\ne i\). The example also includes a diagonal ablation of the full MIMO reflection matrices to show what happens when the learned off-diagonal entries are removed.

Causality and data use

The coefficients are estimated from a finite training record. Prediction on the held-out test record is online: each model calls predict() before update(y_n), so \widehat y[n] uses only previous test vectors or previous samples from the corresponding SISO channel.

What this example verifies

This verifies the reason to use MIMO rather than three independent SISO predictors. On a held-out coupled VAR stream, the full matrix predictor should reduce residual RMS and, more importantly, reduce residual cross-channel correlation relative to diagonal/SISO baselines.

How to read the result

Compare residual RMS and residual-covariance plots. The full MIMO residual should be smaller and less cross-correlated than the independent SISO baseline when the data really contain cross-channel dynamics. The off-diagonal residual-correlation reduction is often the clearer MIMO diagnostic than scalar RMS alone.

Run command

python examples/online_coupled_mimo_vs_siso.py

Run status

Return code: 0

Captured stdout

channels: 3
order: 2
training samples: 6000
test samples: 2200
true companion spectral radius: 0.749380
fitted companion spectral radius: 0.748557
full MIMO reflection norms: [0.710496 0.24016 ]
full MIMO residual RMS: 0.301616
diagonal-ablation residual RMS: 0.320645
independent SISO residual RMS: 0.319416
relative RMS improvement vs independent SISO: 5.57%
mean abs off-diagonal residual correlation, full MIMO: 0.016742
mean abs off-diagonal residual correlation, independent SISO: 0.053008
off-diagonal residual correlation reduction vs independent SISO: 68.42%
causal contract: prediction is requested before update(y_n) for every test vector

Figures

online coupled mimo coefficient matrices

online_coupled_mimo_coefficient_matrices.png

online coupled mimo prediction trace

online_coupled_mimo_prediction_trace.png

online coupled mimo residual covariance

online_coupled_mimo_residual_covariance.png

online coupled mimo rms comparison

online_coupled_mimo_rms_comparison.png

Generated data files

Source code

  1"""Online coupled MIMO prediction versus independent SISO baselines.
  2
  3The diagonal-MIMO tutorial checks that matrix lattice prediction reduces to
  4independent one-channel prediction when all reflection matrices are diagonal.
  5This tutorial checks the complementary case: when the training signal has true
  6cross-channel dynamics, a full online MIMO lattice predictor can use those
  7off-diagonal terms while independent SISO predictors cannot.
  8"""
  9
 10from __future__ import annotations
 11
 12import csv
 13import os
 14from pathlib import Path
 15
 16import numpy as np
 17
 18import lattice_dsp as ld
 19
 20
 21def artifact_dir() -> Path:
 22    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 23    path.mkdir(parents=True, exist_ok=True)
 24    return path
 25
 26
 27def simulate_coupled_var(coefficients: np.ndarray, samples: int, seed: int = 71) -> np.ndarray:
 28    """Generate a stable coupled vector AR process."""
 29
 30    rng = np.random.default_rng(seed)
 31    order, channels, _ = coefficients.shape
 32    burn_in = 512
 33    x = np.zeros((samples + burn_in, channels), dtype=np.float64)
 34    noise = rng.normal(scale=0.30, size=x.shape)
 35    for n in range(order, x.shape[0]):
 36        value = noise[n].copy()
 37        for lag in range(1, order + 1):
 38            value -= coefficients[lag - 1] @ x[n - lag]
 39        x[n] = value
 40    return x[burn_in:]
 41
 42
 43def normalized_covariance(x: np.ndarray) -> np.ndarray:
 44    centered = np.asarray(x, dtype=np.float64) - np.mean(x, axis=0, keepdims=True)
 45    cov = centered.T @ centered / max(centered.shape[0] - 1, 1)
 46    scale = np.sqrt(np.outer(np.diag(cov), np.diag(cov))) + 1e-30
 47    return cov / scale
 48
 49
 50def mean_abs_offdiag(matrix: np.ndarray) -> float:
 51    mask = ~np.eye(matrix.shape[0], dtype=bool)
 52    return float(np.mean(np.abs(matrix[mask])))
 53
 54
 55def fit_full_mimo_predictor(
 56    train: np.ndarray, order: int
 57) -> tuple[ld.MultichannelARResult, ld.MIMOLatticePredictor]:
 58    r = ld.multichannel_autocorrelation(train, order=order)
 59    result = ld.block_levinson_durbin(r, order=order)
 60    return result, ld.MIMOLatticePredictor.from_levinson(result)
 61
 62
 63def fit_independent_siso_predictors(train: np.ndarray, order: int) -> list[ld.MIMOLatticePredictor]:
 64    predictors: list[ld.MIMOLatticePredictor] = []
 65    for ch in range(train.shape[1]):
 66        r = ld.multichannel_autocorrelation(train[:, [ch]], order=order)
 67        result = ld.block_levinson_durbin(r, order=order)
 68        predictors.append(ld.MIMOLatticePredictor.from_levinson(result))
 69    return predictors
 70
 71
 72def process_independent_siso(
 73    predictors: list[ld.MIMOLatticePredictor], x: np.ndarray
 74) -> tuple[np.ndarray, np.ndarray]:
 75    prediction = np.empty_like(x, dtype=np.float64)
 76    error = np.empty_like(x, dtype=np.float64)
 77    for n, sample in enumerate(x):
 78        for ch, predictor in enumerate(predictors):
 79            prediction[n, ch] = predictor.predict()[0].real
 80            error[n, ch] = predictor.update(np.array([sample[ch]]))[0].real
 81    return prediction, error
 82
 83
 84def diagonal_ablation_predictor(result: ld.MultichannelARResult) -> ld.MIMOLatticePredictor:
 85    """Keep only per-channel reflection entries from a full MIMO fit."""
 86
 87    kf = np.asarray([np.diag(np.diag(stage)) for stage in result.reflection], dtype=np.complex128)
 88    if result.backward_reflection is None:
 89        raise ValueError("block-Levinson result does not contain backward reflections")
 90    kb = np.asarray(
 91        [np.diag(np.diag(stage)) for stage in result.backward_reflection], dtype=np.complex128
 92    )
 93    return ld.MIMOLatticePredictor(kf, kb)
 94
 95
 96def residual_rms(error: np.ndarray, warmup: int) -> float:
 97    return float(np.sqrt(np.mean(np.asarray(error[warmup:]).real ** 2)))
 98
 99
100def residual_rms_by_channel(error: np.ndarray, warmup: int) -> np.ndarray:
101    return np.sqrt(np.mean(np.asarray(error[warmup:]).real ** 2, axis=0))
102
103
104def save_summary_csv(
105    out_dir: Path,
106    *,
107    warmup: int,
108    full_error: np.ndarray,
109    diagonal_error: np.ndarray,
110    siso_error: np.ndarray,
111) -> Path:
112    rows: list[dict[str, object]] = []
113    errors = {
114        "full_mimo": full_error,
115        "diagonal_ablation": diagonal_error,
116        "independent_siso": siso_error,
117    }
118    for name, err in errors.items():
119        by_channel = residual_rms_by_channel(err, warmup)
120        cov = normalized_covariance(err[warmup:].real)
121        for ch, rms in enumerate(by_channel):
122            rows.append(
123                {
124                    "model": name,
125                    "channel": ch,
126                    "residual_rms": float(rms),
127                    "mean_abs_offdiag_residual_correlation": mean_abs_offdiag(cov),
128                }
129            )
130    path = out_dir / "online_coupled_mimo_vs_siso_summary.csv"
131    with path.open("w", newline="", encoding="utf-8") as f:
132        writer = csv.DictWriter(f, fieldnames=list(rows[0]))
133        writer.writeheader()
134        writer.writerows(rows)
135    return path
136
137
138def save_figures(
139    out_dir: Path,
140    *,
141    true_coefficients: np.ndarray,
142    test: np.ndarray,
143    full_prediction: np.ndarray,
144    siso_prediction: np.ndarray,
145    full_error: np.ndarray,
146    diagonal_error: np.ndarray,
147    siso_error: np.ndarray,
148    warmup: int,
149) -> None:
150    try:
151        import matplotlib.pyplot as plt
152    except ImportError:  # pragma: no cover - optional plotting dependency
153        print("matplotlib is not installed; skipped figures")
154        return
155
156    n_view = min(260, test.shape[0])
157    t = np.arange(n_view)
158    fig, ax = plt.subplots(figsize=(9.0, 4.3))
159    ax.plot(t, test[:n_view, 0], label="observed channel 0", linewidth=1.6)
160    ax.plot(t, full_prediction[:n_view, 0].real, label="full MIMO prediction", linewidth=1.3)
161    ax.plot(
162        t,
163        siso_prediction[:n_view, 0].real,
164        "--",
165        label="independent SISO prediction",
166        linewidth=1.2,
167    )
168    ax.set_xlabel("test sample")
169    ax.set_ylabel("amplitude")
170    ax.set_title("Online prediction: full MIMO can use cross-channel history")
171    ax.legend(loc="best")
172    ax.grid(True, alpha=0.25)
173    fig.tight_layout()
174    path = out_dir / "online_coupled_mimo_prediction_trace.png"
175    fig.savefig(path, dpi=160)
176    plt.close(fig)
177    print(f"wrote {path}")
178
179    labels = ["full MIMO", "diagonal\nablation", "independent\nSISO"]
180    values = [
181        residual_rms(full_error, warmup),
182        residual_rms(diagonal_error, warmup),
183        residual_rms(siso_error, warmup),
184    ]
185    fig, ax = plt.subplots(figsize=(7.2, 4.2))
186    ax.bar(np.arange(len(values)), values)
187    ax.set_xticks(np.arange(len(values)), labels)
188    ax.set_ylabel("residual RMS")
189    ax.set_title("Coupled online MIMO prediction reduces residual energy")
190    for idx, value in enumerate(values):
191        ax.text(idx, value, f"{value:.3f}", ha="center", va="bottom")
192    fig.tight_layout()
193    path = out_dir / "online_coupled_mimo_rms_comparison.png"
194    fig.savefig(path, dpi=160)
195    plt.close(fig)
196    print(f"wrote {path}")
197
198    matrices = [
199        ("full MIMO residual", normalized_covariance(full_error[warmup:].real)),
200        ("independent SISO residual", normalized_covariance(siso_error[warmup:].real)),
201    ]
202    fig, axes = plt.subplots(1, 2, figsize=(8.8, 3.8))
203    for ax, (title, matrix) in zip(axes, matrices, strict=True):
204        im = ax.imshow(matrix, vmin=-1.0, vmax=1.0)
205        ax.set_title(title)
206        ax.set_xlabel("channel")
207        ax.set_ylabel("channel")
208    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.84)
209    path = out_dir / "online_coupled_mimo_residual_covariance.png"
210    fig.savefig(path, dpi=160, bbox_inches="tight")
211    plt.close(fig)
212    print(f"wrote {path}")
213
214    order = true_coefficients.shape[0]
215    fig, axes = plt.subplots(1, order, figsize=(4.2 * order, 3.8))
216    if order == 1:
217        axes = [axes]
218    max_abs = float(np.max(np.abs(true_coefficients)))
219    for lag, ax in enumerate(axes):
220        im = ax.imshow(true_coefficients[lag], vmin=-max_abs, vmax=max_abs)
221        ax.set_title(f"true A[{lag + 1}]")
222        ax.set_xlabel("source channel")
223        ax.set_ylabel("target channel")
224    fig.colorbar(im, ax=np.ravel(axes).tolist(), shrink=0.84)
225    path = out_dir / "online_coupled_mimo_coefficient_matrices.png"
226    fig.savefig(path, dpi=160, bbox_inches="tight")
227    plt.close(fig)
228    print(f"wrote {path}")
229
230
231def main() -> None:
232    out_dir = artifact_dir()
233
234    # Stable coupled VAR(2).  The off-diagonal entries are the part independent
235    # SISO predictors cannot represent.
236    true_coefficients = np.asarray(
237        [
238            [[0.55, 0.30, 0.00], [-0.25, 0.45, 0.22], [0.18, -0.12, 0.40]],
239            [[-0.18, 0.08, 0.02], [0.05, -0.14, -0.05], [-0.03, 0.07, -0.10]],
240        ],
241        dtype=np.float64,
242    )
243    order, channels, _ = true_coefficients.shape
244    train_samples = 6000
245    test_samples = 2200
246    x = simulate_coupled_var(true_coefficients, train_samples + test_samples)
247    train = x[:train_samples]
248    test = x[train_samples:]
249    warmup = order
250
251    full_result, full_predictor = fit_full_mimo_predictor(train, order)
252    full_prediction, full_error = full_predictor.process(test)
253
254    diagonal_predictor = diagonal_ablation_predictor(full_result)
255    diagonal_prediction, diagonal_error = diagonal_predictor.process(test)
256
257    siso_predictors = fit_independent_siso_predictors(train, order)
258    siso_prediction, siso_error = process_independent_siso(siso_predictors, test)
259
260    full_rms = residual_rms(full_error, warmup)
261    diagonal_rms = residual_rms(diagonal_error, warmup)
262    siso_rms = residual_rms(siso_error, warmup)
263    relative_improvement = (siso_rms - full_rms) / max(siso_rms, 1e-30)
264
265    full_cov = normalized_covariance(full_error[warmup:].real)
266    siso_cov = normalized_covariance(siso_error[warmup:].real)
267    full_offdiag = mean_abs_offdiag(full_cov)
268    siso_offdiag = mean_abs_offdiag(siso_cov)
269    offdiag_reduction = (siso_offdiag - full_offdiag) / max(siso_offdiag, 1e-30)
270
271    csv_path = save_summary_csv(
272        out_dir,
273        warmup=warmup,
274        full_error=full_error,
275        diagonal_error=diagonal_error,
276        siso_error=siso_error,
277    )
278
279    print("channels:", channels)
280    print("order:", order)
281    print("training samples:", train_samples)
282    print("test samples:", test_samples)
283    print(
284        "true companion spectral radius:", f"{ld.companion_spectral_radius(true_coefficients):.6f}"
285    )
286    print(
287        "fitted companion spectral radius:",
288        f"{ld.companion_spectral_radius(full_result.coefficients):.6f}",
289    )
290    print("full MIMO reflection norms:", np.round(full_result.reflection_spectral_norms, 6))
291    print("full MIMO residual RMS:", f"{full_rms:.6f}")
292    print("diagonal-ablation residual RMS:", f"{diagonal_rms:.6f}")
293    print("independent SISO residual RMS:", f"{siso_rms:.6f}")
294    print("relative RMS improvement vs independent SISO:", f"{100.0 * relative_improvement:.2f}%")
295    print("mean abs off-diagonal residual correlation, full MIMO:", f"{full_offdiag:.6f}")
296    print("mean abs off-diagonal residual correlation, independent SISO:", f"{siso_offdiag:.6f}")
297    print(
298        "off-diagonal residual correlation reduction vs independent SISO:",
299        f"{100.0 * offdiag_reduction:.2f}%",
300    )
301    print("causal contract: prediction is requested before update(y_n) for every test vector")
302    print(f"wrote {csv_path}")
303
304    save_figures(
305        out_dir,
306        true_coefficients=true_coefficients,
307        test=test,
308        full_prediction=full_prediction,
309        siso_prediction=siso_prediction,
310        full_error=full_error,
311        diagonal_error=diagonal_error,
312        siso_error=siso_error,
313        warmup=warmup,
314    )
315
316
317if __name__ == "__main__":
318    main()