Finite Hankel reduction amortization benchmark

Tutorial goal

Measure when a one-time finite-Hankel reduction pays off during repeated high-order IIR filtering.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

The finite-Hankel reducer is a preprocessing step. This benchmark makes the speed argument explicit by measuring reduction time, full-order filtering time, reduced-order filtering time, and the break-even number of samples per channel. It separates the method is finite-Hankel/Ho–Kalman reduction, not an exact Nehari/AAK solver.

Key idea and equations

The break-even sample count is estimated as

\[N_{break-even} = \frac{t_{reduce}} {t_{full/sample}-t_{reduced/sample}}.\]

How to read the result

Look for high filter speedup, acceptable SNR/error, stable reduced denominators, and a break-even count that is small relative to the intended workload.

Run command

python benchmarks/hankel_reduction_speedup.py --full-orders 16 32 --reduced-orders 4 8 12 --channels 32 --samples 20000 --repeats 3 --n-impulse 512 --hankel-rows 64 --hankel-cols 64 --output docs/benchmarks/generated/_artifacts/hankel_reduction_speedup/hankel-reduction-speedup.json

Visual and data readout

When the benchmark gallery is built with results, this page embeds PNG summaries generated from the same JSON/CSV artifacts. The raw data stay available below as downloads so exact numbers remain reproducible without making the public page read like console output.

Source code

  1from __future__ import annotations
  2
  3import argparse
  4import json
  5import math
  6import platform
  7import statistics
  8import time
  9from pathlib import Path
 10
 11import numpy as np
 12
 13import lattice_dsp as ld
 14
 15
 16def median_time(fn, repeats: int):
 17    times = []
 18    result = None
 19    for _ in range(repeats):
 20        t0 = time.perf_counter()
 21        result = fn()
 22        times.append(time.perf_counter() - t0)
 23    return statistics.median(times), result
 24
 25
 26def stable_reflection(order: int, rng: np.random.Generator) -> np.ndarray:
 27    # Slow decay gives the reducer something nontrivial to compress while staying
 28    # comfortably inside the scalar lattice stability region.
 29    decay = np.exp(-np.arange(order) / max(8.0, order / 5.0))
 30    signs = rng.choice([-1.0, 1.0], size=order)
 31    jitter = rng.uniform(0.55, 1.0, size=order)
 32    return 0.72 * decay * signs * jitter
 33
 34
 35def numerator_for_order(order: int) -> np.ndarray:
 36    n = np.zeros(order + 1, dtype=float)
 37    n[0] = 1.0
 38    if order >= 2:
 39        n[1] = -0.18
 40        n[2] = 0.08
 41    if order >= 5:
 42        n[4] = -0.04
 43    return n
 44
 45
 46def snr_db(reference: np.ndarray, estimate: np.ndarray) -> float:
 47    err = reference - estimate
 48    p_ref = float(np.mean(reference * reference))
 49    p_err = float(np.mean(err * err))
 50    return 10.0 * math.log10((p_ref + 1e-30) / (p_err + 1e-30))
 51
 52
 53def pole_radius_from_reflection(reflection: list[float]) -> float:
 54    if not reflection:
 55        return 0.0
 56    denominator = np.asarray(ld.reflection_to_denominator(reflection), dtype=float)
 57    roots = np.roots(denominator)
 58    return float(np.max(np.abs(roots))) if roots.size else 0.0
 59
 60
 61def process(
 62    reflection: np.ndarray | list[float], numerator: np.ndarray | list[float], x: np.ndarray
 63):
 64    return ld.process_batch(list(map(float, reflection)), list(map(float, numerator)), x)
 65
 66
 67def main() -> None:
 68    parser = argparse.ArgumentParser(
 69        description="Benchmark finite-Hankel SISO IIR reduction amortization."
 70    )
 71    parser.add_argument("--full-orders", type=int, nargs="+", default=[16, 32, 64])
 72    parser.add_argument("--reduced-orders", type=int, nargs="+", default=[4, 8, 12, 16])
 73    parser.add_argument("--channels", type=int, default=64)
 74    parser.add_argument("--samples", type=int, default=50000)
 75    parser.add_argument("--repeats", type=int, default=3)
 76    parser.add_argument("--n-impulse", type=int, default=768)
 77    parser.add_argument("--hankel-rows", type=int, default=96)
 78    parser.add_argument("--hankel-cols", type=int, default=96)
 79    parser.add_argument("--seed", type=int, default=42)
 80    parser.add_argument("--output", default="reports/hankel-reduction-speedup.json")
 81    args = parser.parse_args()
 82
 83    rng = np.random.default_rng(args.seed)
 84    x = rng.normal(size=(args.channels, args.samples)).astype(np.float64)
 85    rows_out: list[dict[str, float | int | bool | str | None]] = []
 86
 87    for full_order in args.full_orders:
 88        reflection = stable_reflection(full_order, rng)
 89        numerator = numerator_for_order(full_order)
 90
 91        full_time, y_full = median_time(
 92            lambda reflection=reflection, numerator=numerator: process(reflection, numerator, x),
 93            args.repeats,
 94        )
 95        full_per_sample = full_time / (args.channels * args.samples)
 96
 97        for reduced_order in args.reduced_orders:
 98            if reduced_order >= full_order:
 99                continue
