Causal online MIMO lattice prediction

Tutorial goal

Use block-Levinson matrix reflections in a sample-by-sample causal MIMO lattice predictor.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

This tutorial separates coefficient estimation from runtime filtering. The matrix reflection coefficients are estimated once from a finite training record, then a stateful lattice predictor runs online. At each time step it predicts the next vector from stored backward-error states before the current vector is observed.

Key idea and equations

With forward and backward matrix reflection coefficients \(K_m\) and \(L_m\), the online vector lattice recursion is

\[f_0[n] = b_0[n] = y[n],\]
\[f_m[n] = f_{m-1}[n] + K_m b_{m-1}[n-1],\]
\[b_m[n] = b_{m-1}[n-1] + L_m f_{m-1}[n].\]

The one-step prediction is obtained before seeing \(y[n]\) by evaluating the same recursion with \(y[n]=0\) and negating the result:

\[\hat y[n] = -f_p[n]\big|_{y[n]=0}.\]

After the true vector is observed, \(f_p[n]=y[n]-\hat y[n]\) is the forward prediction error and the backward-error states are updated. This is causal in the prediction sense: \(\hat y[n]\) depends only on stored states from samples \(< n\).

Causality and data use

The block-Levinson fit is a batch estimation step. The MIMOLatticePredictor object created from those reflection matrices is a runtime object: predict() uses only previous vectors, and update(y_n) consumes the current vector afterward.

What this example verifies

This verifies the online contract for MIMOLatticePredictor. Prediction is requested before the current vector is observed, then update(y_n) consumes the current vector and advances the backward-error state. The residual is compared with the direct VAR residual from the same fitted block-Levinson model.

How to read the result

Check that the online lattice residual matches the direct AR residual from the same block-Levinson fit, then inspect the trace, reflection-norm, and residual-covariance figures.

Run command

python examples/causal_mimo_lattice_prediction.py

Source code

  1"""Online causal MIMO lattice prediction from block Levinson reflections."""
  2
  3from __future__ import annotations
  4
  5import os
  6from pathlib import Path
  7
  8import numpy as np
  9
 10import lattice_dsp as ld
 11
 12
 13def _artifact_dir() -> Path:
 14    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 15    path.mkdir(parents=True, exist_ok=True)
 16    return path
 17
 18
 19def simulate_var(coefficients: np.ndarray, samples: int = 12000, seed: int = 31) -> np.ndarray:
 20    """Generate a stable coupled vector AR process."""
 21
 22    rng = np.random.default_rng(seed)
 23    order, channels, _ = coefficients.shape
 24    x = np.zeros((samples + 512, channels), dtype=np.float64)
 25    noise = rng.normal(scale=0.35, size=x.shape)
 26    for n in range(order, x.shape[0]):
 27        value = noise[n].copy()
 28        for lag in range(1, order + 1):
 29            value -= coefficients[lag - 1] @ x[n - lag]
 30        x[n] = value
 31    return x[512:]
 32
 33
 34def _normalized_covariance(x: np.ndarray) -> np.ndarray:
 35    centered = x - np.mean(x, axis=0, keepdims=True)
 36    cov = centered.T @ centered.conj() / max(x.shape[0] - 1, 1)
 37    scale = np.sqrt(np.outer(np.real(np.diag(cov)), np.real(np.diag(cov)))) + 1e-30
 38    return np.real(cov / scale)
 39
 40
 41def _save_figures(
 42    *,
 43    x: np.ndarray,
 44    prediction: np.ndarray,
 45    error: np.ndarray,
 46    forward_norms: np.ndarray,
 47    backward_norms: np.ndarray,
 48) -> None:
 49    try:
 50        import matplotlib.pyplot as plt
 51    except ImportError:  # pragma: no cover - optional plotting dependency
 52        print("matplotlib is not installed; skipped figures")
 53        return
 54
 55    out_dir = _artifact_dir()
 56    n_view = min(240, x.shape[0])
 57
 58    fig, ax = plt.subplots(figsize=(8.4, 4.0))
 59    ax.plot(np.arange(n_view), x[:n_view, 0], label="observed channel 0")
 60    ax.plot(np.arange(n_view), prediction[:n_view, 0].real, label="one-step prediction")
 61    ax.set_xlabel("sample")
 62    ax.set_ylabel("amplitude")
 63    ax.set_title("Causal MIMO lattice prediction uses only previous vectors")
 64    ax.legend(loc="best")
 65    fig.tight_layout()
 66    path = out_dir / "causal_mimo_lattice_prediction_trace.png"
 67    fig.savefig(path, dpi=160)
 68    plt.close(fig)
 69    print(f"wrote {path}")
 70
 71    fig, ax = plt.subplots(figsize=(7.0, 4.0))
 72    stages = np.arange(1, len(forward_norms) + 1)
 73    ax.plot(stages, forward_norms, marker="o", label="forward K")
 74    ax.plot(stages, backward_norms, marker="s", label="backward L")
 75    ax.axhline(1.0, linestyle="--", linewidth=1.0)
 76    ax.set_xlabel("lattice stage")
 77    ax.set_ylabel("spectral norm")
 78    ax.set_title("Matrix reflection norms for online MIMO prediction")
 79    ax.set_xticks(stages)
 80    ax.legend(loc="best")
 81    fig.tight_layout()
 82    path = out_dir / "causal_mimo_lattice_reflection_norms.png"
 83    fig.savefig(path, dpi=160)
 84    plt.close(fig)
 85    print(f"wrote {path}")
 86
 87    input_cov = _normalized_covariance(x)
 88    error_cov = _normalized_covariance(error)
 89    fig, axes = plt.subplots(1, 2, figsize=(8.5, 3.8))
 90    for ax, title, matrix in (
 91        (axes[0], "input normalized covariance", input_cov),
 92        (axes[1], "prediction-error normalized covariance", error_cov),
 93    ):
 94        im = ax.imshow(matrix, vmin=-1.0, vmax=1.0)
 95        ax.set_title(title)
 96        ax.set_xlabel("channel")
 97        ax.set_ylabel("channel")
 98    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.82)
 99    path = out_dir / "causal_mimo_lattice_residual_covariance.png"
