Finite Hankel reduction amortization benchmark

Tutorial goal

Measure when a one-time finite-Hankel reduction pays off during repeated high-order IIR filtering.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

The finite-Hankel reducer is a preprocessing step. This benchmark makes the speed argument explicit by measuring reduction time, full-order filtering time, reduced-order filtering time, and the break-even number of samples per channel. It separates the method is finite-Hankel/Ho–Kalman reduction, not an exact Nehari/AAK solver.

Key idea and equations

The break-even sample count is estimated as

\[N_{break-even} = \frac{t_{reduce}} {t_{full/sample}-t_{reduced/sample}}.\]

How to read the result

Look for high filter speedup, acceptable SNR/error, stable reduced denominators, and a break-even count that is small relative to the intended workload.

Run command

python benchmarks/hankel_reduction_speedup.py --full-orders 16 32 --reduced-orders 4 8 12 --channels 32 --samples 20000 --repeats 3 --n-impulse 512 --hankel-rows 64 --hankel-cols 64 --output docs/benchmarks/generated/_artifacts/hankel_reduction_speedup/hankel-reduction-speedup.json

Run status

Return code: 0

Visual and data readout

When the benchmark gallery is built with results, this page embeds PNG summaries generated from the same JSON/CSV artifacts. The raw data stay available below as downloads so exact numbers remain reproducible without making the public page read like console output.

Figures

hankel reduction speedup error summary

hankel_reduction_speedup_error_summary.png

hankel reduction speedup quality summary

hankel_reduction_speedup_quality_summary.png

hankel reduction speedup speedup summary

hankel_reduction_speedup_speedup_summary.png

hankel reduction speedup timing comparison

hankel_reduction_speedup_timing_comparison.png

Generated data files

Source code

  1from __future__ import annotations
  2
  3import argparse
  4import json
  5import math
  6import platform
  7import statistics
  8import time
  9from pathlib import Path
 10
 11import numpy as np
 12
 13import lattice_dsp as ld
 14
 15
 16def median_time(fn, repeats: int):
 17    times = []
 18    result = None
 19    for _ in range(repeats):
 20        t0 = time.perf_counter()
 21        result = fn()
 22        times.append(time.perf_counter() - t0)
 23    return statistics.median(times), result
 24
 25
 26def stable_reflection(order: int, rng: np.random.Generator) -> np.ndarray:
 27    # Slow decay gives the reducer something nontrivial to compress while staying
 28    # comfortably inside the scalar lattice stability region.
 29    decay = np.exp(-np.arange(order) / max(8.0, order / 5.0))
 30    signs = rng.choice([-1.0, 1.0], size=order)
 31    jitter = rng.uniform(0.55, 1.0, size=order)
 32    return 0.72 * decay * signs * jitter
 33
 34
 35def numerator_for_order(order: int) -> np.ndarray:
 36    n = np.zeros(order + 1, dtype=float)
 37    n[0] = 1.0
 38    if order >= 2:
 39        n[1] = -0.18
 40        n[2] = 0.08
 41    if order >= 5:
 42        n[4] = -0.04
 43    return n
 44
 45
 46def snr_db(reference: np.ndarray, estimate: np.ndarray) -> float:
 47    err = reference - estimate
 48    p_ref = float(np.mean(reference * reference))
 49    p_err = float(np.mean(err * err))
 50    return 10.0 * math.log10((p_ref + 1e-30) / (p_err + 1e-30))
 51
 52
 53def pole_radius_from_reflection(reflection: list[float]) -> float:
 54    if not reflection:
 55        return 0.0
 56    denominator = np.asarray(ld.reflection_to_denominator(reflection), dtype=float)
 57    roots = np.roots(denominator)
 58    return float(np.max(np.abs(roots))) if roots.size else 0.0
 59
 60
 61def process(
 62    reflection: np.ndarray | list[float], numerator: np.ndarray | list[float], x: np.ndarray
 63):
 64    return ld.process_batch(list(map(float, reflection)), list(map(float, numerator)), x)
 65
 66
 67def main() -> None:
 68    parser = argparse.ArgumentParser(
 69        description="Benchmark finite-Hankel SISO IIR reduction amortization."
 70    )
 71    parser.add_argument("--full-orders", type=int, nargs="+", default=[16, 32, 64])
 72    parser.add_argument("--reduced-orders", type=int, nargs="+", default=[4, 8, 12, 16])
 73    parser.add_argument("--channels", type=int, default=64)
 74    parser.add_argument("--samples", type=int, default=50000)
 75    parser.add_argument("--repeats", type=int, default=3)
 76    parser.add_argument("--n-impulse", type=int, default=768)
 77    parser.add_argument("--hankel-rows", type=int, default=96)
 78    parser.add_argument("--hankel-cols", type=int, default=96)
 79    parser.add_argument("--seed", type=int, default=42)
 80    parser.add_argument("--output", default="reports/hankel-reduction-speedup.json")
 81    args = parser.parse_args()
 82
 83    rng = np.random.default_rng(args.seed)
 84    x = rng.normal(size=(args.channels, args.samples)).astype(np.float64)
 85    rows_out: list[dict[str, float | int | bool | str | None]] = []
 86
 87    for full_order in args.full_orders:
 88        reflection = stable_reflection(full_order, rng)
 89        numerator = numerator_for_order(full_order)
 90
 91        full_time, y_full = median_time(
 92            lambda reflection=reflection, numerator=numerator: process(reflection, numerator, x),
 93            args.repeats,
 94        )
 95        full_per_sample = full_time / (args.channels * args.samples)
 96
 97        for reduced_order in args.reduced_orders:
 98            if reduced_order >= full_order:
 99                continue
