Coupled MIMO matrix-lattice filtering¶
Tutorial goal
Apply a matrix-lattice all-pass to a coupled complex MIMO signal block and verify streaming energy preservation.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
This tutorial moves from static frequency-response diagnostics to a signal-processing use
case. A coupled complex multichannel signal is transformed by the causal
OnlineMatrixLatticeAllPass runtime. A finite-record time-domain adjoint then checks
reconstruction. The example verifies that the matrix-lattice response preserves energy
while still mixing channels in a frequency-dependent way.
Key idea and equations¶
The matrix-lattice response \(G(z)\) is designed as an all-pass multichannel transform:
The forward online runtime applies the causal convolution
where \(H_k\in\mathbb{C}^{c\times c}\) are matrix impulse-response coefficients. Energy preservation holds on the full stream, including the decaying all-pass tail:
The finite-record synthesis diagnostic applies the time-domain adjoint
This adjoint is noncausal as a streaming inverse because it needs future transformed samples, but it is useful when the whole record is available.
Causality and data use¶
The forward analysis path is causal and sample-by-sample. The reconstruction check is a finite-record time-domain adjoint, so it is noncausal/transductive by design and should not be confused with a causal stable inverse.
What this example verifies¶
This verifies streaming coupled forward filtering. The output is produced by the online matrix-lattice runtime, off-diagonal impulse/Markov energy shows channel coupling, and the finite-record adjoint is labeled separately as a noncausal reconstruction diagnostic.
How to read the result¶
Look for near-zero unitarity, energy, streaming-vs-impulse, and finite-adjoint reconstruction errors. The covariance plots show that the streaming block is coupled even though it is norm preserving.
Run command¶
python examples/coupled_mimo_lattice_filter.py
Source code¶
1"""Tutorial: coupled MIMO matrix-lattice filtering with streaming analysis.
2
3A :class:`lattice_dsp.MatrixLatticeAllPass` is a square, multichannel,
4frequency-dependent all-pass mixing system. This example applies the forward
5analysis transform with the causal online runtime, then uses a finite-record
6noncausal adjoint in the time domain to check reconstruction. The example is
7about the matrix-lattice runtime and diagnostics, not about model reduction.
8"""
9
10from __future__ import annotations
11
12import csv
13import os
14from pathlib import Path
15
16import numpy as np
17
18import lattice_dsp as ld
19
20
21def artifact_dir() -> Path:
22 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
23 path.mkdir(parents=True, exist_ok=True)
24 return path
25
26
27def make_coupled_lattice(
28 channels: int = 3, order: int = 5, seed: int = 202
29) -> ld.MatrixLatticeAllPass:
30 """Return a deterministic coupled matrix-lattice all-pass filter."""
31
32 rng = np.random.default_rng(seed)
33 reflections = []
34 for stage in range(order):
35 raw = rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
36 reflections.append(ld.contractive_matrix_from_raw((0.18 + 0.03 * stage) * raw, margin=1e-6))
37 residue = ld.unitary_polar_factor(
38 rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
39 )
40 return ld.MatrixLatticeAllPass(reflections, residue=residue)
41
42
43def coupled_complex_signal(samples: int = 1024, channels: int = 3, seed: int = 203) -> np.ndarray:
44 """Generate a correlated complex multichannel input block."""
45
46 rng = np.random.default_rng(seed)
47 latent = rng.normal(size=(samples, 2)) + 1j * rng.normal(size=(samples, 2))
48 mixing = np.array(
49 [
50 [1.0 + 0.0j, 0.35 - 0.10j],
51 [0.55 + 0.20j, -0.65 + 0.30j],
52 [-0.20 + 0.45j, 0.85 + 0.05j],
53 ],
54 dtype=np.complex128,
55 )[:channels, :]
56 x = latent @ mixing.T
57 x += 0.08 * (rng.normal(size=(samples, channels)) + 1j * rng.normal(size=(samples, channels)))
58 return np.ascontiguousarray(x)
59
60
61def apply_matrix_lattice_streaming(
62 x: np.ndarray, filt: ld.MatrixLatticeAllPass, *, tail: int = 256
63) -> np.ndarray:
64 """Apply the forward matrix-lattice all-pass with the causal online runtime."""
65
66 x = np.asarray(x, dtype=np.complex128)
67 if x.ndim != 2 or x.shape[1] != filt.dimension:
68 raise ValueError("x must have shape (samples, filter.dimension)")
69 return filt.to_online_filter().process(x, drain=tail)
70
71
72def apply_matrix_lattice_finite_adjoint_time_domain(
73 y: np.ndarray,
74 filt: ld.MatrixLatticeAllPass,
75 *,
76 tail: int = 256,
77 output_length: int | None = None,
78) -> np.ndarray:
79 """Apply the finite-record time-domain adjoint used for reconstruction checks."""
80
81 y = np.asarray(y, dtype=np.complex128)
82 if y.ndim != 2 or y.shape[1] != filt.dimension:
83 raise ValueError("y must have shape (samples, filter.dimension)")
84 h = filt.impulse_response(tail)
85 return ld.matrix_lattice_finite_adjoint(y, h, output_length=output_length)
86
87
88def normalized_covariance_magnitude(x: np.ndarray) -> np.ndarray:
89 """Return absolute normalized channel covariance."""
90
91 x = np.asarray(x, dtype=np.complex128)
92 centered = x - np.mean(x, axis=0, keepdims=True)
93 cov = centered.conj().T @ centered / max(x.shape[0] - 1, 1)
94 scale = np.sqrt(np.outer(np.real(np.diag(cov)), np.real(np.diag(cov)))) + 1e-30
95 return np.abs(cov) / scale
96
97
98def main() -> None:
99 out_dir = artifact_dir()
100 channels = 3
101 order = 5
102 samples = 2048
103 tail = 768
104
105 filt = make_coupled_lattice(channels=channels, order=order)
106 x = coupled_complex_signal(samples=samples, channels=channels)
107
108 y = apply_matrix_lattice_streaming(x, filt, tail=tail)
109 h = filt.impulse_response(tail)
110 y_truncated = ld.matrix_lattice_impulse_response_convolution(x, h, drain=tail)
111 x_hat = apply_matrix_lattice_finite_adjoint_time_domain(
112 y, filt, tail=tail, output_length=samples
113 )
114
115 omega = np.linspace(0.0, np.pi, 512)
116 response = filt.frequency_response(
117 omega, n_threads=int(os.environ.get("LATTICE_DSP_N_THREADS", "1"))
118 )
119 singular_values = np.linalg.svd(response, compute_uv=False)
120 unitarity_error = filt.unitarity_error(omega)
121 energy_error = abs(float(np.vdot(y, y).real) - float(np.vdot(x, x).real)) / max(
122 float(np.vdot(x, x).real), 1e-30
123 )
124 reconstruction_error = float(np.linalg.norm(x_hat - x) / max(np.linalg.norm(x), 1e-30))
125 streaming_vs_truncated_error = float(
126 np.linalg.norm(y - y_truncated) / max(np.linalg.norm(y), 1e-30)
127 )
128
129 cov_in = normalized_covariance_magnitude(x)
130 cov_out = normalized_covariance_magnitude(y[:samples])
131
132 summary = {
133 "channels": channels,
134 "order": order,
135 "samples": samples,
136 "tail_samples": tail,
137 "max_reflection_singular_value": filt.max_reflection_singular_value(),
138 "real_scalar_parameter_count": filt.parameter_count(),
139 "unitarity_error": unitarity_error,
140 "streaming_vs_truncated_impulse_error": streaming_vs_truncated_error,
141 "energy_relative_error_with_tail": energy_error,
142 "finite_adjoint_reconstruction_error": reconstruction_error,
143 "input_mean_offdiag_cov": float(
144 (np.sum(cov_in) - np.trace(cov_in)) / (channels * (channels - 1))
145 ),
146 "output_mean_offdiag_cov": float(
147 (np.sum(cov_out) - np.trace(cov_out)) / (channels * (channels - 1))
148 ),
149 }
150
151 csv_path = out_dir / "coupled_mimo_lattice_filter_summary.csv"
152 with csv_path.open("w", newline="", encoding="utf-8") as f:
153 writer = csv.DictWriter(f, fieldnames=list(summary))
154 writer.writeheader()
155 writer.writerow(summary)
156
157 print("channels:", channels)
158 print("matrix-lattice order:", order)
159 print("samples:", samples)
160 print("tail samples for energy/reconstruction:", tail)
161 print("max reflection singular value:", f"{summary['max_reflection_singular_value']:.4f}")
162 print("real scalar parameter count:", summary["real_scalar_parameter_count"])
163 print("max unitarity error:", f"{unitarity_error:.3e}")
164 print("streaming vs truncated impulse error:", f"{streaming_vs_truncated_error:.3e}")
165 print("energy relative error with tail:", f"{energy_error:.3e}")
166 print("finite-adjoint reconstruction error:", f"{reconstruction_error:.3e}")
167 print(
168 "input/output mean off-diagonal covariance:",
169 f"{summary['input_mean_offdiag_cov']:.3f}",
170 f"{summary['output_mean_offdiag_cov']:.3f}",
171 )
172 print(
173 "causal analysis: y[n] is produced by OnlineMatrixLatticeAllPass before future x samples are seen"
174 )
175 print("finite adjoint: reconstruction uses the whole transformed block and is noncausal")
176 print(f"wrote {csv_path}")
177
178 try:
179 import matplotlib.pyplot as plt
180 except Exception:
181 print("matplotlib is not installed; skipped figures")
182 return
183
184 fig, ax = plt.subplots(figsize=(8.0, 4.5))
185 for i in range(channels):
186 ax.plot(omega, singular_values[:, i], label=f"s{i + 1}")
187 ax.set_title("Matrix-lattice singular values over frequency")
188 ax.set_xlabel("radian frequency")
189 ax.set_ylabel("singular value")
190 ax.grid(True, alpha=0.3)
191 ax.legend()
192 fig.tight_layout()
193 fig_path = out_dir / "coupled_mimo_lattice_singular_values.png"
194 fig.savefig(fig_path, dpi=160)
195 print(f"wrote {fig_path}")
196
197 fig2, axes = plt.subplots(1, 2, figsize=(9.0, 4.0))
198 axes[0].imshow(cov_in, vmin=0.0, vmax=1.0)
199 axes[0].set_title("input |normalized covariance|")
200 axes[0].set_xlabel("channel")
201 axes[0].set_ylabel("channel")
202 im1 = axes[1].imshow(cov_out, vmin=0.0, vmax=1.0)
203 axes[1].set_title("streaming output |normalized covariance|")
204 axes[1].set_xlabel("channel")
205 fig2.colorbar(im1, ax=axes.ravel().tolist(), shrink=0.82)
206 fig2_path = out_dir / "coupled_mimo_lattice_covariance.png"
207 fig2.savefig(fig2_path, dpi=160, bbox_inches="tight")
208 print(f"wrote {fig2_path}")
209
210 fig3, ax3 = plt.subplots(figsize=(8.0, 4.5))
211 span = min(320, samples)
212 ax3.plot(np.real(x[:span, 0]), label="input ch0 real")
213 ax3.plot(np.real(y[:span, 0]), label="streaming output ch0 real", alpha=0.8)
214 ax3.set_title("Causal matrix-lattice analysis on one channel")
215 ax3.set_xlabel("sample")
216 ax3.set_ylabel("amplitude")
217 ax3.grid(True, alpha=0.3)
218 ax3.legend()
219 fig3.tight_layout()
220 fig3_path = out_dir / "coupled_mimo_lattice_streaming_trace.png"
221 fig3.savefig(fig3_path, dpi=160)
222 print(f"wrote {fig3_path}")
223
224
225if __name__ == "__main__":
226 main()