Matrix lattice all-pass response

Tutorial goal

Build a matrix all-pass lattice and verify unitary response behavior.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

Matrix lattice filters generalize scalar reflection coefficients to matrix reflection coefficients. They are useful for compact multichannel unitary or paraunitary transforms.

Key idea and equations

A matrix-lattice section replaces a scalar reflection coefficient with a reflection matrix \(K_i\). The contractive-stage diagnostic is

\[\lVert K_i\rVert_2 < 1.\]

Under the all-pass/scattering construction used here, the resulting transfer matrix \(G(z)\) should satisfy, on the unit circle,

\[G(e^{j\omega})^H G(e^{j\omega}) = I.\]

Equivalently, every singular value of \(G(e^{j\omega})\) should be one. The entry-magnitude plot shows how individual channel responses vary with frequency; the singular-value and residual plots check the stronger unitary matrix property.

Causality and data use

This example verifies the all-pass transfer function on a frequency grid and also uses MatrixLatticeAllPass.to_online_filter() to realize the same object as a sample-by-sample causal runtime. The streaming impulse-response check uses only current input and previous section states.

What this example verifies

This verifies that the causal online all-pass runtime matches the frequency response represented by MatrixLatticeAllPass. The impulse-response FFT should agree with the frequency-grid evaluator, and singular values should remain close to one on the unit circle.

How to read the result

The singular-value and unitarity-residual figures should be flat at one and near numerical precision if the generated matrix reflections are contractive.

Run command

python examples/matrix_lattice_allpass.py

Source code

  1"""Matrix-valued lattice/all-pass response demo."""
  2
  3from __future__ import annotations
  4
  5import os
  6from pathlib import Path
  7
  8import numpy as np
  9
 10from lattice_dsp import MatrixLatticeAllPass, contractive_matrix_from_raw, unitary_polar_factor
 11
 12
 13def _artifact_dir() -> Path:
 14    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 15    path.mkdir(parents=True, exist_ok=True)
 16    return path
 17
 18
 19def _save_figures(
 20    omega: np.ndarray, response: np.ndarray, filt: MatrixLatticeAllPass, impulse: np.ndarray
 21) -> None:
 22    try:
 23        import matplotlib.pyplot as plt
 24    except ImportError:  # pragma: no cover - optional plotting dependency
 25        print("matplotlib is not installed; skipped figures")
 26        return
 27
 28    out_dir = _artifact_dir()
 29    singular_values = np.linalg.svd(response, compute_uv=False)
 30    eye = np.eye(filt.dimension)
 31    unitarity_error = np.array([np.linalg.norm(hi.conj().T @ hi - eye) for hi in response])
 32
 33    fig, ax = plt.subplots(figsize=(7.2, 4.0))
 34    for idx in range(singular_values.shape[1]):
 35        ax.plot(omega, singular_values[:, idx], label=f{idx + 1}")
 36    ax.set_xlabel("rad/sample")
 37    ax.set_ylabel("singular value")
 38    ax.set_title("All singular values stay at one for an all-pass response")
 39    ax.legend(loc="best")
 40    fig.tight_layout()
 41    path = out_dir / "matrix_lattice_allpass_singular_values.png"
 42    fig.savefig(path, dpi=160)
 43    plt.close(fig)
 44    print(f"wrote {path}")
 45
 46    fig, ax = plt.subplots(figsize=(7.0, 3.8))
 47    ax.semilogy(omega, np.maximum(unitarity_error, 1e-18))
 48    ax.set_xlabel("rad/sample")
 49    ax.set_ylabel("||HᴴH - I||₂")
 50    ax.set_title("Frequency-by-frequency unitarity residual")
 51    fig.tight_layout()
 52    path = out_dir / "matrix_lattice_allpass_unitarity_error.png"
 53    fig.savefig(path, dpi=160)
 54    plt.close(fig)
 55    print(f"wrote {path}")
 56
 57    fig, ax = plt.subplots(figsize=(7.0, 3.8))
 58    ax.semilogy(np.maximum(np.linalg.norm(impulse, axis=(1, 2)), 1e-18))
 59    ax.set_xlabel("sample")
 60    ax.set_ylabel("impulse-response block Frobenius norm")
 61    ax.set_title("Causal online realization has a decaying all-pass tail")
 62    fig.tight_layout()
 63    path = out_dir / "matrix_lattice_allpass_streaming_impulse.png"
 64    fig.savefig(path, dpi=160)
 65    plt.close(fig)
 66    print(f"wrote {path}")
 67
 68    freq_indices = [0, len(omega) // 4, len(omega) // 2]
 69    fig, axes = plt.subplots(1, len(freq_indices), figsize=(9.2, 3.2))
 70    for ax, idx in zip(axes, freq_indices, strict=True):
 71        im = ax.imshow(np.abs(response[idx]))
 72        ax.set_title(f"|H(e^{{}})| at ω={omega[idx]:.2f}")
 73        ax.set_xlabel("input")
 74        ax.set_ylabel("output")
 75        fig.colorbar(im, ax=ax, shrink=0.75)
 76    fig.tight_layout()
 77    path = out_dir / "matrix_lattice_allpass_entry_magnitudes.png"
 78    fig.savefig(path, dpi=160)
 79    plt.close(fig)
 80    print(f"wrote {path}")
 81
 82
 83rng = np.random.default_rng(123)
 84dim = 3
 85order = 4
 86
 87reflections = [
 88    contractive_matrix_from_raw(
 89        0.35 * (rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)))
 90    )
 91    for _ in range(order)
 92]
 93residue = unitary_polar_factor(rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim)))
 94
 95filt = MatrixLatticeAllPass(reflections, residue=residue)
 96omega = np.linspace(0.0, np.pi, 256)
 97response = filt.frequency_response(omega)
 98
 99# Streaming realization check: process one impulse per input channel and compare
100# the truncated impulse-response frequency response with the direct evaluator.
101n_impulse = 512
102impulse_response = np.empty((n_impulse, dim, dim), dtype=np.complex128)
103for input_channel in range(dim):
104    runtime = filt.to_online_filter()
105    impulse = np.zeros((n_impulse, dim), dtype=np.complex128)
106    impulse[0, input_channel] = 1.0
107    y = runtime.process(impulse)
108    impulse_response[:, :, input_channel] = y
109streaming_probe = np.linspace(0.0, np.pi, 32)
110powers = np.exp(-1j * np.outer(streaming_probe, np.arange(n_impulse)))
111streaming_response = np.einsum("wn,nij->wij", powers, impulse_response)
112streaming_relative_error = np.linalg.norm(
113    streaming_response - filt.frequency_response(streaming_probe)
114) / np.linalg.norm(filt.frequency_response(streaming_probe))
115
116print("dimension:", filt.dimension)
117print("order:", filt.order)
118print("max reflection singular value:", round(filt.max_reflection_singular_value(), 6))
119print("real scalar parameter count:", filt.parameter_count())
120print("max unitarity error:", f"{filt.unitarity_error(omega):.3e}")
121print("streaming impulse/frequency relative error:", f"{streaming_relative_error:.3e}")
122print("response shape:", response.shape)
123
124_save_figures(omega, response, filt, impulse_response)