Finite-section AAK/Nehari IIR reduction benchmark¶
Tutorial goal
Compare finite-Hankel and finite-section AAK/Nehari candidate reductions on the same stable SISO IIR filters.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
The finite-Hankel reducer and the finite-section AAK/Nehari candidate workflow are both useful baselines. This benchmark runs them side by side on compressible stable SISO IIR filters and measures the practical tradeoff: reduction cost, filtering speedup, end-to-end speedup including reduction, SNR, magnitude-response error, pole radius, and break-even samples per channel.
The benchmark is deliberately finite-section. It is not a claim of exact infinite-dimensional AAK/Nehari optimality; it is a reproducible comparison of the mature baselines currently implemented in the package.
Key idea and equations¶
The end-to-end speedup includes the one-time reduction cost,
The break-even sample count estimates when preprocessing has paid for itself,
How to read the result¶
Look for stable reduced models with useful SNR/magnitude error and end-to-end speedup above one for the intended signal length.
Run command¶
python benchmarks/finite_aak_iir_reduction_speedup.py --full-orders 8 16 --target-orders 3 4 6 8 --channels 16 --samples 12000 --repeats 2 --n-impulse 384 --hankel-rows 48 --hankel-cols 48 --output docs/benchmarks/generated/_artifacts/finite_aak_iir_reduction_speedup/finite-aak-iir-reduction-speedup.json
Visual and data readout¶
When the benchmark gallery is built with results, this page embeds PNG summaries generated from the same JSON/CSV artifacts. The raw data stay available below as downloads so exact numbers remain reproducible without making the public page read like console output.
Source code¶
1from __future__ import annotations
2
3import argparse
4import json
5import math
6import platform
7import statistics
8import time
9from pathlib import Path
10from typing import Any
11from collections.abc import Callable
12
13import numpy as np
14
15import lattice_dsp as ld
16
17
18def median_time(fn: Callable[[], Any], repeats: int) -> tuple[float, Any]:
19 times: list[float] = []
20 result: Any = None
21 for _ in range(repeats):
22 t0 = time.perf_counter()
23 result = fn()
24 times.append(time.perf_counter() - t0)
25 return statistics.median(times), result
26
27
28def impulse_from_poles(poles: np.ndarray, weights: np.ndarray, n_terms: int) -> np.ndarray:
29 n = np.arange(n_terms, dtype=np.float64)
30 return np.sum(weights[:, None] * poles[:, None] ** n[None, :], axis=0)
31
32
33def numerator_from_impulse_and_denominator(
34 impulse: np.ndarray, denominator: np.ndarray
35) -> np.ndarray:
36 order = denominator.size - 1
37 numerator = np.zeros(order + 1, dtype=np.float64)
38 for i in range(order + 1):
39 numerator[i] = sum(float(denominator[j]) * float(impulse[i - j]) for j in range(i + 1))
40 return numerator
41
42
43def compressible_iir(order: int, rng: np.random.Generator, n_impulse: int) -> dict[str, np.ndarray]:
44 """Build a stable real-pole IIR with decaying modal weights.
45
46 The construction gives the reduction methods a meaningful compressible model:
47 a few slow poles dominate, while many smaller modes are cheap to discard.
48 """
49
50 if order <= 1:
51 raise ValueError("order must be greater than one")
52
53 slow = np.array([0.92, 0.78, -0.58, 0.42], dtype=np.float64)
54 remaining = order - slow.size
55 if remaining > 0:
56 grid = np.linspace(0.30, 0.04, remaining)
57 signs = np.where(np.arange(remaining) % 2 == 0, 1.0, -1.0)
58 small = signs * grid
59 poles = np.concatenate([slow, small])
60 else:
61 poles = slow[:order]
62
63 # Small deterministic jitter avoids perfectly repeated benchmark cases while
64 # preserving stability and reproducibility.
65 jitter = rng.uniform(-0.008, 0.008, size=poles.size)
66 poles = np.clip(poles + jitter, -0.94, 0.94)
67
68 weights = np.zeros(order, dtype=np.float64)
69 weights[: min(4, order)] = np.array([1.0, 0.28, -0.17, 0.08], dtype=np.float64)[: min(4, order)]
70 if order > 4:
71 weights[4:] = (
72 0.035
73 * np.exp(-np.arange(order - 4) / 4.0)
74 * np.where(np.arange(order - 4) % 2 == 0, 1.0, -1.0)
75 )
76
77 denominator = np.asarray(np.poly(poles), dtype=np.float64)
78 impulse = impulse_from_poles(poles, weights, n_impulse)
79 numerator = numerator_from_impulse_and_denominator(impulse, denominator)
80 reflection = np.asarray(ld.denominator_to_reflection(denominator.tolist()), dtype=np.float64)
81 return {
82 "poles": poles,
83 "weights": weights,
84 "denominator": denominator,
85 "numerator": numerator,
86 "reflection": reflection,
87 "impulse": impulse,
88 }
89
90
91def process(
92 reflection: np.ndarray | list[float], numerator: np.ndarray | list[float], x: np.ndarray
93) -> np.ndarray:
94 return np.asarray(
95 ld.process_batch(list(map(float, reflection)), list(map(float, numerator)), x),
96 dtype=np.float64,
97 )
98
99
100def snr_db(reference: np.ndarray, estimate: np.ndarray) -> float:
101 err = reference - estimate
102 p_ref = float(np.mean(reference * reference))
103 p_err = float(np.mean(err * err))
104 return 10.0 * math.log10((p_ref + 1e-30) / (p_err + 1e-30))
105
106
107def pole_radius_from_denominator(denominator: np.ndarray | list[float]) -> float:
108 denominator_arr = np.asarray(denominator, dtype=np.float64)
109 roots = np.roots(denominator_arr)
110 return float(np.max(np.abs(roots))) if roots.size else 0.0
111
112
113def frequency_response(
114 denominator: np.ndarray, numerator: np.ndarray, n_freq: int = 512
115) -> np.ndarray:
116 w = np.linspace(0.0, math.pi, n_freq)
117 z = np.exp(-1j * w)
118 num = np.zeros_like(z, dtype=np.complex128)
119 den = np.zeros_like(z, dtype=np.complex128)
120 for k, coeff in enumerate(numerator):
121 num += coeff * z**k
122 for k, coeff in enumerate(denominator):
123 den += coeff * z**k
124 return num / den
125
126
127def max_magnitude_error_db(
128 full_denominator: np.ndarray,
129 full_numerator: np.ndarray,
130 reduced_denominator: np.ndarray,
131 reduced_numerator: np.ndarray,
132) -> float:
133 h_full = frequency_response(full_denominator, full_numerator)
134 h_reduced = frequency_response(reduced_denominator, reduced_numerator)
135 full_db = 20.0 * np.log10(np.maximum(np.abs(h_full), 1e-14))
136 reduced_db = 20.0 * np.log10(np.maximum(np.abs(h_reduced), 1e-14))
137 return float(np.max(np.abs(full_db - reduced_db)))
138
139
140def break_even_samples_per_channel(
141 reduction_time_s: float, full_time_s: float, reduced_time_s: float, channels: int, samples: int
142) -> float | None:
143 full_per_sample = full_time_s / (channels * samples)
144 reduced_per_sample = reduced_time_s / (channels * samples)
145 delta = full_per_sample - reduced_per_sample
146 if delta <= 0.0:
147 return None
148 return reduction_time_s / delta / channels
149
150
151def serializable_row(row: dict[str, Any]) -> dict[str, Any]:
152 out: dict[str, Any] = {}
153 for key, value in row.items():
154 if isinstance(value, np.generic):
155 out[key] = value.item()
156 elif isinstance(value, np.ndarray):
157 out[key] = value.tolist()
158 else:
159 out[key] = value
160 return out
161
162
163def evaluate_reduced_model(
164 *,
165 method: str,
166 full_order: int,
167 target_order: int,
168 reduction_time_s: float,
169 full_model: dict[str, np.ndarray],
170 full_time_s: float,
171 y_full: np.ndarray,
172 reduced_reflection: np.ndarray,
173 reduced_numerator: np.ndarray,
174 reduced_denominator: np.ndarray,
175 relative_impulse_error: float,
176 accepted: bool,
177 stable: bool,
178 x: np.ndarray,
179 repeats: int,
180) -> dict[str, Any]:
181 reduced_time_s, y_reduced = median_time(
182 lambda: process(reduced_reflection, reduced_numerator, x), repeats
183 )
184 rel_mse = float(np.mean((y_full - y_reduced) ** 2) / (np.mean(y_full**2) + 1e-30))
185 end_to_end_speedup = (
186 full_time_s / (reduction_time_s + reduced_time_s)
187 if reduction_time_s + reduced_time_s > 0
188 else None
189 )
190 be = break_even_samples_per_channel(
191 reduction_time_s, full_time_s, reduced_time_s, x.shape[0], x.shape[1]
192 )
193 return {
194 "method": method,
195 "full_order": int(full_order),
196 "target_order": int(target_order),
197 "stable": bool(stable),
198 "accepted": bool(accepted),
199 "reduction_time_s": float(reduction_time_s),
200 "full_filter_median_s": float(full_time_s),
201 "reduced_filter_median_s": float(reduced_time_s),
202 "filter_speedup": float(full_time_s / reduced_time_s) if reduced_time_s > 0 else None,
203 "amortized_end_to_end_speedup": float(end_to_end_speedup)
204 if end_to_end_speedup is not None
205 else None,
206 "break_even_samples_per_channel": float(be) if be is not None else None,
207 "relative_impulse_error": float(relative_impulse_error),
208 "rel_mse_on_random_batch": rel_mse,
209 "snr_db_on_random_batch": snr_db(y_full, y_reduced),
210 "max_magnitude_error_db": max_magnitude_error_db(
211 full_model["denominator"],
212 full_model["numerator"],
213 reduced_denominator,
214 reduced_numerator,
215 ),
216 "max_pole_radius": pole_radius_from_denominator(reduced_denominator),
217 }
218
219
220def main() -> None:
221 parser = argparse.ArgumentParser(
222 description="Compare finite-Hankel and finite-section AAK/Nehari SISO IIR reduction workflows."
223 )
224 parser.add_argument("--full-orders", type=int, nargs="+", default=[8, 16, 32])
225 parser.add_argument("--target-orders", type=int, nargs="+", default=[3, 4, 6, 8, 12])
226 parser.add_argument("--channels", type=int, default=32)
227 parser.add_argument("--samples", type=int, default=30000)
228 parser.add_argument("--repeats", type=int, default=3)
229 parser.add_argument("--n-impulse", type=int, default=768)
230 parser.add_argument("--hankel-rows", type=int, default=96)
231 parser.add_argument("--hankel-cols", type=int, default=96)
232 parser.add_argument("--seed", type=int, default=314)
233 parser.add_argument("--output", default="reports/finite-aak-iir-reduction-speedup.json")
234 args = parser.parse_args()
235
236 rng = np.random.default_rng(args.seed)
237 x = rng.normal(size=(args.channels, args.samples)).astype(np.float64)
238 rows_out: list[dict[str, Any]] = []
239 criteria = ld.FiniteNehariCandidateCriteria(
240 max_tail_error=1.0,
241 max_rational_error=1.0,
242 max_pole_radius=0.999,
243 )
244
245 for full_order in args.full_orders:
246 if full_order <= 1:
247 continue
248 full_model = compressible_iir(full_order, rng, args.n_impulse)
249 full_time_s, y_full = median_time(
250 lambda fm=full_model: process(fm["reflection"], fm["numerator"], x),
251 args.repeats,
252 )
253
254 for target_order in args.target_orders:
255 if target_order >= full_order:
256 continue
257 if target_order > min(args.hankel_rows, args.hankel_cols):
258 continue
259
260 # Finite-Hankel / Ho--Kalman baseline.
261 try:
262 reduce_time, hankel = median_time(
263 lambda ro=target_order, fm=full_model: ld.finite_hankel_reduce_iir(
264 fm["reflection"].tolist(),
265 fm["numerator"].tolist(),
266 reduced_order=ro,
267 n_impulse=args.n_impulse,
268 rows=args.hankel_rows,
269 cols=args.hankel_cols,
270 ),
271 1,
272 )
273 if bool(hankel["stable"]) and hankel.get("reflection"):
274 rows_out.append(
275 serializable_row(
276 evaluate_reduced_model(
277 method="finite_hankel",
278 full_order=full_order,
279 target_order=target_order,
280 reduction_time_s=reduce_time,
281 full_model=full_model,
282 full_time_s=full_time_s,
283 y_full=y_full,
284 reduced_reflection=np.asarray(
285 hankel["reflection"], dtype=np.float64
286 ),
287 reduced_numerator=np.asarray(hankel["numerator"], dtype=np.float64),
288 reduced_denominator=np.asarray(
289 hankel["denominator"], dtype=np.float64
290 ),
291 relative_impulse_error=float(hankel["relative_impulse_error"]),
292 accepted=True,
293 stable=True,
294 x=x,
295 repeats=args.repeats,
296 )
297 )
298 )
299 else:
300 rows_out.append(
301 {
302 "method": "finite_hankel",
303 "full_order": full_order,
304 "target_order": target_order,
305 "stable": bool(hankel.get("stable", False)),
306 "accepted": False,
307 "reduction_time_s": float(reduce_time),
308 "error": "reduced model was not stable in scalar lattice coordinates",
309 }
310 )
311 except Exception as exc: # noqa: BLE001 - benchmark rows should report failures.
312 rows_out.append(
313 {
314 "method": "finite_hankel",
315 "full_order": full_order,
316 "target_order": target_order,
317 "stable": False,
318 "accepted": False,
319 "error": str(exc),
320 }
321 )
322
323 # Finite-section AAK/Nehari candidate using the same target order.
324 try:
325 reduce_time, aak = median_time(
326 lambda ro=target_order, fm=full_model: ld.finite_aak_reduce_iir(
327 fm["reflection"],
328 fm["numerator"],
329 ranks=[ro],
330 n_impulse=args.n_impulse,
331 rows=args.hankel_rows,
332 cols=args.hankel_cols,
333 criteria=criteria,
334 attach_certificate=True,
335 ),
336 1,
337 )
338 if bool(aak["stable"]) and aak["reduced_reflection"].size:
339 rows_out.append(
340 serializable_row(
341 evaluate_reduced_model(
342 method="finite_aak_candidate",
343 full_order=full_order,
344 target_order=target_order,
345 reduction_time_s=reduce_time,
346 full_model=full_model,
347 full_time_s=full_time_s,
348 y_full=y_full,
349 reduced_reflection=np.asarray(
350 aak["reduced_reflection"], dtype=np.float64
351 ),
352 reduced_numerator=np.asarray(
353 aak["reduced_numerator"], dtype=np.float64
354 ),
355 reduced_denominator=np.asarray(
356 aak["reduced_denominator"], dtype=np.float64
357 ),
358 relative_impulse_error=float(aak["relative_impulse_error"]),
359 accepted=bool(aak["accepted"]),
360 stable=bool(aak["stable"]),
361 x=x,
362 repeats=args.repeats,
363 )
364 )
365 )
366 else:
367 rows_out.append(
368 {
369 "method": "finite_aak_candidate",
370 "full_order": full_order,
371 "target_order": target_order,
372 "stable": bool(aak.get("stable", False)),
373 "accepted": False,
374 "reduction_time_s": float(reduce_time),
375 "error": "selected candidate was not stable in scalar lattice coordinates",
376 }
377 )
378 except Exception as exc: # noqa: BLE001 - benchmark rows should report failures.
379 rows_out.append(
380 {
381 "method": "finite_aak_candidate",
382 "full_order": full_order,
383 "target_order": target_order,
384 "stable": False,
385 "accepted": False,
386 "error": str(exc),
387 }
388 )
389
390 result = {
391 "metadata": {
392 "python": platform.python_version(),
393 "platform": platform.platform(),
394 "has_openmp": bool(ld.HAS_OPENMP),
395 "channels": args.channels,
396 "samples": args.samples,
397 "repeats": args.repeats,
398 "n_impulse": args.n_impulse,
399 "hankel_rows": args.hankel_rows,
400 "hankel_cols": args.hankel_cols,
401 "seed": args.seed,
402 "description": (
403 "Finite-Hankel versus finite-section AAK/Nehari SISO IIR reduction benchmark. "
404 "Both methods are finite-section baselines; neither is claimed to be a full infinite-dimensional solver."
405 ),
406 },
407 "rows": rows_out,
408 }
409
410 output = Path(args.output)
411 output.parent.mkdir(parents=True, exist_ok=True)
412 output.write_text(json.dumps(result, indent=2), encoding="utf-8")
413
414 print(json.dumps(result["metadata"], indent=2))
415 print()
416 print(
417 f"{'method':>21s} {'full':>5s} {'red':>5s} {'stable':>7s} {'reduce_s':>10s} "
418 f"{'filter_x':>9s} {'end2end_x':>10s} {'SNR':>8s} {'mag_err':>9s} {'break_even/ch':>15s}"
419 )
420 print("-" * 115)
421 for row in rows_out:
422 if "error" in row:
423 print(
424 f"{row['method']:>21s} {row['full_order']:5d} {row['target_order']:5d} "
425 f"{str(row.get('stable', False)):>7s} {float(row.get('reduction_time_s', 0.0)):10.4f} ERROR: {row['error']}"
426 )
427 continue
428 be = row["break_even_samples_per_channel"]
429 be_text = "n/a" if be is None else f"{be:.0f}"
430 print(
431 f"{row['method']:>21s} {row['full_order']:5d} {row['target_order']:5d} {str(row['stable']):>7s} "
432 f"{row['reduction_time_s']:10.4f} {row['filter_speedup']:9.2f} "
433 f"{row['amortized_end_to_end_speedup']:10.2f} {row['snr_db_on_random_batch']:8.2f} "
434 f"{row['max_magnitude_error_db']:9.3f} {be_text:>15s}"
435 )
436
437 print()
438 print(f"Wrote {output}")
439
440
441if __name__ == "__main__":
442 main()