Large echo-scale recursive model stress

Tutorial goal

Stress a million-sample signal with a high-order stable lattice-ladder model and compare the scale with a long FIR echo tap vector.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

Echo-style paths often have long memory: delay, early reflections, and a slowly decaying room tail. FIR adaptive echo cancellers model that memory by adding taps, so a long path becomes a large parameter vector that is filtered and updated at every sample. This example keeps adaptation out of scope and instead stresses the fixed-filter processing axis: a million-sample input and a high-order stable recursive lattice-ladder model.

Key idea and equations

A direct FIR echo model with L taps has the filtering relation

\[y[n] = \sum_{m=0}^{L-1} h[m] x[n-m],\]

and an LMS-style update touches the same large tap vector again,

\[h_{n+1} = h_n + \mu e[n] x_n.\]

A lattice-ladder IIR model stores recursive state and stage parameters. Its scalar all-pole stability guard is still expressed through bounded reflection coefficients,

\[|k_i| < 1.\]

The comparison is a scale diagnostic: N L direct FIR tap visits versus N p lattice-stage visits for recursive order p. It is not an accuracy-equivalence claim.

How to read the result

Compare the local lattice-ladder timing with the printed FIR echo-scale tap-visit estimates, especially the FIR taps / lattice order ratio.

Run command

python examples/large_order_echo_stress.py

Source code

  1"""Large echo-scale stress: long signals and high-order stable recursive models.
  2
  3This example complements ``million_sample_iir_throughput.py``.  The first
  4throughput example shows a very long tail that collapses to one recursive state.
  5Here we stress the other axis that matters in echo-style work: a long signal and
  6a much larger stable recursive order.
  7
  8The example is intentionally a fixed-filter throughput and scale diagnostic, not
  9an adaptive echo canceller.  Classical FIR echo cancellation often needs a very
 10large tap vector to cover delay, early reflections, and a reverberant tail.  An
 11LMS-style FIR pass touches that large vector at every sample, and adaptation can
 12roughly double the coefficient traffic because it both filters and updates taps.
 13
 14A stable lattice/lattice-ladder IIR model has a different cost profile.  It keeps
 15recursive state and enforces scalar all-pole stability through bounded reflection
 16coefficients.  This does not remove step-size tuning or identification difficulty
 17in adaptive systems, but it shows why compact recursive models are attractive
 18when the physical path has long memory.
 19"""
 20
 21from __future__ import annotations
 22
 23import argparse
 24import csv
 25import os
 26import statistics
 27import time
 28from pathlib import Path
 29
 30import numpy as np
 31
 32import lattice_dsp as ld
 33
 34
 35def artifact_dir() -> Path:
 36    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 37    path.mkdir(parents=True, exist_ok=True)
 38    return path
 39
 40
 41def median_time(fn, repeats: int) -> tuple[float, np.ndarray]:
 42    times: list[float] = []
 43    result: np.ndarray | None = None
 44    for _ in range(max(1, repeats)):
 45        t0 = time.perf_counter()
 46        result = np.asarray(fn(), dtype=np.float64)
 47        times.append(time.perf_counter() - t0)
 48    assert result is not None
 49    return statistics.median(times), result
 50
 51
 52def synthetic_speech_like_signal(samples: int, rng: np.random.Generator) -> np.ndarray:
 53    """Create a deterministic, dependency-free speech/noise-like input."""
 54    white = rng.normal(size=samples).astype(np.float64)
 55    # A tiny AR coloring loop avoids external audio dependencies while making the
 56    # input less white than a pure RNG sequence.
 57    x = np.empty_like(white)
 58    s1 = 0.0
 59    s2 = 0.0
 60    for n, v in enumerate(white):
 61        s1 = 0.92 * s1 + 0.08 * v
 62        s2 = 0.65 * s2 + 0.35 * s1
 63        x[n] = s2
 64    rms = float(np.sqrt(np.mean(x * x)))
 65    return x / max(rms, 1e-30)
 66
 67
 68def stable_echo_lattice(
 69    order: int, max_reflection: float, seed: int
 70) -> tuple[np.ndarray, np.ndarray]:
 71    """Build a stable high-order lattice-ladder echo-like model.
 72
 73    The reflection coefficients are deliberately bounded well inside the unit
 74    disk.  Ladder taps decay with stage number so the model behaves like a large
 75    but controlled recursive echo tail instead of an arbitrary unstable IIR.
 76    """
 77    rng = np.random.default_rng(seed)
 78    stages = np.arange(1, order + 1, dtype=np.float64)
 79    signs = rng.choice(np.array([-1.0, 1.0]), size=order)
 80    smooth = np.sin(0.071 * stages) + 0.35 * np.sin(0.019 * stages + 0.4)
 81    envelope = np.exp(-stages / max(order / 3.0, 1.0))
 82    reflection = max_reflection * envelope * smooth / max(1.0, np.max(np.abs(smooth)))
 83    reflection *= signs
 84    reflection = np.clip(reflection, -0.98, 0.98).astype(np.float64)
 85
 86    ladder_stages = np.arange(order + 1, dtype=np.float64)
 87    ladder = rng.normal(scale=1.0, size=order + 1) * np.exp(-ladder_stages / max(order / 7.0, 1.0))
 88    ladder[0] += 1.0
 89    ladder /= max(np.linalg.norm(ladder), 1e-30)
 90    ladder *= 0.35
 91    return reflection, ladder.astype(np.float64)
 92
 93
 94def maybe_time_fft_tail(x: np.ndarray, echo_taps: int, repeats: int) -> tuple[float, np.ndarray]:
 95    """Optional FFT/FIR echo-tail timing used only when requested."""
 96    tail_index = np.arange(echo_taps, dtype=np.float64)
 97    decay = np.exp(-tail_index / max(echo_taps / 8.0, 1.0))
 98    h = decay / max(np.linalg.norm(decay), 1e-30)
 99    n_out = x.size + h.size - 1