100    fig.savefig(path, dpi=160, bbox_inches="tight")
101    plt.close(fig)
102    print(f"wrote {path}")
103
104
105def main() -> None:
106    true_coefficients = np.asarray(
107        [
108            [[0.34, 0.08, -0.03], [-0.05, 0.30, 0.06], [0.02, -0.06, 0.27]],
109            [[-0.12, 0.03, 0.01], [0.02, -0.10, -0.02], [0.00, 0.04, -0.08]],
110        ],
111        dtype=np.float64,
112    )
113    x = simulate_var(true_coefficients)
114    order = true_coefficients.shape[0]
115
116    # Offline/batch estimation step: obtain matrix reflection coefficients from a finite record.
117    r = ld.multichannel_autocorrelation(x, order=order)
118    levinson = ld.block_levinson_durbin(r, order=order)
119
120    # Online/runtime step: predict each vector before updating the state with that vector.
121    predictor = ld.MIMOLatticePredictor.from_levinson(levinson)
122    prediction, error = predictor.process(x)
123    direct_error = ld.multichannel_prediction_error(x, levinson.coefficients)
124    online_direct_difference = np.linalg.norm(error[order:] - direct_error) / max(
125        np.linalg.norm(direct_error), 1e-30
126    )
127
128    input_cov = np.cov(x.T)
129    error_cov = np.cov(error[order:].real.T)
130    input_offdiag = (np.sum(np.abs(input_cov)) - np.trace(np.abs(input_cov))) / (
131        x.shape[1] * (x.shape[1] - 1)
132    )
133    error_offdiag = (np.sum(np.abs(error_cov)) - np.trace(np.abs(error_cov))) / (
134        x.shape[1] * (x.shape[1] - 1)
135    )
136
137    print("channels:", x.shape[1])
138    print("order:", order)
139    print(
140        "companion spectral radius:", f"{ld.companion_spectral_radius(levinson.coefficients):.6f}"
141    )
142    print("forward reflection norms:", np.round(levinson.reflection_spectral_norms, 6))
143    print("backward reflection norms:", np.round(levinson.backward_reflection_spectral_norms, 6))
144    print("online lattice/direct AR residual difference:", f"{online_direct_difference:.3e}")
145    print(
146        "input/error mean absolute off-diagonal covariance:",
147        f"{input_offdiag:.4f}",
148        f"{error_offdiag:.4f}",
149    )
150    print(
151        "takeaway: after batch coefficient estimation, the MIMO lattice predictor is causal and online"
152    )
153
154    _save_figures(
155        x=x,
156        prediction=prediction,
157        error=error[order:],
158        forward_norms=levinson.reflection_spectral_norms,
159        backward_norms=levinson.backward_reflection_spectral_norms,
160    )
161
162
163if __name__ == "__main__":
164    main()