Unitary convolution block for ML-style stability

Tutorial goal

Show a streaming norm-preserving convolution-like block motivated by stable ML layers.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

Orthogonal/unitary transforms can improve numerical stability in learned models. This demo connects matrix-lattice ideas to norm-preserving convolution blocks as a DSP demonstration, not a full ML framework. Unlike a circular FFT layer, the forward map here is run by the causal online matrix-lattice runtime.

Key idea and equations

The streaming block applies a causal multichannel convolution

\[y[n] = \sum_{k\ge 0} H_k x[n-k].\]

The all-pass condition

\[H(e^{j\omega})^H H(e^{j\omega}) = I\]

keeps the induced \(\ell_2\) norm controlled on the full stream:

\[\lVert y\rVert_2 \approx \lVert x\rVert_2,\]

after appending enough zero-input samples to include the tail. The finite-record adjoint diagnostic uses

\[x_{adj}[n] = \sum_{k\ge 0} H_k^H y[n+k],\]

which is useful for reconstruction checks but is noncausal as an online inverse.

Causality and data use

The forward map is causal and streaming. The adjoint reconstruction check is time-domain but finite-block/noncausal, which matches how adjoints are used in offline ML-style diagnostics.

What this example verifies

This verifies a DSP analogue of a norm-preserving convolution block. The forward map is causal and streaming; norm preservation is checked on the full stream with tail padding, while the adjoint reconstruction diagnostic is finite-record and noncausal.

How to read the result

Check the input/output norm figure, singular-value plot, streaming trace, and finite-adjoint error plot; a streaming unitary convolution block should preserve each batch-item norm after its tail is included.

Run command

python examples/ml_unitary_convolution_demo.py

Run status

Return code: 0

Captured stdout

batch size: 8
sequence length: 1024
channels: 6
order: 4
tail samples: 1024
real scalar parameters: 360
max streaming norm-preservation error: 5.157e-16
max finite-adjoint reconstruction error: 4.311e-14
singular value range: [1.000000, 1.000000]
causal forward: output at n uses current input and previous lattice state
finite adjoint: reconstruction is time-domain but noncausal over the block
takeaway: matrix lattice filters can parameterize streaming norm-preserving convolution blocks

Figures

ml unitary convolution adjoint error

ml_unitary_convolution_adjoint_error.png

ml unitary convolution batch norms

ml_unitary_convolution_batch_norms.png

ml unitary convolution channel energy

ml_unitary_convolution_channel_energy.png

ml unitary convolution singular values

ml_unitary_convolution_singular_values.png

ml unitary convolution streaming trace

ml_unitary_convolution_streaming_trace.png

