Million-sample IIR throughput for long acoustic-like tails¶
Tutorial goal
Show why a compact IIR/lattice representation can process very long signals efficiently when a long decay has a low-order recursive description.
Note
New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.
Context¶
Long acoustic paths and reverberant decays are often represented as long FIR impulse responses. That representation is flexible, but a tail with hundreds of thousands of taps is expensive to process repeatedly, especially when the signal itself has millions of samples. When the dominant decay is well described by a stable recursive model, an IIR/lattice representation can keep the long memory implicitly in a small state vector.
Key idea and equations¶
A long FIR tail computes
For the exponential tail
an equivalent stable IIR recursion is
In the scalar lattice convention, this denominator has reflection coefficient
k_1 = -r, so stability is exposed by |k_1| < 1.
How to read the result¶
Compare the IIR recursive state count with the FIR truncation length and the local median timing on million-sample inputs.
Run command¶
python examples/million_sample_iir_throughput.py
Source code¶
1"""Million-sample throughput: long acoustic-like tail as a low-order IIR.
2
3This example shows the speed motivation for recursive models. A long FIR tail
4can represent an acoustic or decay response by storing many taps. When the tail
5has a compact recursive description, an IIR/lattice filter can process millions
6of samples with a small fixed state instead of carrying the full tail length.
7
8The example uses a first-order stable IIR whose impulse response is an exponential
9reverberant-like decay,
10
11 h[m] = (1 - r) r**m, 0 < r < 1.
12
13The equivalent FIR approximation truncates that tail to many taps. The timing
14comparison is intentionally local to your machine and NumPy build; it is a
15reproducibility aid, not a universal benchmark number.
16"""
17
18from __future__ import annotations
19
20import argparse
21import csv
22import math
23import os
24import statistics
25import time
26from pathlib import Path
27
28import numpy as np
29
30import lattice_dsp as ld
31
32
33def artifact_dir() -> Path:
34 path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
35 path.mkdir(parents=True, exist_ok=True)
36 return path
37
38
39def median_time(fn, repeats: int) -> tuple[float, np.ndarray]:
40 times: list[float] = []
41 result: np.ndarray | None = None
42 for _ in range(max(1, repeats)):
43 t0 = time.perf_counter()
44 result = np.asarray(fn(), dtype=np.float64)
45 times.append(time.perf_counter() - t0)
46 assert result is not None
47 return statistics.median(times), result
48
49
50def next_power_of_two(n: int) -> int:
51 if n <= 1:
52 return 1
53 return 1 << (n - 1).bit_length()
54
55
56def fft_convolve_truncated(x: np.ndarray, h: np.ndarray) -> np.ndarray:
57 """Dependency-free full FFT convolution, truncated to len(x)."""
58 n_out = int(x.size + h.size - 1)
59 n_fft = next_power_of_two(n_out)
60 spectrum = np.fft.rfft(x, n_fft) * np.fft.rfft(h, n_fft)
61 return np.fft.irfft(spectrum, n_fft)[: x.size]
62
63
64def relative_rms_error(reference: np.ndarray, estimate: np.ndarray) -> float:
65 err = reference - estimate
66 return float(np.sqrt(np.mean(err * err)) / (np.sqrt(np.mean(reference * reference)) + 1e-30))
67
68
69def main() -> None:
70 parser = argparse.ArgumentParser(description="Million-sample IIR throughput demonstration.")
71 parser.add_argument("--samples", type=int, default=1_000_000, help="number of input samples")
72 parser.add_argument(
73 "--tail-taps", type=int, default=131_072, help="FIR taps used to truncate the IIR tail"
74 )
75 parser.add_argument(
76 "--pole",
77 type=float,
78 default=0.99992,
79 help="stable IIR pole radius for the exponential tail",
80 )
81 parser.add_argument("--repeats", type=int, default=3, help="median timing repeats")
82 parser.add_argument("--seed", type=int, default=2026)
83 parser.add_argument("--skip-fft", action="store_true", help="skip the FFT/FIR reference timing")
84 args = parser.parse_args()
85
86 if args.samples <= 0:
87 raise ValueError("--samples must be positive")
88 if args.tail_taps <= 0:
89 raise ValueError("--tail-taps must be positive")
90 if not (0.0 < args.pole < 1.0):
91 raise ValueError("--pole must satisfy 0 < pole < 1")
92
93 rng = np.random.default_rng(args.seed)
94 x = rng.normal(size=args.samples).astype(np.float64)
95
96 # Under the lattice convention used by lattice-dsp, a first-order reflection
97 # coefficient k gives A(z) = 1 + k z^-1. The pole is therefore -k. Choosing
98 # k = -pole gives the stable smoother y[n] = (1-pole) x[n] + pole y[n-1].
99 reflection = [-float(args.pole)]
100 numerator = [1.0 - float(args.pole), 0.0]
101
102 iir_time, y_iir = median_time(
103 lambda: ld.LatticeIIR(reflection, numerator).process(x), args.repeats
104 )
105
106 rows: list[dict[str, object]] = [
107 {
108 "method": "lattice_iir_recursive",
109 "samples": args.samples,
110 "state_or_taps": len(reflection),
111 "median_seconds": iir_time,
112 "throughput_msamples_per_s": args.samples / max(iir_time, 1e-30) / 1e6,
113 "relative_rms_error_vs_iir": 0.0,
114 }
115 ]
116
117 print("million-sample IIR throughput demonstration")
118 print("=" * 52)
119 print(f"samples: {args.samples:,}")
120 print(f"stable pole radius: {args.pole:.8f}")
121 print(f"IIR reflection coefficient: {reflection[0]:.8f}")
122 print(f"IIR recursive state count: {len(reflection)}")
123 print(f"IIR median time: {iir_time:.6f} s")
124 print(f"IIR throughput: {args.samples / max(iir_time, 1e-30) / 1e6:.2f} million samples/s")
125
126 if not args.skip_fft:
127 # Finite FIR approximation to the same infinite exponential tail. This is
128 # not meant as the slowest possible comparison: FFT convolution is already
129 # a strong baseline for long FIR filtering. The point is that the IIR has
130 # constant state for this structured tail, while the FIR path stores and
131 # transforms many coefficients.
132 taps = (1.0 - args.pole) * np.power(args.pole, np.arange(args.tail_taps, dtype=np.float64))
133 fft_time, y_fir = median_time(lambda: fft_convolve_truncated(x, taps), args.repeats)
134 rel_err = relative_rms_error(y_iir, y_fir)
135 n_fft = next_power_of_two(args.samples + args.tail_taps - 1)
136 speedup = fft_time / max(iir_time, 1e-30)
137 rows.append(
138 {
139 "method": "truncated_fir_fft_convolution",
140 "samples": args.samples,
141 "state_or_taps": args.tail_taps,
142 "median_seconds": fft_time,
143 "throughput_msamples_per_s": args.samples / max(fft_time, 1e-30) / 1e6,
144 "relative_rms_error_vs_iir": rel_err,
145 "fft_length": n_fft,
146 "speedup_iir_vs_fft_fir": speedup,
147 }
148 )
149 print()
150 print(f"FIR truncation taps: {args.tail_taps:,}")
151 print(f"FFT length for FIR reference: {n_fft:,}")
152 print(f"FFT/FIR median time: {fft_time:.6f} s")
153 print(f"IIR speedup over FFT/FIR reference: {speedup:.2f}x")
154 print(f"relative RMS error from truncating the infinite IIR tail: {rel_err:.3e}")
155 print(
156 f"omitted tail amplitude after {args.tail_taps:,} taps: {math.pow(args.pole, args.tail_taps):.3e}"
157 )
158
159 out_dir = artifact_dir()
160 csv_path = out_dir / "million_sample_iir_throughput.csv"
161 fieldnames = sorted({key for row in rows for key in row})
162 with csv_path.open("w", newline="", encoding="utf-8") as f:
163 writer = csv.DictWriter(f, fieldnames=fieldnames)
164 writer.writeheader()
165 writer.writerows(rows)
166 print()
167 print(f"wrote {csv_path}")
168
169
170if __name__ == "__main__":
171 main()