100            if reduced_order > min(args.hankel_rows, args.hankel_cols):
101                continue
102
103            reduce_time, reduced = median_time(
104                lambda ro=reduced_order, reflection=reflection, numerator=numerator: (
105                    ld.finite_hankel_reduce_iir(
106                        reflection.tolist(),
107                        numerator.tolist(),
108                        reduced_order=ro,
109                        n_impulse=args.n_impulse,
110                        rows=args.hankel_rows,
111                        cols=args.hankel_cols,
112                    )
113                ),
114                1,
115            )
116
117            if not reduced["stable"] or not reduced["reflection"]:
118                rows_out.append(
119                    {
120                        "full_order": full_order,
121                        "reduced_order": reduced_order,
122                        "stable": bool(reduced["stable"]),
123                        "reduction_time_s": reduce_time,
124                        "error": "reduced model was not stable in reflection coordinates",
125                    }
126                )
127                continue
128
129            reduced_reflection = list(map(float, reduced["reflection"]))
130            reduced_numerator = list(map(float, reduced["numerator"]))
131            reduced_time, y_reduced = median_time(
132                lambda rr=reduced_reflection, rn=reduced_numerator: process(rr, rn, x),
133                args.repeats,
134            )
135            reduced_per_sample = reduced_time / (args.channels * args.samples)
136            delta = full_per_sample - reduced_per_sample
137            break_even_samples_per_channel = (
138                reduce_time / delta / args.channels if delta > 0 else None
139            )
140
141            rel_mse = float(np.mean((y_full - y_reduced) ** 2) / (np.mean(y_full**2) + 1e-30))
142            rows_out.append(
143                {
144                    "full_order": full_order,
145                    "reduced_order": reduced_order,
146                    "stable": bool(reduced["stable"]),
147                    "method": reduced["method"],
148                    "retained_hankel_energy": float(reduced["retained_hankel_energy"]),
149                    "relative_impulse_error": float(reduced["relative_impulse_error"]),
150                    "rel_mse_on_random_batch": rel_mse,
151                    "snr_db_on_random_batch": snr_db(y_full, y_reduced),
152                    "max_pole_radius": pole_radius_from_reflection(reduced_reflection),
153                    "reduction_time_s": reduce_time,
154                    "full_filter_median_s": full_time,
155                    "reduced_filter_median_s": reduced_time,
156                    "filter_speedup": full_time / reduced_time if reduced_time > 0 else None,
157                    "full_time_per_sample_s": full_per_sample,
158                    "reduced_time_per_sample_s": reduced_per_sample,
159                    "break_even_samples_per_channel": break_even_samples_per_channel,
160                }
161            )
162
163    result = {
164        "metadata": {
165            "python": platform.python_version(),
166            "platform": platform.platform(),
167            "has_openmp": bool(ld.HAS_OPENMP),
168            "channels": args.channels,
169            "samples": args.samples,
170            "repeats": args.repeats,
171            "n_impulse": args.n_impulse,
172            "hankel_rows": args.hankel_rows,
173            "hankel_cols": args.hankel_cols,
174            "seed": args.seed,
175            "description": "Finite-Hankel reduction amortization benchmark. Reduction is a preprocessing cost; speedup applies when the reduced model is reused.",
176        },
177        "rows": rows_out,
178    }
179
180    output = Path(args.output)
181    output.parent.mkdir(parents=True, exist_ok=True)
182    output.write_text(json.dumps(result, indent=2), encoding="utf-8")
183
184    print(json.dumps(result["metadata"], indent=2))
185    print()
186    print(
187        f"{'full':>5s} {'red':>5s} {'stable':>7s} {'reduce_s':>10s} "
188        f"{'full_s':>10s} {'red_s':>10s} {'speedup':>8s} {'SNR':>8s} {'break_even/ch':>15s}"
189    )
190    print("-" * 95)
191    for row in rows_out:
192        if "error" in row:
193            print(
194                f"{row['full_order']:5d} {row['reduced_order']:5d} {str(row['stable']):>7s} {row['reduction_time_s']:10.4f}  ERROR: {row['error']}"
195            )
196            continue
197        be = row["break_even_samples_per_channel"]
198        be_text = "n/a" if be is None else f"{be:.0f}"
199        print(
200            f"{row['full_order']:5d} {row['reduced_order']:5d} {str(row['stable']):>7s} "
201            f"{row['reduction_time_s']:10.4f} {row['full_filter_median_s']:10.4f} "
202            f"{row['reduced_filter_median_s']:10.4f} {row['filter_speedup']:8.2f} "
203            f"{row['snr_db_on_random_batch']:8.2f} {be_text:>15s}"
204        )
205
206    print()
207    print(f"Wrote {output}")
208
209
210if __name__ == "__main__":
211    main()