Finite-section AAK/Nehari IIR reduction benchmark

Tutorial goal

Compare finite-Hankel and finite-section AAK/Nehari candidate reductions on the same stable SISO IIR filters.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

The finite-Hankel reducer and the finite-section AAK/Nehari candidate workflow are both useful baselines. This benchmark runs them side by side on compressible stable SISO IIR filters and measures the practical tradeoff: reduction cost, filtering speedup, end-to-end speedup including reduction, SNR, magnitude-response error, pole radius, and break-even samples per channel.

The benchmark is deliberately finite-section. It is not a claim of exact infinite-dimensional AAK/Nehari optimality; it is a reproducible comparison of the mature baselines currently implemented in the package.

Key idea and equations

The end-to-end speedup includes the one-time reduction cost,

\[S_{end-to-end} = \frac{t_{full}}{t_{reduce}+t_{reduced}}.\]

The break-even sample count estimates when preprocessing has paid for itself,

\[N_{break-even} = \frac{t_{reduce}} {t_{full/sample}-t_{reduced/sample}}.\]

How to read the result

Look for stable reduced models with useful SNR/magnitude error and end-to-end speedup above one for the intended signal length.

Run command

python benchmarks/finite_aak_iir_reduction_speedup.py --full-orders 8 16 --target-orders 3 4 6 8 --channels 16 --samples 12000 --repeats 2 --n-impulse 384 --hankel-rows 48 --hankel-cols 48 --output docs/benchmarks/generated/_artifacts/finite_aak_iir_reduction_speedup/finite-aak-iir-reduction-speedup.json

Run status

Return code: 0

Visual and data readout

When the benchmark gallery is built with results, this page embeds PNG summaries generated from the same JSON/CSV artifacts. The raw data stay available below as downloads so exact numbers remain reproducible without making the public page read like console output.

Figures

finite aak iir reduction speedup error summary

finite_aak_iir_reduction_speedup_error_summary.png

finite aak iir reduction speedup quality summary

finite_aak_iir_reduction_speedup_quality_summary.png

finite aak iir reduction speedup speedup summary

finite_aak_iir_reduction_speedup_speedup_summary.png

finite aak iir reduction speedup timing comparison

finite_aak_iir_reduction_speedup_timing_comparison.png

Generated data files

