Coupled MIMO matrix-lattice filtering

Tutorial goal

Apply a matrix-lattice all-pass to a coupled complex MIMO signal block and verify streaming energy preservation.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

This tutorial moves from static frequency-response diagnostics to a signal-processing use case. A coupled complex multichannel signal is transformed by the causal OnlineMatrixLatticeAllPass runtime. A finite-record time-domain adjoint then checks reconstruction. The example verifies that the matrix-lattice response preserves energy while still mixing channels in a frequency-dependent way.

Key idea and equations

The matrix-lattice response \(G(z)\) is designed as an all-pass multichannel transform:

\[G(e^{j\omega})^H G(e^{j\omega}) = I.\]

The forward online runtime applies the causal convolution

\[y[n] = \sum_{k\ge 0} H_k x[n-k],\]

where \(H_k\in\mathbb{C}^{c\times c}\) are matrix impulse-response coefficients. Energy preservation holds on the full stream, including the decaying all-pass tail:

\[\sum_n \lVert y[n]\rVert_2^2 \approx \sum_n \lVert x[n]\rVert_2^2.\]

The finite-record synthesis diagnostic applies the time-domain adjoint

\[x_{adj}[n] = \sum_{k\ge 0} H_k^H y[n+k].\]

This adjoint is noncausal as a streaming inverse because it needs future transformed samples, but it is useful when the whole record is available.

Causality and data use

The forward analysis path is causal and sample-by-sample. The reconstruction check is a finite-record time-domain adjoint, so it is noncausal/transductive by design and should not be confused with a causal stable inverse.

What this example verifies

This verifies streaming coupled forward filtering. The output is produced by the online matrix-lattice runtime, off-diagonal impulse/Markov energy shows channel coupling, and the finite-record adjoint is labeled separately as a noncausal reconstruction diagnostic.

How to read the result

Look for near-zero unitarity, energy, streaming-vs-impulse, and finite-adjoint reconstruction errors. The covariance plots show that the streaming block is coupled even though it is norm preserving.

Run command

python examples/coupled_mimo_lattice_filter.py

Run status

Return code: 0

Captured stdout

channels: 3
matrix-lattice order: 5
samples: 2048
tail samples for energy/reconstruction: 768
max reflection singular value: 0.7944
real scalar parameter count: 108
max unitarity error: 2.784e-14
streaming vs truncated impulse error: 3.796e-09
energy relative error with tail: 3.145e-16
finite-adjoint reconstruction error: 2.906e-09
input/output mean off-diagonal covariance: 0.508 0.162
causal analysis: y[n] is produced by OnlineMatrixLatticeAllPass before future x samples are seen
finite adjoint: reconstruction uses the whole transformed block and is noncausal

Figures

coupled mimo lattice covariance

coupled_mimo_lattice_covariance.png

coupled mimo lattice singular values

coupled_mimo_lattice_singular_values.png

coupled mimo lattice streaming trace

coupled_mimo_lattice_streaming_trace.png

Generated data files

