Million-sample IIR throughput for long acoustic-like tails

Tutorial goal

Show why a compact IIR/lattice representation can process very long signals efficiently when a long decay has a low-order recursive description.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

Long acoustic paths and reverberant decays are often represented as long FIR impulse responses. That representation is flexible, but a tail with hundreds of thousands of taps is expensive to process repeatedly, especially when the signal itself has millions of samples. When the dominant decay is well described by a stable recursive model, an IIR/lattice representation can keep the long memory implicitly in a small state vector.

Key idea and equations

A long FIR tail computes

\[y[n] = \sum_{m=0}^{L-1} h[m] x[n-m].\]

For the exponential tail

\[h[m] = (1-r) r^m, \qquad 0 < r < 1,\]

an equivalent stable IIR recursion is

\[y[n] = (1-r) x[n] + r y[n-1].\]

In the scalar lattice convention, this denominator has reflection coefficient k_1 = -r, so stability is exposed by |k_1| < 1.

How to read the result

Compare the IIR recursive state count with the FIR truncation length and the local median timing on million-sample inputs.

Run command

python examples/million_sample_iir_throughput.py

Run status

Return code: 0

Captured stdout

million-sample IIR throughput demonstration
====================================================
samples: 1,000,000
stable pole radius: 0.99992000
IIR reflection coefficient: -0.99992000
IIR recursive state count: 1
IIR median time: 0.010744 s
IIR throughput: 93.07 million samples/s

FIR truncation taps: 131,072
FFT length for FIR reference: 2,097,152
FFT/FIR median time: 0.124499 s
IIR speedup over FFT/FIR reference: 11.59x
relative RMS error from truncating the infinite IIR tail: 2.640e-05
omitted tail amplitude after 131,072 taps: 2.792e-05

Generated data files

Source code

  1"""Million-sample throughput: long acoustic-like tail as a low-order IIR.
  2
  3This example shows the speed motivation for recursive models.  A long FIR tail
  4can represent an acoustic or decay response by storing many taps.  When the tail
  5has a compact recursive description, an IIR/lattice filter can process millions
  6of samples with a small fixed state instead of carrying the full tail length.
  7
  8The example uses a first-order stable IIR whose impulse response is an exponential
  9reverberant-like decay,
 10
 11    h[m] = (1 - r) r**m,    0 < r < 1.
 12
 13The equivalent FIR approximation truncates that tail to many taps.  The timing
 14comparison is intentionally local to your machine and NumPy build; it is a
 15reproducibility aid, not a universal benchmark number.
 16"""
 17
 18from __future__ import annotations
 19
 20import argparse
 21import csv
 22import math
 23import os
 24import statistics
 25import time
 26from pathlib import Path
 27
 28import numpy as np
 29
 30import lattice_dsp as ld
 31
 32
 33def artifact_dir() -> Path:
 34    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 35    path.mkdir(parents=True, exist_ok=True)
 36    return path
 37
 38
 39def median_time(fn, repeats: int) -> tuple[float, np.ndarray]:
 40    times: list[float] = []
 41    result: np.ndarray | None = None
 42    for _ in range(max(1, repeats)):
 43        t0 = time.perf_counter()
 44        result = np.asarray(fn(), dtype=np.float64)
 45        times.append(time.perf_counter() - t0)
 46    assert result is not None
 47    return statistics.median(times), result
 48
 49
 50def next_power_of_two(n: int) -> int:
 51    if n <= 1:
 52        return 1
 53    return 1 << (n - 1).bit_length()
 54
 55
 56def fft_convolve_truncated(x: np.ndarray, h: np.ndarray) -> np.ndarray:
 57    """Dependency-free full FFT convolution, truncated to len(x)."""
 58    n_out = int(x.size + h.size - 1)
 59    n_fft = next_power_of_two(n_out)
 60    spectrum = np.fft.rfft(x, n_fft) * np.fft.rfft(h, n_fft)
 61    return np.fft.irfft(spectrum, n_fft)[: x.size]
 62
 63
 64def relative_rms_error(reference: np.ndarray, estimate: np.ndarray) -> float:
 65    err = reference - estimate
 66    return float(np.sqrt(np.mean(err * err)) / (np.sqrt(np.mean(reference * reference)) + 1e-30))
 67
 68
 69def main() -> None:
 70    parser = argparse.ArgumentParser(description="Million-sample IIR throughput demonstration.")
 71    parser.add_argument("--samples", type=int, default=1_000_000, help="number of input samples")
 72    parser.add_argument(
 73        "--tail-taps", type=int, default=131_072, help="FIR taps used to truncate the IIR tail"
 74    )
 75    parser.add_argument(
 76        "--pole",
 77        type=float,
 78        default=0.99992,
 79        help="stable IIR pole radius for the exponential tail",
 80    )
 81    parser.add_argument("--repeats", type=int, default=3, help="median timing repeats")
 82    parser.add_argument("--seed", type=int, default=2026)
 83    parser.add_argument("--skip-fft", action="store_true", help="skip the FFT/FIR reference timing")
 84    args = parser.parse_args()
 85
 86    if args.samples <= 0:
 87        raise ValueError("--samples must be positive")
 88    if args.tail_taps <= 0:
 89        raise ValueError("--tail-taps must be positive")
 90    if not (0.0 < args.pole < 1.0):
 91        raise ValueError("--pole must satisfy 0 < pole < 1")
 92
 93    rng = np.random.default_rng(args.seed)
 94    x = rng.normal(size=args.samples).astype(np.float64)
 95
 96    # Under the lattice convention used by lattice-dsp, a first-order reflection
 97    # coefficient k gives A(z) = 1 + k z^-1.  The pole is therefore -k.  Choosing
 98    # k = -pole gives the stable smoother y[n] = (1-pole) x[n] + pole y[n-1].
 99    reflection = [-float(args.pole)]