Source code

  1from __future__ import annotations
  2
  3import argparse
  4import json
  5import math
  6import platform
  7import statistics
  8import time
  9from pathlib import Path
 10from typing import Any
 11from collections.abc import Callable
 12
 13import numpy as np
 14
 15import lattice_dsp as ld
 16
 17
 18def median_time(fn: Callable[[], Any], repeats: int) -> tuple[float, Any]:
 19    times: list[float] = []
 20    result: Any = None
 21    for _ in range(repeats):
 22        t0 = time.perf_counter()
 23        result = fn()
 24        times.append(time.perf_counter() - t0)
 25    return statistics.median(times), result
 26
 27
 28def impulse_from_poles(poles: np.ndarray, weights: np.ndarray, n_terms: int) -> np.ndarray:
 29    n = np.arange(n_terms, dtype=np.float64)
 30    return np.sum(weights[:, None] * poles[:, None] ** n[None, :], axis=0)
 31
 32
 33def numerator_from_impulse_and_denominator(
 34    impulse: np.ndarray, denominator: np.ndarray
 35) -> np.ndarray:
 36    order = denominator.size - 1
 37    numerator = np.zeros(order + 1, dtype=np.float64)
 38    for i in range(order + 1):
 39        numerator[i] = sum(float(denominator[j]) * float(impulse[i - j]) for j in range(i + 1))
 40    return numerator
 41
 42
 43def compressible_iir(order: int, rng: np.random.Generator, n_impulse: int) -> dict[str, np.ndarray]:
 44    """Build a stable real-pole IIR with decaying modal weights.
 45
 46    The construction gives the reduction methods a meaningful compressible model:
 47    a few slow poles dominate, while many smaller modes are cheap to discard.
 48    """
 49
 50    if order <= 1:
 51        raise ValueError("order must be greater than one")
 52
 53    slow = np.array([0.92, 0.78, -0.58, 0.42], dtype=np.float64)
 54    remaining = order - slow.size
 55    if remaining > 0:
 56        grid = np.linspace(0.30, 0.04, remaining)
 57        signs = np.where(np.arange(remaining) % 2 == 0, 1.0, -1.0)
 58        small = signs * grid
 59        poles = np.concatenate([slow, small])
 60    else:
 61        poles = slow[:order]
 62
 63    # Small deterministic jitter avoids perfectly repeated benchmark cases while
 64    # preserving stability and reproducibility.
 65    jitter = rng.uniform(-0.008, 0.008, size=poles.size)
 66    poles = np.clip(poles + jitter, -0.94, 0.94)
 67
 68    weights = np.zeros(order, dtype=np.float64)
 69    weights[: min(4, order)] = np.array([1.0, 0.28, -0.17, 0.08], dtype=np.float64)[: min(4, order)]
 70    if order > 4:
 71        weights[4:] = (
 72            0.035
 73            * np.exp(-np.arange(order - 4) / 4.0)
 74            * np.where(np.arange(order - 4) % 2 == 0, 1.0, -1.0)
 75        )
 76
 77    denominator = np.asarray(np.poly(poles), dtype=np.float64)
 78    impulse = impulse_from_poles(poles, weights, n_impulse)
 79    numerator = numerator_from_impulse_and_denominator(impulse, denominator)
 80    reflection = np.asarray(ld.denominator_to_reflection(denominator.tolist()), dtype=np.float64)
 81    return {
 82        "poles": poles,
 83        "weights": weights,
 84        "denominator": denominator,
 85        "numerator": numerator,
 86        "reflection": reflection,
 87        "impulse": impulse,
 88    }
 89
 90
 91def process(
 92    reflection: np.ndarray | list[float], numerator: np.ndarray | list[float], x: np.ndarray
 93) -> np.ndarray:
 94    return np.asarray(
 95        ld.process_batch(list(map(float, reflection)), list(map(float, numerator)), x),
 96        dtype=np.float64,
 97    )
 98
 99