Source code

  1"""Tutorial: coupled MIMO matrix-lattice filtering with streaming analysis.
  2
  3A :class:`lattice_dsp.MatrixLatticeAllPass` is a square, multichannel,
  4frequency-dependent all-pass mixing system.  This example applies the forward
  5analysis transform with the causal online runtime, then uses a finite-record
  6noncausal adjoint in the time domain to check reconstruction.  The example is
  7about the matrix-lattice runtime and diagnostics, not about model reduction.
  8"""
  9
 10from __future__ import annotations
 11
 12import csv
 13import os
 14from pathlib import Path
 15
 16import numpy as np
 17
 18import lattice_dsp as ld
 19
 20
 21def artifact_dir() -> Path:
 22    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 23    path.mkdir(parents=True, exist_ok=True)
 24    return path
 25
 26
 27def make_coupled_lattice(
 28    channels: int = 3, order: int = 5, seed: int = 202
 29) -> ld.MatrixLatticeAllPass:
 30    """Return a deterministic coupled matrix-lattice all-pass filter."""
 31
 32    rng = np.random.default_rng(seed)
 33    reflections = []
 34    for stage in range(order):
 35        raw = rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
 36        reflections.append(ld.contractive_matrix_from_raw((0.18 + 0.03 * stage) * raw, margin=1e-6))
 37    residue = ld.unitary_polar_factor(
 38        rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
 39    )
 40    return ld.MatrixLatticeAllPass(reflections, residue=residue)
 41
 42
 43def coupled_complex_signal(samples: int = 1024, channels: int = 3, seed: int = 203) -> np.ndarray:
 44    """Generate a correlated complex multichannel input block."""
 45
 46    rng = np.random.default_rng(seed)
 47    latent = rng.normal(size=(samples, 2)) + 1j * rng.normal(size=(samples, 2))
 48    mixing = np.array(
 49        [
 50            [1.0 + 0.0j, 0.35 - 0.10j],
 51            [0.55 + 0.20j, -0.65 + 0.30j],
 52            [-0.20 + 0.45j, 0.85 + 0.05j],
 53        ],
 54        dtype=np.complex128,
 55    )[:channels, :]
 56    x = latent @ mixing.T
 57    x += 0.08 * (rng.normal(size=(samples, channels)) + 1j * rng.normal(size=(samples, channels)))
 58    return np.ascontiguousarray(x)
 59
 60
 61def apply_matrix_lattice_streaming(
 62    x: np.ndarray, filt: ld.MatrixLatticeAllPass, *, tail: int = 256
 63) -> np.ndarray:
 64    """Apply the forward matrix-lattice all-pass with the causal online runtime."""
 65
 66    x = np.asarray(x, dtype=np.complex128)
 67    if x.ndim != 2 or x.shape[1] != filt.dimension:
 68        raise ValueError("x must have shape (samples, filter.dimension)")
 69    return filt.to_online_filter().process(x, drain=tail)
 70
 71
 72def apply_matrix_lattice_finite_adjoint_time_domain(
 73    y: np.ndarray,
 74    filt: ld.MatrixLatticeAllPass,
 75    *,
 76    tail: int = 256,
 77    output_length: int | None = None,
 78) -> np.ndarray:
 79    """Apply the finite-record time-domain adjoint used for reconstruction checks."""
 80
 81    y = np.asarray(y, dtype=np.complex128)
 82    if y.ndim != 2 or y.shape[1] != filt.dimension:
 83        raise ValueError("y must have shape (samples, filter.dimension)")
 84    h = filt.impulse_response(tail)
 85    return ld.matrix_lattice_finite_adjoint(y, h, output_length=output_length)
 86
 87
 88def normalized_covariance_magnitude(x: np.ndarray) -> np.ndarray:
 89    """Return absolute normalized channel covariance."""
 90
 91    x = np.asarray(x, dtype=np.complex128)
 92    centered = x - np.mean(x, axis=0, keepdims=True)
 93    cov = centered.conj().T @ centered / max(x.shape[0] - 1, 1)
 94    scale = np.sqrt(np.outer(np.real(np.diag(cov)), np.real(np.diag(cov)))) + 1e-30
 95    return np.abs(cov) / scale
 96
 97
 98def main() -> None:
 99    out_dir = artifact_dir()