100    numerator = [1.0 - float(args.pole), 0.0]
101
102    iir_time, y_iir = median_time(
103        lambda: ld.LatticeIIR(reflection, numerator).process(x), args.repeats
104    )
105
106    rows: list[dict[str, object]] = [
107        {
108            "method": "lattice_iir_recursive",
109            "samples": args.samples,
110            "state_or_taps": len(reflection),
111            "median_seconds": iir_time,
112            "throughput_msamples_per_s": args.samples / max(iir_time, 1e-30) / 1e6,
113            "relative_rms_error_vs_iir": 0.0,
114        }
115    ]
116
117    print("million-sample IIR throughput demonstration")
118    print("=" * 52)
119    print(f"samples: {args.samples:,}")
120    print(f"stable pole radius: {args.pole:.8f}")
121    print(f"IIR reflection coefficient: {reflection[0]:.8f}")
122    print(f"IIR recursive state count: {len(reflection)}")
123    print(f"IIR median time: {iir_time:.6f} s")
124    print(f"IIR throughput: {args.samples / max(iir_time, 1e-30) / 1e6:.2f} million samples/s")
125
126    if not args.skip_fft:
127        # Finite FIR approximation to the same infinite exponential tail.  This is
128        # not meant as the slowest possible comparison: FFT convolution is already
129        # a strong baseline for long FIR filtering.  The point is that the IIR has
130        # constant state for this structured tail, while the FIR path stores and
131        # transforms many coefficients.
132        taps = (1.0 - args.pole) * np.power(args.pole, np.arange(args.tail_taps, dtype=np.float64))
133        fft_time, y_fir = median_time(lambda: fft_convolve_truncated(x, taps), args.repeats)
134        rel_err = relative_rms_error(y_iir, y_fir)
135        n_fft = next_power_of_two(args.samples + args.tail_taps - 1)
136        speedup = fft_time / max(iir_time, 1e-30)
137        rows.append(
138            {
139                "method": "truncated_fir_fft_convolution",
140                "samples": args.samples,
141                "state_or_taps": args.tail_taps,
142                "median_seconds": fft_time,
143                "throughput_msamples_per_s": args.samples / max(fft_time, 1e-30) / 1e6,
144                "relative_rms_error_vs_iir": rel_err,
145                "fft_length": n_fft,
146                "speedup_iir_vs_fft_fir": speedup,
147            }
148        )
149        print()
150        print(f"FIR truncation taps: {args.tail_taps:,}")
151        print(f"FFT length for FIR reference: {n_fft:,}")
152        print(f"FFT/FIR median time: {fft_time:.6f} s")
153        print(f"IIR speedup over FFT/FIR reference: {speedup:.2f}x")
154        print(f"relative RMS error from truncating the infinite IIR tail: {rel_err:.3e}")
155        print(
156            f"omitted tail amplitude after {args.tail_taps:,} taps: {math.pow(args.pole, args.tail_taps):.3e}"
157        )
158
159    out_dir = artifact_dir()
160    csv_path = out_dir / "million_sample_iir_throughput.csv"
161    fieldnames = sorted({key for row in rows for key in row})
162    with csv_path.open("w", newline="", encoding="utf-8") as f:
163        writer = csv.DictWriter(f, fieldnames=fieldnames)
164        writer.writeheader()
165        writer.writerows(rows)
166    print()
167    print(f"wrote {csv_path}")
168
169
170if __name__ == "__main__":
171    main()