100def snr_db(reference: np.ndarray, estimate: np.ndarray) -> float:
101    err = reference - estimate
102    p_ref = float(np.mean(reference * reference))
103    p_err = float(np.mean(err * err))
104    return 10.0 * math.log10((p_ref + 1e-30) / (p_err + 1e-30))
105
106
107def pole_radius_from_denominator(denominator: np.ndarray | list[float]) -> float:
108    denominator_arr = np.asarray(denominator, dtype=np.float64)
109    roots = np.roots(denominator_arr)
110    return float(np.max(np.abs(roots))) if roots.size else 0.0
111
112
113def frequency_response(
114    denominator: np.ndarray, numerator: np.ndarray, n_freq: int = 512
115) -> np.ndarray:
116    w = np.linspace(0.0, math.pi, n_freq)
117    z = np.exp(-1j * w)
118    num = np.zeros_like(z, dtype=np.complex128)
119    den = np.zeros_like(z, dtype=np.complex128)
120    for k, coeff in enumerate(numerator):
121        num += coeff * z**k
122    for k, coeff in enumerate(denominator):
123        den += coeff * z**k
124    return num / den
125
126
127def max_magnitude_error_db(
128    full_denominator: np.ndarray,
129    full_numerator: np.ndarray,
130    reduced_denominator: np.ndarray,
131    reduced_numerator: np.ndarray,
132) -> float:
133    h_full = frequency_response(full_denominator, full_numerator)
134    h_reduced = frequency_response(reduced_denominator, reduced_numerator)
135    full_db = 20.0 * np.log10(np.maximum(np.abs(h_full), 1e-14))
136    reduced_db = 20.0 * np.log10(np.maximum(np.abs(h_reduced), 1e-14))
137    return float(np.max(np.abs(full_db - reduced_db)))
138
139
140def break_even_samples_per_channel(
141    reduction_time_s: float, full_time_s: float, reduced_time_s: float, channels: int, samples: int
142) -> float | None:
143    full_per_sample = full_time_s / (channels * samples)
144    reduced_per_sample = reduced_time_s / (channels * samples)
145    delta = full_per_sample - reduced_per_sample
146    if delta <= 0.0:
147        return None
148    return reduction_time_s / delta / channels
149
150
151def serializable_row(row: dict[str, Any]) -> dict[str, Any]:
152    out: dict[str, Any] = {}
153    for key, value in row.items():
154        if isinstance(value, np.generic):
155            out[key] = value.item()
156        elif isinstance(value, np.ndarray):
157            out[key] = value.tolist()
158        else:
159            out[key] = value
160    return out
161
162
163def evaluate_reduced_model(
164    *,
165    method: str,
166    full_order: int,
167    target_order: int,
168    reduction_time_s: float,
169    full_model: dict[str, np.ndarray],
170    full_time_s: float,
171    y_full: np.ndarray,
172    reduced_reflection: np.ndarray,
173    reduced_numerator: np.ndarray,
174    reduced_denominator: np.ndarray,
175    relative_impulse_error: float,
176    accepted: bool,
177    stable: bool,
178    x: np.ndarray,
179    repeats: int,
180) -> dict[str, Any]:
181    reduced_time_s, y_reduced = median_time(
182        lambda: process(reduced_reflection, reduced_numerator, x), repeats
183    )
184    rel_mse = float(np.mean((y_full - y_reduced) ** 2) / (np.mean(y_full**2) + 1e-30))
185    end_to_end_speedup = (
186        full_time_s / (reduction_time_s + reduced_time_s)
187        if reduction_time_s + reduced_time_s > 0
188        else None
189    )
190    be = break_even_samples_per_channel(
191        reduction_time_s, full_time_s, reduced_time_s, x.shape[0], x.shape[1]
192    )
193    return {
194        "method": method,
195        "full_order": int(full_order),
196        "target_order": int(target_order),
197        "stable": bool(stable),
198        "accepted": bool(accepted),
199        "reduction_time_s": float(reduction_time_s),
200        "full_filter_median_s": float(full_time_s),
201        "reduced_filter_median_s": float(reduced_time_s),
202        "filter_speedup": float(full_time_s / reduced_time_s) if reduced_time_s > 0 else None,
203        "amortized_end_to_end_speedup": float(end_to_end_speedup)
204        if end_to_end_speedup is not None
205        else None,
206        "break_even_samples_per_channel": float(be) if be is not None else None,
207        "relative_impulse_error": float(relative_impulse_error),
208        "rel_mse_on_random_batch": rel_mse,
209        "snr_db_on_random_batch": snr_db(y_full, y_reduced),
210        "max_magnitude_error_db": max_magnitude_error_db(
211            full_model["denominator"],
212            full_model["numerator"],
213            reduced_denominator,
214            reduced_numerator,
215        ),
216        "max_pole_radius": pole_radius_from_denominator(reduced_denominator),
217    }
218
219
220def main() -> None:
221    parser = argparse.ArgumentParser(
222        description="Compare finite-Hankel and finite-section AAK/Nehari SISO IIR reduction workflows."
223    )
224    parser.add_argument("--full-orders", type=int, nargs="+", default=[8, 16, 32])
225    parser.add_argument("--target-orders", type=int, nargs="+", default=[3, 4, 6, 8, 12])
226    parser.add_argument("--channels", type=int, default=32)
227    parser.add_argument("--samples", type=int, default=30000)
228    parser.add_argument("--repeats", type=int, default=3)
229    parser.add_argument("--n-impulse", type=int, default=768)
230    parser.add_argument("--hankel-rows", type=int, default=96)
231    parser.add_argument("--hankel-cols", type=int, default=96)
232    parser.add_argument("--seed", type=int, default=314)
233    parser.add_argument("--output", default="reports/finite-aak-iir-reduction-speedup.json")
234    args = parser.parse_args()
235
236    rng = np.random.default_rng(args.seed)
237    x = rng.normal(size=(args.channels, args.samples)).astype(np.float64)
238    rows_out: list[dict[str, Any]] = []
239    criteria = ld.FiniteNehariCandidateCriteria(
240        max_tail_error=1.0,
241        max_rational_error=1.0,
242        max_pole_radius=0.999,
243    )
244
245    for full_order in args.full_orders:
246        if full_order <= 1:
247            continue
248        full_model = compressible_iir(full_order, rng, args.n_impulse)
249        full_time_s, y_full = median_time(
250            lambda fm=full_model: process(fm["reflection"], fm["numerator"], x),
251            args.repeats,
252        )
253
254        for target_order in args.target_orders:
255            if target_order >= full_order:
256                continue
257            if target_order > min(args.hankel_rows, args.hankel_cols):
258                continue
259
260            # Finite-Hankel / Ho--Kalman baseline.
261            try:
262                reduce_time, hankel = median_time(
263                    lambda ro=target_order, fm=full_model: ld.finite_hankel_reduce_iir(
264                        fm["reflection"].tolist(),
265                        fm["numerator"].tolist(),
266                        reduced_order=ro,
267                        n_impulse=args.n_impulse,
268                        rows=args.hankel_rows,
269                        cols=args.hankel_cols,
270                    ),
271                    1,
272                )
273                if bool(hankel["stable"]) and hankel.get("reflection"):
274                    rows_out.append(
275                        serializable_row(
276                            evaluate_reduced_model(
277                                method="finite_hankel",
278                                full_order=full_order,
279                                target_order=target_order,
280                                reduction_time_s=reduce_time,
281                                full_model=full_model,
282                                full_time_s=full_time_s,
283                                y_full=y_full,
284                                reduced_reflection=np.asarray(
285                                    hankel["reflection"], dtype=np.float64
286                                ),
287                                reduced_numerator=np.asarray(hankel["numerator"], dtype=np.float64),
288                                reduced_denominator=np.asarray(
289                                    hankel["denominator"], dtype=np.float64
290                                ),
291                                relative_impulse_error=float(hankel["relative_impulse_error"]),
292                                accepted=True,
293                                stable=True,
294                                x=x,
295                                repeats=args.repeats,
296                            )
297                        )
298                    )
299                else:
300                    rows_out.append(
301                        {
302                            "method": "finite_hankel",
303                            "full_order": full_order,
304                            "target_order": target_order,
305                            "stable": bool(hankel.get("stable", False)),
306                            "accepted": False,
307                            "reduction_time_s": float(reduce_time),
308                            "error": "reduced model was not stable in scalar lattice coordinates",
309                        }
310                    )
311            except Exception as exc:  # noqa: BLE001 - benchmark rows should report failures.
312                rows_out.append(
313                    {
314                        "method": "finite_hankel",
315                        "full_order": full_order,
316                        "target_order": target_order,
317                        "stable": False,
318                        "accepted": False,
319                        "error": str(exc),
320                    }
321                )
322
323            # Finite-section AAK/Nehari candidate using the same target order.
324            try:
325                reduce_time, aak = median_time(
326                    lambda ro=target_order, fm=full_model: ld.finite_aak_reduce_iir(
327                        fm["reflection"],
328                        fm["numerator"],
329                        ranks=[ro],
330                        n_impulse=args.n_impulse,
331                        rows=args.hankel_rows,
332                        cols=args.hankel_cols,
333                        criteria=criteria,
334                        attach_certificate=True,
335                    ),
336                    1,
337                )
338                if bool(aak["stable"]) and aak["reduced_reflection"].size:
339                    rows_out.append(
340                        serializable_row(
341                            evaluate_reduced_model(
342                                method="finite_aak_candidate",
343                                full_order=full_order,
344                                target_order=target_order,
345                                reduction_time_s=reduce_time,
346                                full_model=full_model,
347                                full_time_s=full_time_s,
348                                y_full=y_full,
349                                reduced_reflection=np.asarray(
350                                    aak["reduced_reflection"], dtype=np.float64
351                                ),
352                                reduced_numerator=np.asarray(
353                                    aak["reduced_numerator"], dtype=np.float64
354                                ),
355                                reduced_denominator=np.asarray(
356                                    aak["reduced_denominator"], dtype=np.float64
357                                ),
358                                relative_impulse_error=float(aak["relative_impulse_error"]),
359                                accepted=bool(aak["accepted"]),
360                                stable=bool(aak["stable"]),
361                                x=x,
362                                repeats=args.repeats,
363                            )
364                        )
365                    )
366                else:
367                    rows_out.append(
368                        {
369                            "method": "finite_aak_candidate",
370                            "full_order": full_order,
371                            "target_order": target_order,
372                            "stable": bool(aak.get("stable", False)),
373                            "accepted": False,
374                            "reduction_time_s": float(reduce_time),
375                            "error": "selected candidate was not stable in scalar lattice coordinates",
376                        }
377                    )
378            except Exception as exc:  # noqa: BLE001 - benchmark rows should report failures.
379                rows_out.append(
380                    {
381                        "method": "finite_aak_candidate",
382                        "full_order": full_order,
383                        "target_order": target_order,
384                        "stable": False,
385                        "accepted": False,
386                        "error": str(exc),
387                    }
388                )
389
390    result = {
391        "metadata": {
392            "python": platform.python_version(),
393            "platform": platform.platform(),
394            "has_openmp": bool(ld.HAS_OPENMP),
395            "channels": args.channels,
396            "samples": args.samples,
397            "repeats": args.repeats,
398            "n_impulse": args.n_impulse,
399            "hankel_rows": args.hankel_rows,
400            "hankel_cols": args.hankel_cols,
401            "seed": args.seed,
402            "description": (
403                "Finite-Hankel versus finite-section AAK/Nehari SISO IIR reduction benchmark. "
404                "Both methods are finite-section baselines; neither is claimed to be a full infinite-dimensional solver."
405            ),
406        },
407        "rows": rows_out,
408    }
409
410    output = Path(args.output)
411    output.parent.mkdir(parents=True, exist_ok=True)
412    output.write_text(json.dumps(result, indent=2), encoding="utf-8")
413
414    print(json.dumps(result["metadata"], indent=2))
415    print()
416    print(
417        f"{'method':>21s} {'full':>5s} {'red':>5s} {'stable':>7s} {'reduce_s':>10s} "
418        f"{'filter_x':>9s} {'end2end_x':>10s} {'SNR':>8s} {'mag_err':>9s} {'break_even/ch':>15s}"
419    )
420    print("-" * 115)
421    for row in rows_out:
422        if "error" in row:
423            print(
424                f"{row['method']:>21s} {row['full_order']:5d} {row['target_order']:5d} "
425                f"{str(row.get('stable', False)):>7s} {float(row.get('reduction_time_s', 0.0)):10.4f}  ERROR: {row['error']}"
426            )
427            continue
428        be = row["break_even_samples_per_channel"]
429        be_text = "n/a" if be is None else f"{be:.0f}"
430        print(
431            f"{row['method']:>21s} {row['full_order']:5d} {row['target_order']:5d} {str(row['stable']):>7s} "
432            f"{row['reduction_time_s']:10.4f} {row['filter_speedup']:9.2f} "
433            f"{row['amortized_end_to_end_speedup']:10.2f} {row['snr_db_on_random_batch']:8.2f} "
434            f"{row['max_magnitude_error_db']:9.3f} {be_text:>15s}"
435        )
436
437    print()
438    print(f"Wrote {output}")
439
440
441if __name__ == "__main__":
442    main()