100            if reduced_order > min(args.hankel_rows, args.hankel_cols):
101                continue
102
103            reduce_time, reduced = median_time(
104                lambda ro=reduced_order, reflection=reflection, numerator=numerator: (
105                    ld.finite_hankel_reduce_iir(
106                        reflection.tolist(),
107                        numerator.tolist(),
108                        reduced_order=ro,
109                        n_impulse=args.n_impulse,
110                        rows=args.hankel_rows,
111                        cols=args.hankel_cols,
112                    )
113                ),
114                1,
115            )
116
117            if not reduced["stable"] or not reduced["reflection"]:
118                rows_out.append(
119                    {
120                        "full_order": full_order,
121                        "reduced_order": reduced_order,
122                        "stable": bool(reduced["stable"]),
123                        "reduction_time_s": reduce_time,
124                        "error": "reduced model was not stable in reflection coordinates",
125                    }
126                )
127                continue
128
129            reduced_reflection = list(map(float, reduced["reflection"]))
130            reduced_numerator = list(map(float, reduced["numerator"]))
131            reduced_time, y_reduced = median_time(
132                lambda rr=reduced_reflection, rn=reduced_numerator: process(rr, rn, x),
133                args.repeats,
134            )
135            reduced_per_sample = reduced_time / (args.channels * args.samples)
136            delta = full_per_sample - reduced_per_sample
137            break_even_samples_per_channel = (
138                reduce_time / delta / args.channels if delta > 0 else None
139            )
140
141            rel_mse = float(np.mean((y_full - y_reduced) ** 2) / (np.mean(y_full**2) + 1e-30))
142            rows_out.append(
143                {
144                    "full_order": full_order,
145                    "reduced_order": reduced_order,
146                    "stable": bool(reduced["stable"]),
147                    "method": reduced["method"],
148                    "retained_hankel_energy": float(reduced["retained_hankel_energy"]),
149                    "relative_impulse_error": float(reduced["relative_impulse_error"]),
150                    "rel_mse_on_random_batch": rel_mse,
151                    "snr_db_on_random_batch": snr_db(y_full, y_reduced),
152                    "max_pole_radius": pole_radius_from_reflection(reduced_reflection),
153                    "reduction_time_s": reduce_time,
154                    "full_filter_median_s": full_time,
155                    "reduced_filter_median_s": reduced_time,
156                    "filter_speedup": full_time / reduced_time if reduced_time > 0 else None,
157                    "full_time_per_sample_s": full_per_sample,
158                    "reduced_time_per_sample_s": reduced_per_sample,
159                    "break_even_samples_per_channel": break_even_samples_per_channel,
160                }
161            )
162
163    result = {
164        "metadata": {
165            "python": platform.python_version(),
166            "platform": platform.platform(),
167            "has_openmp": bool(ld.HAS_OPENMP),
168            "channels": args.channels,
169            "samples": args.samples,
170            "repeats": args.repeats,
171            "n_impulse": args.n_impulse,
172            "hankel_rows": args.hankel_rows,
173            "hankel_cols": args.hankel_cols,
174            "seed": args.seed,
175            "description": "Finite-Hankel reduction amortization benchmark. Reduction is a preprocessing cost; speedup applies when the reduced model is reused.",
176        },
177        "rows": rows_out,
178    }
179
180    output = Path(args.output)
181    output.parent.mkdir(parents=True, exist_ok=True)
182    output.write_text(json.dumps(result, indent=2), encoding="utf-8")
183
184    print(json.dumps(result["metadata"], indent=2))
185    print()
186    print(
187        f"{'full':>5s} {'red':>5s} {'stable':>7s} {'reduce_s':>10s} "
188        f"{'full_s':>10s} {'red_s':>10s} {'speedup':>8s} {'SNR':>8s} {'break_even/ch':>15s}"
189    )
190    print("-" * 95)
191    for row in rows_out:
192        if "error" in row:
193            print(
194                f"{row['full_order']:5d} {row['reduced_order']:5d} {str(row['stable']):>7s} {row['reduction_time_s']:10.4f}  ERROR: {row['error']}"
195            )
196            continue
197        be = row["break_even_samples_per_channel"]
198        be_text = "n/a" if be is None else f"{be:.0f}"
199        print(
200            f"{row['full_order']:5d} {row['reduced_order']:5d} {str(row['stable']):>7s} "
201            f"{row['reduction_time_s']:10.4f} {row['full_filter_median_s']:10.4f} "
202            f"{row['reduced_filter_median_s']:10.4f} {row['filter_speedup']:8.2f} "
203            f"{row['snr_db_on_random_batch']:8.2f} {be_text:>15s}"
204        )
205
206    print()
207    print(f"Wrote {output}")
208
209
210if __name__ == "__main__":
211    main()