Unitary convolution block for ML-style stability¶
Tutorial goal
Show a streaming norm-preserving convolution-like block motivated by stable ML layers.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
Orthogonal/unitary transforms can improve numerical stability in learned models. This demo connects matrix-lattice ideas to norm-preserving convolution blocks as a DSP demonstration, not a full ML framework. Unlike a circular FFT layer, the forward map here is run by the causal online matrix-lattice runtime.
Key idea and equations¶
The streaming block applies a causal multichannel convolution
The all-pass condition
keeps the induced \(\ell_2\) norm controlled on the full stream:
after appending enough zero-input samples to include the tail. The finite-record adjoint diagnostic uses
which is useful for reconstruction checks but is noncausal as an online inverse.
Causality and data use¶
The forward map is causal and streaming. The adjoint reconstruction check is time-domain but finite-block/noncausal, which matches how adjoints are used in offline ML-style diagnostics.
What this example verifies¶
This verifies a DSP analogue of a norm-preserving convolution block. The forward map is causal and streaming; norm preservation is checked on the full stream with tail padding, while the adjoint reconstruction diagnostic is finite-record and noncausal.
How to read the result¶
Check the input/output norm figure, singular-value plot, streaming trace, and finite-adjoint error plot; a streaming unitary convolution block should preserve each batch-item norm after its tail is included.
Run command¶
python examples/ml_unitary_convolution_demo.py
Run status¶
Return code: 0
Captured stdout¶
batch size: 8
sequence length: 1024
channels: 6
order: 4
tail samples: 1024
real scalar parameters: 360
max streaming norm-preservation error: 5.157e-16
max finite-adjoint reconstruction error: 4.311e-14
singular value range: [1.000000, 1.000000]
causal forward: output at n uses current input and previous lattice state
finite adjoint: reconstruction is time-domain but noncausal over the block
takeaway: matrix lattice filters can parameterize streaming norm-preserving convolution blocks
Figures¶
ml_unitary_convolution_adjoint_error.png¶
ml_unitary_convolution_batch_norms.png¶
ml_unitary_convolution_channel_energy.png¶
ml_unitary_convolution_singular_values.png¶
ml_unitary_convolution_streaming_trace.png¶
Source code¶
1"""ML-adjacent streaming unitary convolution demo.
2
3Orthogonal/unitary convolutions are useful in ML because they preserve signal
4norms and keep forward/adjoint maps well conditioned. This example uses the
5causal online matrix-lattice all-pass runtime as a streaming multichannel
6unitary convolution block. A finite-record time-domain adjoint is used only for
7an offline reconstruction diagnostic.
8"""
9
10from __future__ import annotations
11
12import os
13from pathlib import Path
14
15import numpy as np
16
17from lattice_dsp import (
18 MatrixLatticeAllPass,
19 contractive_matrix_from_raw,
20 matrix_lattice_finite_adjoint,
21 unitary_polar_factor,
22)
23
24
25def _artifact_dir() -> Path:
26 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
27 path.mkdir(parents=True, exist_ok=True)
28 return path
29
30
31def _make_filter(rng: np.random.Generator, channels: int, order: int) -> MatrixLatticeAllPass:
32 reflections = [
33 contractive_matrix_from_raw(
34 0.25
35 * (rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels)))
36 )
37 for _ in range(order)
38 ]
39 residue = unitary_polar_factor(
40 rng.normal(size=(channels, channels)) + 1j * rng.normal(size=(channels, channels))
41 )
42 return MatrixLatticeAllPass(reflections, residue=residue)
43
44
45def _forward_streaming(batch: np.ndarray, filt: MatrixLatticeAllPass, *, tail: int) -> np.ndarray:
46 out = np.empty((batch.shape[0], batch.shape[1] + tail, batch.shape[2]), dtype=np.complex128)
47 for item in range(batch.shape[0]):
48 out[item] = filt.to_online_filter().process(batch[item], drain=tail)
49 return out
50
51
52def _finite_adjoint(
53 batch: np.ndarray, filt: MatrixLatticeAllPass, *, tail: int, output_length: int
54) -> np.ndarray:
55 h = filt.impulse_response(tail)
56 out = np.empty((batch.shape[0], output_length, batch.shape[2]), dtype=np.complex128)
57 for item in range(batch.shape[0]):
58 out[item] = matrix_lattice_finite_adjoint(batch[item], h, output_length=output_length)
59 return out
60
61
62def _save_figures(
63 *,
64 input_norms: np.ndarray,
65 output_norms: np.ndarray,
66 x: np.ndarray,
67 y: np.ndarray,
68 x_hat: np.ndarray,
69 omega_probe: np.ndarray,
70 singular_values: np.ndarray,
71) -> None:
72 try:
73 import matplotlib.pyplot as plt
74 except ImportError: # pragma: no cover - optional plotting dependency
75 print("matplotlib is not installed; skipped figures")
76 return
77
78 out_dir = _artifact_dir()
79
80 fig, ax = plt.subplots(figsize=(7.0, 4.0))
81 batch = np.arange(len(input_norms))
82 ax.plot(batch, input_norms, marker="o", label="input")
83 ax.plot(batch, output_norms, marker="x", linestyle="--", label="streaming output with tail")
84 ax.set_xlabel("batch item")
85 ax.set_ylabel("flattened signal norm")
86 ax.set_title("Streaming unitary convolution preserves each batch-item norm")
87 ax.legend(loc="best")
88 fig.tight_layout()
89 path = out_dir / "ml_unitary_convolution_batch_norms.png"
90 fig.savefig(path, dpi=160)
91 plt.close(fig)
92 print(f"wrote {path}")
93
94 fig, ax = plt.subplots(figsize=(7.2, 4.0))
95 for idx in range(singular_values.shape[1]):
96 ax.plot(omega_probe, singular_values[:, idx], label=f"σ{idx + 1}")
97 ax.set_xlabel("rad/sample")
98 ax.set_ylabel("singular value")
99 ax.set_title("Frequency response stays unitary")
100 ax.legend(loc="best", ncol=2)
101 fig.tight_layout()
102 path = out_dir / "ml_unitary_convolution_singular_values.png"
103 fig.savefig(path, dpi=160)
104 plt.close(fig)
105 print(f"wrote {path}")
106
107 channel_energy_in = np.mean(np.abs(x) ** 2, axis=(0, 1))
108 channel_energy_out = np.mean(np.abs(y[:, : x.shape[1]]) ** 2, axis=(0, 1))
109 fig, ax = plt.subplots(figsize=(7.2, 4.0))
110 idx = np.arange(len(channel_energy_in))
111 width = 0.36
112 ax.bar(idx - width / 2, channel_energy_in, width=width, label="input")
113 ax.bar(idx + width / 2, channel_energy_out, width=width, label="streaming output prefix")
114 ax.set_xlabel("channel")
115 ax.set_ylabel("mean energy")
116 ax.set_title("Energy may move across channels while total norm is preserved")
117 ax.legend(loc="best")
118 fig.tight_layout()
119 path = out_dir / "ml_unitary_convolution_channel_energy.png"
120 fig.savefig(path, dpi=160)
121 plt.close(fig)
122 print(f"wrote {path}")
123
124 reconstruction = np.linalg.norm((x_hat - x).reshape(x.shape[0], -1), axis=1) / input_norms
125 fig, ax = plt.subplots(figsize=(7.0, 3.8))
126 ax.semilogy(np.maximum(reconstruction, 1e-18), marker="o")
127 ax.set_xlabel("batch item")
128 ax.set_ylabel("relative finite-adjoint reconstruction error")
129 ax.set_title("Time-domain adjoint recovers the input")
130 fig.tight_layout()
131 path = out_dir / "ml_unitary_convolution_adjoint_error.png"
132 fig.savefig(path, dpi=160)
133 plt.close(fig)
134 print(f"wrote {path}")
135
136 fig, ax = plt.subplots(figsize=(7.2, 4.0))
137 span = min(256, x.shape[1])
138 ax.plot(np.real(x[0, :span, 0]), label="input ch0 real")
139 ax.plot(np.real(y[0, :span, 0]), label="streaming output ch0 real", alpha=0.8)
140 ax.set_xlabel("sample")
141 ax.set_ylabel("amplitude")
142 ax.set_title("Causal online convolution trace")
143 ax.legend(loc="best")
144 fig.tight_layout()
145 path = out_dir / "ml_unitary_convolution_streaming_trace.png"
146 fig.savefig(path, dpi=160)
147 plt.close(fig)
148 print(f"wrote {path}")
149
150
151rng = np.random.default_rng(314)
152batch_size = 8
153sequence_length = 1024
154channels = 6
155order = 4
156tail = 1024
157
158filt = _make_filter(rng, channels, order)
159x = rng.normal(size=(batch_size, sequence_length, channels)) + 1j * rng.normal(
160 size=(batch_size, sequence_length, channels)
161)
162y = _forward_streaming(x, filt, tail=tail)
163x_hat = _finite_adjoint(y, filt, tail=tail, output_length=sequence_length)
164
165input_norms = np.linalg.norm(x.reshape(batch_size, -1), axis=1)
166output_norms = np.linalg.norm(y.reshape(batch_size, -1), axis=1)
167max_norm_error = float(np.max(np.abs(output_norms - input_norms) / input_norms))
168max_adjoint_error = float(
169 np.max(np.linalg.norm((x_hat - x).reshape(batch_size, -1), axis=1) / input_norms)
170)
171
172omega_probe = np.linspace(0.0, np.pi, 64)
173response = filt.frequency_response(omega_probe)
174singular_values = np.linalg.svd(response, compute_uv=False)
175
176print("batch size:", batch_size)
177print("sequence length:", sequence_length)
178print("channels:", channels)
179print("order:", order)
180print("tail samples:", tail)
181print("real scalar parameters:", filt.parameter_count())
182print("max streaming norm-preservation error:", f"{max_norm_error:.3e}")
183print("max finite-adjoint reconstruction error:", f"{max_adjoint_error:.3e}")
184print("singular value range:", f"[{singular_values.min():.6f}, {singular_values.max():.6f}]")
185print("causal forward: output at n uses current input and previous lattice state")
186print("finite adjoint: reconstruction is time-domain but noncausal over the block")
187print(
188 "takeaway: matrix lattice filters can parameterize streaming norm-preserving convolution blocks"
189)
190
191_save_figures(
192 input_norms=input_norms,
193 output_norms=output_norms,
194 x=x,
195 y=y,
196 x_hat=x_hat,
197 omega_probe=omega_probe,
198 singular_values=singular_values,
199)