Source code

  1"""ML-adjacent streaming unitary convolution demo.
  2
  3Orthogonal/unitary convolutions are useful in ML because they preserve signal
  4norms and keep forward/adjoint maps well conditioned.  This example uses the
  5causal online matrix-lattice all-pass runtime as a streaming multichannel
  6unitary convolution block.  A finite-record time-domain adjoint is used only for
  7an offline reconstruction diagnostic.
  8"""
  9
 10from __future__ import annotations
 11
 12import os
 13from pathlib import Path
 14
 15import numpy as np
 16
 17from lattice_dsp import (
 18    MatrixLatticeAllPass,
 19    contractive_matrix_from_raw,
 20    matrix_lattice_finite_adjoint,
 21    unitary_polar_factor,
 22)
 23
 24
 25def _artifact_dir() -> Path:
 26    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 27    path.mkdir(parents=True, exist_ok=True)
 28    return path
 29
 30
 31def _make_filter(rng: np.random.Generator, channels: int, order: int) -> MatrixLatticeAllPass:
 32    reflections = [
 33        contractive_matrix_from_raw(
 34            0.25
 35            * (rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels)))
 36        )
 37        for _ in range(order)
 38    ]
 39    residue = unitary_polar_factor(
 40        rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
 41    )
 42    return MatrixLatticeAllPass(reflections, residue=residue)
 43
 44
 45def _forward_streaming(batch: np.ndarray, filt: MatrixLatticeAllPass, *, tail: int) -> np.ndarray:
 46    out = np.empty((batch.shape[0], batch.shape[1] + tail, batch.shape[2]), dtype=np.complex128)
 47    for item in range(batch.shape[0]):
 48        out[item] = filt.to_online_filter().process(batch[item], drain=tail)
 49    return out
 50
 51
 52def _finite_adjoint(
 53    batch: np.ndarray, filt: MatrixLatticeAllPass, *, tail: int, output_length: int
 54) -> np.ndarray:
 55    h = filt.impulse_response(tail)
 56    out = np.empty((batch.shape[0], output_length, batch.shape[2]), dtype=np.complex128)
 57    for item in range(batch.shape[0]):
 58        out[item] = matrix_lattice_finite_adjoint(batch[item], h, output_length=output_length)
 59    return out
 60
 61
 62def _save_figures(
 63    *,
 64    input_norms: np.ndarray,
 65    output_norms: np.ndarray,
 66    x: np.ndarray,
 67    y: np.ndarray,
 68    x_hat: np.ndarray,
 69    omega_probe: np.ndarray,
 70    singular_values: np.ndarray,
 71) -> None:
 72    try:
 73        import matplotlib.pyplot as plt
 74    except ImportError:  # pragma: no cover - optional plotting dependency
 75        print("matplotlib is not installed; skipped figures")
 76        return
 77
 78    out_dir = _artifact_dir()
 79
 80    fig, ax = plt.subplots(figsize=(7.0, 4.0))
 81    batch = np.arange(len(input_norms))
 82    ax.plot(batch, input_norms, marker="o", label="input")
 83    ax.plot(batch, output_norms, marker="x", linestyle="--", label="streaming output with tail")
 84    ax.set_xlabel("batch item")
 85    ax.set_ylabel("flattened signal norm")
 86    ax.set_title("Streaming unitary convolution preserves each batch-item norm")
 87    ax.legend(loc="best")
 88    fig.tight_layout()
 89    path = out_dir / "ml_unitary_convolution_batch_norms.png"
 90    fig.savefig(path, dpi=160)
 91    plt.close(fig)
 92    print(f"wrote {path}")
 93
 94    fig, ax = plt.subplots(figsize=(7.2, 4.0))
 95    for idx in range(singular_values.shape[1]):
 96        ax.plot(omega_probe, singular_values[:, idx], label=f{idx + 1}")
 97    ax.set_xlabel("rad/sample")
 98    ax.set_ylabel("singular value")
 99    ax.set_title("Frequency response stays unitary")
100    ax.legend(loc="best", ncol=2)
101    fig.tight_layout()
102    path = out_dir / "ml_unitary_convolution_singular_values.png"
103    fig.savefig(path, dpi=160)
104    plt.close(fig)
105    print(f"wrote {path}")
106
107    channel_energy_in = np.mean(np.abs(x) ** 2, axis=(0, 1))
108    channel_energy_out = np.mean(np.abs(y[:, : x.shape[1]]) ** 2, axis=(0, 1))
109    fig, ax = plt.subplots(figsize=(7.2, 4.0))
110    idx = np.arange(len(channel_energy_in))
111    width = 0.36
112    ax.bar(idx - width / 2, channel_energy_in, width=width, label="input")
113    ax.bar(idx + width / 2, channel_energy_out, width=width, label="streaming output prefix")
114    ax.set_xlabel("channel")
115    ax.set_ylabel("mean energy")
116    ax.set_title("Energy may move across channels while total norm is preserved")
117    ax.legend(loc="best")
118    fig.tight_layout()
119    path = out_dir / "ml_unitary_convolution_channel_energy.png"
120    fig.savefig(path, dpi=160)
121    plt.close(fig)
122    print(f"wrote {path}")
123
124    reconstruction = np.linalg.norm((x_hat - x).reshape(x.shape[0], -1), axis=1) / input_norms
125    fig, ax = plt.subplots(figsize=(7.0, 3.8))
126    ax.semilogy(np.maximum(reconstruction, 1e-18), marker="o")
127    ax.set_xlabel("batch item")
128    ax.set_ylabel("relative finite-adjoint reconstruction error")
129    ax.set_title("Time-domain adjoint recovers the input")
130    fig.tight_layout()
131    path = out_dir / "ml_unitary_convolution_adjoint_error.png"
132    fig.savefig(path, dpi=160)
133    plt.close(fig)
134    print(f"wrote {path}")
135
136    fig, ax = plt.subplots(figsize=(7.2, 4.0))
137    span = min(256, x.shape[1])
138    ax.plot(np.real(x[0, :span, 0]), label="input ch0 real")
139    ax.plot(np.real(y[0, :span, 0]), label="streaming output ch0 real", alpha=0.8)
140    ax.set_xlabel("sample")
141    ax.set_ylabel("amplitude")
142    ax.set_title("Causal online convolution trace")
143    ax.legend(loc="best")
144    fig.tight_layout()
145    path = out_dir / "ml_unitary_convolution_streaming_trace.png"
146    fig.savefig(path, dpi=160)
147    plt.close(fig)
148    print(f"wrote {path}")
149
150
151rng = np.random.default_rng(314)
152batch_size = 8
153sequence_length = 1024
154channels = 6
155order = 4
156tail = 1024
157
158filt = _make_filter(rng, channels, order)
159x = rng.normal(size=(batch_size, sequence_length, channels)) + 1j * rng.normal(
160    size=(batch_size, sequence_length, channels)
161)
162y = _forward_streaming(x, filt, tail=tail)
163x_hat = _finite_adjoint(y, filt, tail=tail, output_length=sequence_length)
164
165input_norms = np.linalg.norm(x.reshape(batch_size, -1), axis=1)
166output_norms = np.linalg.norm(y.reshape(batch_size, -1), axis=1)
167max_norm_error = float(np.max(np.abs(output_norms - input_norms) / input_norms))
168max_adjoint_error = float(
169    np.max(np.linalg.norm((x_hat - x).reshape(batch_size, -1), axis=1) / input_norms)
170)
171
172omega_probe = np.linspace(0.0, np.pi, 64)
173response = filt.frequency_response(omega_probe)
174singular_values = np.linalg.svd(response, compute_uv=False)
175
176print("batch size:", batch_size)
177print("sequence length:", sequence_length)
178print("channels:", channels)
179print("order:", order)
180print("tail samples:", tail)
181print("real scalar parameters:", filt.parameter_count())
182print("max streaming norm-preservation error:", f"{max_norm_error:.3e}")
183print("max finite-adjoint reconstruction error:", f"{max_adjoint_error:.3e}")
184print("singular value range:", f"[{singular_values.min():.6f}, {singular_values.max():.6f}]")
185print("causal forward: output at n uses current input and previous lattice state")
186print("finite adjoint: reconstruction is time-domain but noncausal over the block")
187print(
188    "takeaway: matrix lattice filters can parameterize streaming norm-preserving convolution blocks"
189)
190
191_save_figures(
192    input_norms=input_norms,
193    output_norms=output_norms,
194    x=x,
195    y=y,
196    x_hat=x_hat,
197    omega_probe=omega_probe,
198    singular_values=singular_values,
199)