100    channels = 3
101    order = 5
102    samples = 2048
103    tail = 768
104
105    filt = make_coupled_lattice(channels=channels, order=order)
106    x = coupled_complex_signal(samples=samples, channels=channels)
107
108    y = apply_matrix_lattice_streaming(x, filt, tail=tail)
109    h = filt.impulse_response(tail)
110    y_truncated = ld.matrix_lattice_impulse_response_convolution(x, h, drain=tail)
111    x_hat = apply_matrix_lattice_finite_adjoint_time_domain(
112        y, filt, tail=tail, output_length=samples
113    )
114
115    omega = np.linspace(0.0, np.pi, 512)
116    response = filt.frequency_response(
117        omega, n_threads=int(os.environ.get("LATTICE_DSP_N_THREADS", "1"))
118    )
119    singular_values = np.linalg.svd(response, compute_uv=False)
120    unitarity_error = filt.unitarity_error(omega)
121    energy_error = abs(float(np.vdot(y, y).real) - float(np.vdot(x, x).real)) / max(
122        float(np.vdot(x, x).real), 1e-30
123    )
124    reconstruction_error = float(np.linalg.norm(x_hat - x) / max(np.linalg.norm(x), 1e-30))
125    streaming_vs_truncated_error = float(
126        np.linalg.norm(y - y_truncated) / max(np.linalg.norm(y), 1e-30)
127    )
128
129    cov_in = normalized_covariance_magnitude(x)
130    cov_out = normalized_covariance_magnitude(y[:samples])
131
132    summary = {
133        "channels": channels,
134        "order": order,
135        "samples": samples,
136        "tail_samples": tail,
137        "max_reflection_singular_value": filt.max_reflection_singular_value(),
138        "real_scalar_parameter_count": filt.parameter_count(),
139        "unitarity_error": unitarity_error,
140        "streaming_vs_truncated_impulse_error": streaming_vs_truncated_error,
141        "energy_relative_error_with_tail": energy_error,
142        "finite_adjoint_reconstruction_error": reconstruction_error,
143        "input_mean_offdiag_cov": float(
144            (np.sum(cov_in) - np.trace(cov_in)) / (channels * (channels - 1))
145        ),
146        "output_mean_offdiag_cov": float(
147            (np.sum(cov_out) - np.trace(cov_out)) / (channels * (channels - 1))
148        ),
149    }
150
151    csv_path = out_dir / "coupled_mimo_lattice_filter_summary.csv"
152    with csv_path.open("w", newline="", encoding="utf-8") as f:
153        writer = csv.DictWriter(f, fieldnames=list(summary))
154        writer.writeheader()
155        writer.writerow(summary)
156
157    print("channels:", channels)
158    print("matrix-lattice order:", order)
159    print("samples:", samples)
160    print("tail samples for energy/reconstruction:", tail)
161    print("max reflection singular value:", f"{summary['max_reflection_singular_value']:.4f}")
162    print("real scalar parameter count:", summary["real_scalar_parameter_count"])
163    print("max unitarity error:", f"{unitarity_error:.3e}")
164    print("streaming vs truncated impulse error:", f"{streaming_vs_truncated_error:.3e}")
165    print("energy relative error with tail:", f"{energy_error:.3e}")
166    print("finite-adjoint reconstruction error:", f"{reconstruction_error:.3e}")
167    print(
168        "input/output mean off-diagonal covariance:",
169        f"{summary['input_mean_offdiag_cov']:.3f}",
170        f"{summary['output_mean_offdiag_cov']:.3f}",
171    )
172    print(
173        "causal analysis: y[n] is produced by OnlineMatrixLatticeAllPass before future x samples are seen"
174    )
175    print("finite adjoint: reconstruction uses the whole transformed block and is noncausal")
176    print(f"wrote {csv_path}")
177
178    try:
179        import matplotlib.pyplot as plt
180    except Exception:
181        print("matplotlib is not installed; skipped figures")
182        return
183
184    fig, ax = plt.subplots(figsize=(8.0, 4.5))
185    for i in range(channels):
186        ax.plot(omega, singular_values[:, i], label=f"s{i + 1}")
187    ax.set_title("Matrix-lattice singular values over frequency")
188    ax.set_xlabel("radian frequency")
189    ax.set_ylabel("singular value")
190    ax.grid(True, alpha=0.3)
191    ax.legend()
192    fig.tight_layout()
193    fig_path = out_dir / "coupled_mimo_lattice_singular_values.png"
194    fig.savefig(fig_path, dpi=160)
195    print(f"wrote {fig_path}")
196
197    fig2, axes = plt.subplots(1, 2, figsize=(9.0, 4.0))
198    axes[0].imshow(cov_in, vmin=0.0, vmax=1.0)
199    axes[0].set_title("input |normalized covariance|")
200    axes[0].set_xlabel("channel")
201    axes[0].set_ylabel("channel")
202    im1 = axes[1].imshow(cov_out, vmin=0.0, vmax=1.0)
203    axes[1].set_title("streaming output |normalized covariance|")
204    axes[1].set_xlabel("channel")
205    fig2.colorbar(im1, ax=axes.ravel().tolist(), shrink=0.82)
206    fig2_path = out_dir / "coupled_mimo_lattice_covariance.png"
207    fig2.savefig(fig2_path, dpi=160, bbox_inches="tight")
208    print(f"wrote {fig2_path}")
209
210    fig3, ax3 = plt.subplots(figsize=(8.0, 4.5))
211    span = min(320, samples)
212    ax3.plot(np.real(x[:span, 0]), label="input ch0 real")
213    ax3.plot(np.real(y[:span, 0]), label="streaming output ch0 real", alpha=0.8)
214    ax3.set_title("Causal matrix-lattice analysis on one channel")
215    ax3.set_xlabel("sample")
216    ax3.set_ylabel("amplitude")
217    ax3.grid(True, alpha=0.3)
218    ax3.legend()
219    fig3.tight_layout()
220    fig3_path = out_dir / "coupled_mimo_lattice_streaming_trace.png"
221    fig3.savefig(fig3_path, dpi=160)
222    print(f"wrote {fig3_path}")
223
224
225if __name__ == "__main__":
226    main()