Matrix-lattice all-pass runtime benchmark¶
Tutorial goal
Compare compiled matrix-lattice frequency-response evaluation with the NumPy reference evaluator.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
Matrix lattice filters are most often used as compact frequency-dependent multichannel all-pass/scattering responses. This benchmark measures the response-evaluation runtime for different channel dimensions and lattice orders, comparing the compiled C++ evaluator with the small NumPy reference implementation.
Key idea and equations¶
The benchmark reports
along with the relative difference between implementations and the maximum unitarity error over the frequency grid.
How to read the result¶
Look for relative differences near numerical precision, small unitarity error, and speedups above one for larger frequency grids/orders.
Run command¶
python benchmarks/matrix_lattice_runtime.py --dims 2 3 4 --orders 2 4 8 --n-freq 1024 --repeats 2 --n-threads 1 --output docs/benchmarks/generated/_artifacts/matrix_lattice_runtime/matrix-lattice-runtime.json
Run status¶
Return code: 0
Visual and data readout¶
When the benchmark gallery is built with results, this page embeds PNG summaries generated from the same JSON/CSV artifacts. The raw data stay available below as downloads so exact numbers remain reproducible without making the public page read like console output.
Figures¶
matrix_lattice_runtime_error_summary.png¶
matrix_lattice_runtime_quality_summary.png¶
matrix_lattice_runtime_speedup_summary.png¶
matrix_lattice_runtime_timing_comparison.png¶
Generated data files¶
Source code¶
1"""Benchmark matrix-lattice frequency-response evaluation runtime."""
2
3from __future__ import annotations
4
5import argparse
6import json
7import platform
8import statistics
9import time
10from pathlib import Path
11
12import numpy as np
13
14import lattice_dsp as ld
15from lattice_dsp import matrix_lattice as matrix_lattice_module
16
17
18def make_filter(dim: int, order: int, seed: int) -> ld.MatrixLatticeAllPass:
19 rng = np.random.default_rng(seed)
20 reflections = [
21 ld.contractive_matrix_from_raw(
22 0.22 * (rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim))), margin=1e-6
23 )
24 for _ in range(order)
25 ]
26 residue = ld.unitary_polar_factor(
27 rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim))
28 )
29 return ld.MatrixLatticeAllPass(reflections, residue=residue)
30
31
32def median_runtime(fn, repeats: int) -> float:
33 times = []
34 for _ in range(repeats):
35 start = time.perf_counter()
36 fn()
37 times.append(time.perf_counter() - start)
38 return float(statistics.median(times))
39
40
41def response_unitarity_error(response: np.ndarray) -> float:
42 eye = np.eye(response.shape[1], dtype=np.complex128)
43 return float(max(np.linalg.norm(g.conj().T @ g - eye, ord="fro") for g in response))
44
45
46def run_case(
47 dim: int, order: int, n_freq: int, repeats: int, n_threads: int, seed: int
48) -> dict[str, float | int | bool]:
49 filt = make_filter(dim, order, seed)
50 omega = np.linspace(0.0, np.pi, n_freq)
51
52 # Warm both paths before timing.
53 compiled_response = filt.frequency_response(omega, n_threads=n_threads)
54 python_response = matrix_lattice_module._frequency_response_numpy(
55 filt.stage_blocks, filt.residue, omega
56 ) # noqa: SLF001
57
58 compiled_s = median_runtime(
59 lambda: filt.frequency_response(omega, n_threads=n_threads), repeats
60 )
61 python_s = median_runtime(
62 lambda: matrix_lattice_module._frequency_response_numpy(
63 filt.stage_blocks, filt.residue, omega
64 ),
65 repeats, # noqa: SLF001
66 )
67 rel_diff = float(
68 np.linalg.norm(compiled_response - python_response)
69 / max(np.linalg.norm(python_response), 1e-30)
70 )
71 speedup = python_s / compiled_s if compiled_s > 0.0 else float("inf")
72
73 return {
74 "dim": dim,
75 "order": order,
76 "n_freq": n_freq,
77 "n_threads": n_threads,
78 "compiled_s": compiled_s,
79 "python_s": python_s,
80 "speedup": speedup,
81 "relative_difference": rel_diff,
82 "unitarity_error": response_unitarity_error(compiled_response),
83 "max_reflection_singular_value": filt.max_reflection_singular_value(),
84 "real_scalar_parameter_count": filt.parameter_count(),
85 }
86
87
88def parse_args() -> argparse.Namespace:
89 parser = argparse.ArgumentParser(description=__doc__)
90 parser.add_argument("--dims", type=int, nargs="+", default=[2, 3, 4])
91 parser.add_argument("--orders", type=int, nargs="+", default=[2, 4, 8])
92 parser.add_argument("--n-freq", type=int, default=1024)
93 parser.add_argument("--repeats", type=int, default=3)
94 parser.add_argument("--n-threads", type=int, default=1)
95 parser.add_argument("--seed", type=int, default=707)
96 parser.add_argument("--output", type=Path, default=Path("reports/matrix-lattice-runtime.json"))
97 return parser.parse_args()
98
99
100def main() -> None:
101 args = parse_args()
102 if args.n_freq <= 0:
103 raise SystemExit("--n-freq must be positive")
104 if args.repeats <= 0:
105 raise SystemExit("--repeats must be positive")
106
107 rows = []
108 for dim in args.dims:
109 for order in args.orders:
110 rows.append(
111 run_case(
112 dim,
113 order,
114 args.n_freq,
115 args.repeats,
116 args.n_threads,
117 args.seed + 100 * dim + order,
118 )
119 )
120
121 payload = {
122 "python": platform.python_version(),
123 "platform": platform.platform(),
124 "has_openmp": bool(getattr(ld, "HAS_OPENMP", False)),
125 "n_freq": args.n_freq,
126 "repeats": args.repeats,
127 "n_threads": args.n_threads,
128 "description": "MatrixLatticeAllPass frequency-response benchmark. The compiled C++ path is compared with the NumPy reference evaluator.",
129 "results": rows,
130 }
131 args.output.parent.mkdir(parents=True, exist_ok=True)
132 args.output.write_text(json.dumps(payload, indent=2), encoding="utf-8")
133
134 print(json.dumps({k: v for k, v in payload.items() if k != "results"}, indent=2))
135 print()
136 print(
137 f"{'dim':>4} {'order':>5} {'params':>8} {'compiled_s':>11} {'python_s':>10} {'speedup':>8} {'unitarity':>11} {'rel_diff':>10}"
138 )
139 print("-" * 86)
140 for row in rows:
141 print(
142 f"{row['dim']:4d} {row['order']:5d} {row['real_scalar_parameter_count']:8d} "
143 f"{row['compiled_s']:11.5f} {row['python_s']:10.5f} {row['speedup']:8.2f} "
144 f"{row['unitarity_error']:11.2e} {row['relative_difference']:10.2e}"
145 )
146 print(f"\nWrote {args.output}")
147
148
149if __name__ == "__main__":
150 main()