Matrix lattice all-pass response¶
Tutorial goal
Build a matrix all-pass lattice and verify unitary response behavior.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
Matrix lattice filters generalize scalar reflection coefficients to matrix reflection coefficients. They are useful for compact multichannel unitary or paraunitary transforms.
Key idea and equations¶
A matrix-lattice section replaces a scalar reflection coefficient with a reflection matrix \(K_i\). The contractive-stage diagnostic is
Under the all-pass/scattering construction used here, the resulting transfer matrix \(G(z)\) should satisfy, on the unit circle,
Equivalently, every singular value of \(G(e^{j\omega})\) should be one. The entry-magnitude plot shows how individual channel responses vary with frequency; the singular-value and residual plots check the stronger unitary matrix property.
Causality and data use¶
This example verifies the all-pass transfer function on a frequency grid and also uses MatrixLatticeAllPass.to_online_filter() to realize the same object as a sample-by-sample causal runtime. The streaming impulse-response check uses only current input and previous section states.
What this example verifies¶
This verifies that the causal online all-pass runtime matches the frequency
response represented by MatrixLatticeAllPass. The impulse-response FFT
should agree with the frequency-grid evaluator, and singular values should
remain close to one on the unit circle.
How to read the result¶
The singular-value and unitarity-residual figures should be flat at one and near numerical precision if the generated matrix reflections are contractive.
Run command¶
python examples/matrix_lattice_allpass.py
Source code¶
1"""Matrix-valued lattice/all-pass response demo."""
2
3from __future__ import annotations
4
5import os
6from pathlib import Path
7
8import numpy as np
9
10from lattice_dsp import MatrixLatticeAllPass, contractive_matrix_from_raw, unitary_polar_factor
11
12
13def _artifact_dir() -> Path:
14 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
15 path.mkdir(parents=True, exist_ok=True)
16 return path
17
18
19def _save_figures(
20 omega: np.ndarray, response: np.ndarray, filt: MatrixLatticeAllPass, impulse: np.ndarray
21) -> None:
22 try:
23 import matplotlib.pyplot as plt
24 except ImportError: # pragma: no cover - optional plotting dependency
25 print("matplotlib is not installed; skipped figures")
26 return
27
28 out_dir = _artifact_dir()
29 singular_values = np.linalg.svd(response, compute_uv=False)
30 eye = np.eye(filt.dimension)
31 unitarity_error = np.array([np.linalg.norm(hi.conj().T @ hi - eye) for hi in response])
32
33 fig, ax = plt.subplots(figsize=(7.2, 4.0))
34 for idx in range(singular_values.shape[1]):
35 ax.plot(omega, singular_values[:, idx], label=f"σ{idx + 1}")
36 ax.set_xlabel("rad/sample")
37 ax.set_ylabel("singular value")
38 ax.set_title("All singular values stay at one for an all-pass response")
39 ax.legend(loc="best")
40 fig.tight_layout()
41 path = out_dir / "matrix_lattice_allpass_singular_values.png"
42 fig.savefig(path, dpi=160)
43 plt.close(fig)
44 print(f"wrote {path}")
45
46 fig, ax = plt.subplots(figsize=(7.0, 3.8))
47 ax.semilogy(omega, np.maximum(unitarity_error, 1e-18))
48 ax.set_xlabel("rad/sample")
49 ax.set_ylabel("||HᴴH - I||₂")
50 ax.set_title("Frequency-by-frequency unitarity residual")
51 fig.tight_layout()
52 path = out_dir / "matrix_lattice_allpass_unitarity_error.png"
53 fig.savefig(path, dpi=160)
54 plt.close(fig)
55 print(f"wrote {path}")
56
57 fig, ax = plt.subplots(figsize=(7.0, 3.8))
58 ax.semilogy(np.maximum(np.linalg.norm(impulse, axis=(1, 2)), 1e-18))
59 ax.set_xlabel("sample")
60 ax.set_ylabel("impulse-response block Frobenius norm")
61 ax.set_title("Causal online realization has a decaying all-pass tail")
62 fig.tight_layout()
63 path = out_dir / "matrix_lattice_allpass_streaming_impulse.png"
64 fig.savefig(path, dpi=160)
65 plt.close(fig)
66 print(f"wrote {path}")
67
68 freq_indices = [0, len(omega) // 4, len(omega) // 2]
69 fig, axes = plt.subplots(1, len(freq_indices), figsize=(9.2, 3.2))
70 for ax, idx in zip(axes, freq_indices, strict=True):
71 im = ax.imshow(np.abs(response[idx]))
72 ax.set_title(f"|H(e^{{jω}})| at ω={omega[idx]:.2f}")
73 ax.set_xlabel("input")
74 ax.set_ylabel("output")
75 fig.colorbar(im, ax=ax, shrink=0.75)
76 fig.tight_layout()
77 path = out_dir / "matrix_lattice_allpass_entry_magnitudes.png"
78 fig.savefig(path, dpi=160)
79 plt.close(fig)
80 print(f"wrote {path}")
81
82
83rng = np.random.default_rng(123)
84dim = 3
85order = 4
86
87reflections = [
88 contractive_matrix_from_raw(
89 0.35 * (rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)))
90 )
91 for _ in range(order)
92]
93residue = unitary_polar_factor(rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)))
94
95filt = MatrixLatticeAllPass(reflections, residue=residue)
96omega = np.linspace(0.0, np.pi, 256)
97response = filt.frequency_response(omega)
98
99# Streaming realization check: process one impulse per input channel and compare
100# the truncated impulse-response frequency response with the direct evaluator.
101n_impulse = 512
102impulse_response = np.empty((n_impulse, dim, dim), dtype=np.complex128)
103for input_channel in range(dim):
104 runtime = filt.to_online_filter()
105 impulse = np.zeros((n_impulse, dim), dtype=np.complex128)
106 impulse[0, input_channel] = 1.0
107 y = runtime.process(impulse)
108 impulse_response[:, :, input_channel] = y
109streaming_probe = np.linspace(0.0, np.pi, 32)
110powers = np.exp(-1j * np.outer(streaming_probe, np.arange(n_impulse)))
111streaming_response = np.einsum("wn,nij->wij", powers, impulse_response)
112streaming_relative_error = np.linalg.norm(
113 streaming_response - filt.frequency_response(streaming_probe)
114) / np.linalg.norm(filt.frequency_response(streaming_probe))
115
116print("dimension:", filt.dimension)
117print("order:", filt.order)
118print("max reflection singular value:", round(filt.max_reflection_singular_value(), 6))
119print("real scalar parameter count:", filt.parameter_count())
120print("max unitarity error:", f"{filt.unitarity_error(omega):.3e}")
121print("streaming impulse/frequency relative error:", f"{streaming_relative_error:.3e}")
122print("response shape:", response.shape)
123
124_save_figures(omega, response, filt, impulse_response)