100    n_fft = 1 << (n_out - 1).bit_length()
101
102    def run() -> np.ndarray:
103        spectrum = np.fft.rfft(x, n_fft) * np.fft.rfft(h, n_fft)
104        return np.fft.irfft(spectrum, n_fft)[: x.size]
105
106    return median_time(run, repeats)
107
108
109def main() -> None:
110    parser = argparse.ArgumentParser(
111        description="Large echo-scale stable recursive model stress test."
112    )
113    parser.add_argument("--samples", type=int, default=1_000_000, help="number of input samples")
114    parser.add_argument(
115        "--order", type=int, default=512, help="stable lattice-ladder recursive order"
116    )
117    parser.add_argument(
118        "--echo-taps",
119        type=int,
120        default=131_072,
121        help="reference FIR echo-tap count for scale estimates",
122    )
123    parser.add_argument(
124        "--max-reflection", type=float, default=0.45, help="maximum reflection magnitude envelope"
125    )
126    parser.add_argument("--repeats", type=int, default=3, help="median timing repeats")
127    parser.add_argument("--seed", type=int, default=2027)
128    parser.add_argument(
129        "--time-fft-tail",
130        action="store_true",
131        help="also time a dependency-free FFT/FIR tail reference",
132    )
133    args = parser.parse_args()
134
135    if args.samples <= 0:
136        raise ValueError("--samples must be positive")
137    if args.order <= 0:
138        raise ValueError("--order must be positive")
139    if args.echo_taps <= 0:
140        raise ValueError("--echo-taps must be positive")
141    if not (0.0 < args.max_reflection < 1.0):
142        raise ValueError("--max-reflection must satisfy 0 < max_reflection < 1")
143
144    rng = np.random.default_rng(args.seed)
145    x = synthetic_speech_like_signal(args.samples, rng)
146    reflection, ladder = stable_echo_lattice(args.order, args.max_reflection, args.seed + 1)
147
148    def run_lattice() -> np.ndarray:
149        filt = ld.LatticeLadderIIR(reflection.tolist(), ladder.tolist())
150        return filt.process(x)
151
152    iir_time, y = median_time(run_lattice, args.repeats)
153    throughput = args.samples / max(iir_time, 1e-30) / 1e6
154    stage_rate = args.samples * args.order / max(iir_time, 1e-30) / 1e9
155    lattice_stage_visits = args.samples * args.order
156    fir_filter_tap_visits = args.samples * args.echo_taps
157    fir_lms_tap_visits = 2 * fir_filter_tap_visits
158    tap_to_order_ratio = args.echo_taps / args.order
159
160    rows: list[dict[str, object]] = [
161        {
162            "method": "lattice_ladder_iir_fixed_model",
163            "samples": args.samples,
164            "order_or_taps": args.order,
165            "median_seconds": iir_time,
166            "throughput_msamples_per_s": throughput,
167            "stage_visits_giga_per_s": stage_rate,
168            "output_rms": float(np.sqrt(np.mean(y * y))),
169            "max_abs_reflection": float(np.max(np.abs(reflection))),
170        },
171        {
172            "method": "fir_echo_scale_estimate_filter_only",
173            "samples": args.samples,
174            "order_or_taps": args.echo_taps,
175            "coefficient_touches": fir_filter_tap_visits,
176            "tap_to_lattice_order_ratio": tap_to_order_ratio,
177        },
178        {
179            "method": "fir_lms_scale_estimate_filter_plus_update",
180            "samples": args.samples,
181            "order_or_taps": args.echo_taps,
182            "coefficient_touches": fir_lms_tap_visits,
183            "tap_to_lattice_order_ratio": tap_to_order_ratio,
184        },
185    ]
186
187    print("large echo-scale stable recursive model stress")
188    print("=" * 56)
189    print(f"samples: {args.samples:,}")
190    print(f"lattice-ladder recursive order: {args.order:,}")
191    print(f"max |reflection|: {np.max(np.abs(reflection)):.6f}")
192    print(f"recursive state count: {args.order:,}")
193    print(f"ladder parameter count: {args.order + 1:,}")
194    print(f"median IIR/lattice-ladder time: {iir_time:.6f} s")
195    print(f"throughput: {throughput:.2f} million samples/s")
196    print(f"stage update rate: {stage_rate:.2f} billion stage-visits/s")
197    print(f"output RMS: {np.sqrt(np.mean(y * y)):.6f}")
198    print()
199    print("echo-scale comparison numbers")
200    print("-" * 56)
201    print(f"reference FIR echo taps: {args.echo_taps:,}")
202    print(f"FIR taps / lattice order: {tap_to_order_ratio:.1f}x")
203    print(f"lattice stage visits: {lattice_stage_visits:,}")
204    print(f"FIR filter tap visits, direct form: {fir_filter_tap_visits:,}")
205    print(f"FIR LMS filter+update tap visits, rough scale: {fir_lms_tap_visits:,}")
206    print("note: the tap-visit numbers are scale diagnostics, not an accuracy equivalence claim")
207
208    if args.time_fft_tail:
209        fft_time, y_fft = maybe_time_fft_tail(x, args.echo_taps, args.repeats)
210        rows.append(
211            {
212                "method": "optional_fft_fir_tail_reference",
213                "samples": args.samples,
214                "order_or_taps": args.echo_taps,
215                "median_seconds": fft_time,
216                "throughput_msamples_per_s": args.samples / max(fft_time, 1e-30) / 1e6,
217                "output_rms": float(np.sqrt(np.mean(y_fft * y_fft))),
218                "speedup_iir_vs_fft_tail": fft_time / max(iir_time, 1e-30),
219            }
220        )
221        print()
222        print("optional FFT/FIR reference")
223        print("-" * 56)
224        print(f"FFT/FIR median time: {fft_time:.6f} s")
225        print(f"IIR/lattice speedup over FFT/FIR tail: {fft_time / max(iir_time, 1e-30):.2f}x")
226        print(f"FFT/FIR output RMS: {np.sqrt(np.mean(y_fft * y_fft)):.6f}")
227
228    out_dir = artifact_dir()
229    csv_path = out_dir / "large_order_echo_stress.csv"
230    fieldnames = sorted({key for row in rows for key in row})
231    with csv_path.open("w", newline="", encoding="utf-8") as f:
232        writer = csv.DictWriter(f, fieldnames=fieldnames)
233        writer.writeheader()
234        writer.writerows(rows)
235    print()
236    print(f"wrote {csv_path}")
237
238
239if __name__ == "__main__":
240    main()