MIMO long-signal state-space stress

Tutorial goal

Reduce a coupled MIMO state-space model with the finite block-Hankel workflow, then process long batched multichannel signals through the compiled C++ runtime.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

This is the multichannel counterpart to the scalar long-signal stress examples. A coupled MIMO system is converted to Markov parameters, reduced with finite_hankel_reduce_mimo, and then reused on a long batched input through mimo_state_space_process_batch. The printed comparison numbers are scale diagnostics for MIMO echo-style paths, not claims of accuracy equivalence to every long FIR model and not a matrix-valued AAK/Nehari solver claim.

Key idea and equations

A MIMO state-space model uses

\[x_s[n+1] = A x_s[n] + B u[n], \qquad y[n] = C x_s[n] + D u[n].\]

Its Markov matrices M_k map input channels to output channels at lag k. The finite MIMO block-Hankel reducer builds a matrix whose blocks are these Markov matrices and returns a lower-order state-space realization. For comparison, a direct MIMO FIR echo model with L taps per input-output path has scale

\[N \, L \, m \, p,\]

for N samples, m inputs, and p outputs. The state-space runtime has a dense recursive scale tied to the chosen state order instead of the FIR tap count.

How to read the result

Inspect the finite block-Hankel reduction time, retained energy, reduced runtime, output-channel throughput, and the printed direct MIMO FIR tap-visit scale.

Run command

python examples/mimo_long_signal_state_space_stress.py

Representative local output

The following output is from one local run of the default 8-by-8 stress command. Exact timings are machine-dependent, but the scale relationship is the point: the finite MIMO reduction produced a stable order-16 model that processed one million multichannel samples while the equivalent direct MIMO FIR tap-visit count was orders of magnitude larger.

MIMO long-signal finite-Hankel/state-space stress
================================================================
batch streams: 1
samples per stream: 1,000,000
inputs x outputs: 8 x 8
full MIMO state order: 64
dominant full-model pole radius target: 0.985000
full-model spectral radius: 0.985000
Markov samples for reduction: 320
block-Hankel matrix: 192 x 192
reduced order: 16
Markov generation time: 0.004333 s
finite MIMO block-Hankel reduction time: 0.332755 s
retained Hankel energy: 0.999787
relative Markov error: 1.963e-03
reduced model stable: True
reduced spectral radius: 0.983021

compiled reduced MIMO runtime
----------------------------------------------------------------
median reduced state-space time: 0.151889 s
throughput: 6.58 million multichannel samples/s
output-channel throughput: 52.67 million output samples/s
dense reduced state-space visits: 576,000,000
dense visit rate: 3.79 billion visits/s
output RMS: 1.914513

MIMO echo-scale comparison numbers
----------------------------------------------------------------
reference MIMO FIR taps per input-output path: 131,072
FIR taps / reduced state order: 8192.0x
full dense state-space visits at same signal size: 5,184,000,000
reduced dense state-space visits: 576,000,000
direct MIMO FIR filter visits: 8,388,608,000,000
direct MIMO FIR LMS filter+update visits, rough scale: 16,777,216,000,000
note: these are scale diagnostics, not an accuracy equivalence claim

Run status

Return code: 0

Captured stdout

MIMO long-signal finite-Hankel/state-space stress
================================================================
batch streams: 1
samples per stream: 1,000,000
inputs x outputs: 8 x 8
full MIMO state order: 64
dominant full-model pole radius target: 0.985000
full-model spectral radius: 0.985000
Markov samples for reduction: 320
block-Hankel matrix: 192 x 192
reduced order: 16
Markov generation time: 0.005340 s
finite MIMO block-Hankel reduction time: 0.442447 s
retained Hankel energy: 0.999787
relative Markov error: 1.963e-03
reduced model stable: True
reduced spectral radius: 0.983021

compiled reduced MIMO runtime
----------------------------------------------------------------
median reduced state-space time: 0.307337 s
throughput: 3.25 million multichannel samples/s
output-channel throughput: 26.03 million output samples/s
dense reduced state-space visits: 576,000,000
dense visit rate: 1.87 billion visits/s
output RMS: 1.914513

MIMO echo-scale comparison numbers
----------------------------------------------------------------
reference MIMO FIR taps per input-output path: 131,072
FIR taps / reduced state order: 8192.0x
full dense state-space visits at same signal size: 5,184,000,000
reduced dense state-space visits: 576,000,000
direct MIMO FIR filter visits: 8,388,608,000,000
direct MIMO FIR LMS filter+update visits, rough scale: 16,777,216,000,000
note: these are scale diagnostics, not an accuracy equivalence claim

Generated data files

Source code

  1"""MIMO long-signal stress for finite block-Hankel reduction and compiled state-space runtime.
  2
  3This example complements the scalar long-signal and echo-scale stress tutorials.
  4It focuses on the multichannel niche: a coupled MIMO system is reduced with the
  5finite block-Hankel workflow, then the reduced state-space model is reused on a
  6long batched multichannel signal through the compiled C++ runtime.
  7
  8The comparison numbers are deliberately scale diagnostics.  They show the cost
  9profile of a dense recursive MIMO state-space model against a long direct-form
 10MIMO FIR echo model.  They are not an accuracy-equivalence claim and they are not
 11a matrix-valued AAK/Nehari solver claim.
 12"""
 13
 14from __future__ import annotations
 15
 16import argparse
 17import csv
 18import os
 19import statistics
 20import time
 21from pathlib import Path
 22
 23import numpy as np
 24
 25import lattice_dsp as ld
 26
 27
 28def artifact_dir() -> Path:
 29    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 30    path.mkdir(parents=True, exist_ok=True)
 31    return path
 32
 33
 34def median_time(fn, repeats: int) -> tuple[float, object]:
 35    times: list[float] = []
 36    result: object | None = None
 37    for _ in range(max(1, repeats)):
 38        t0 = time.perf_counter()
 39        result = fn()
 40        times.append(time.perf_counter() - t0)
 41    assert result is not None
 42    return statistics.median(times), result
 43
 44
 45def stable_coupled_state_space(
 46    order: int,
 47    outputs: int,
 48    inputs: int,
 49    radius: float,
 50    seed: int,
 51) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 52    """Construct a deterministic stable dense MIMO state-space model."""
 53    rng = np.random.default_rng(seed)
 54
 55    # A random orthogonal similarity keeps the system coupled while making the
 56    # spectral radius explicit.  The descending radii create a long but stable
 57    # memory profile.
 58    q, _ = np.linalg.qr(rng.normal(size=(order, order)))
 59    radii = np.linspace(radius, 0.15, order, dtype=np.float64)
 60    A = q @ np.diag(radii) @ q.T
 61
 62    # Scale B and C by sqrt(order) so the output RMS remains controlled when the
 63    # state dimension changes.
 64    scale = max(np.sqrt(order), 1.0)
 65    B = rng.normal(size=(order, inputs)) / scale
 66    C = rng.normal(size=(outputs, order)) / scale
 67    D = 0.015 * rng.normal(size=(outputs, inputs))
 68    return A.astype(np.float64), B.astype(np.float64), C.astype(np.float64), D.astype(np.float64)
 69
 70
 71def colored_multichannel_input(
 72    batch: int,
 73    samples: int,
 74    inputs: int,
 75    seed: int,
 76) -> np.ndarray:
 77    """Generate dependency-free multichannel input with light temporal coloring."""
 78    rng = np.random.default_rng(seed)
 79    white = rng.normal(size=(batch, samples, inputs)).astype(np.float64)
 80    x = np.empty_like(white)
 81    state = np.zeros((batch, inputs), dtype=np.float64)
 82    for n in range(samples):
 83        state = 0.88 * state + 0.12 * white[:, n, :]
 84        x[:, n, :] = state
 85    rms = np.sqrt(np.mean(x * x, axis=(1, 2), keepdims=True))
 86    return x / np.maximum(rms, 1e-30)
 87
 88
 89def dense_state_space_visit_scale(
 90    batch: int,
 91    samples: int,
 92    order: int,
 93    inputs: int,
 94    outputs: int,
 95) -> int:
 96    """Approximate dense multiply-add visits for a batched MIMO state-space pass."""
 97    per_sample = order * order + order * inputs + outputs * order + outputs * inputs
 98    return int(batch) * int(samples) * int(per_sample)
 99
100
101def direct_mimo_fir_visit_scale(
102    batch: int,
103    samples: int,
104    taps: int,
105    inputs: int,
106    outputs: int,
107) -> int:
108    """Approximate direct-form MIMO FIR multiply-add visits."""
109    return int(batch) * int(samples) * int(taps) * int(inputs) * int(outputs)
110
111
112def spectral_radius(A: np.ndarray) -> float:
113    if A.size == 0:
114        return 0.0
115    return float(np.max(np.abs(np.linalg.eigvals(A))))
116
117
118def main() -> None:
119    parser = argparse.ArgumentParser(description=__doc__)
120    parser.add_argument(
121        "--samples", type=int, default=1_000_000, help="number of time samples per stream"
122    )
123    parser.add_argument(
124        "--batch", type=int, default=1, help="number of independent multichannel streams"
125    )
126    parser.add_argument("--inputs", type=int, default=8, help="input channels")
127    parser.add_argument("--outputs", type=int, default=8, help="output channels")
128    parser.add_argument(
129        "--full-order",
130        type=int,
131        default=64,
132        help="order of the synthetic full MIMO state-space model",
133    )
134    parser.add_argument(
135        "--reduced-order",
136        type=int,
137        default=16,
138        help="finite block-Hankel reduced order used for long-signal processing",
139    )
140    parser.add_argument(
141        "--markov-samples",
142        type=int,
143        default=320,
144        help="number of MIMO Markov parameters used for reduction",
145    )
146    parser.add_argument(
147        "--block-rows", type=int, default=24, help="block rows in the finite block-Hankel matrix"
148    )
149    parser.add_argument(
150        "--block-cols", type=int, default=24, help="block columns in the finite block-Hankel matrix"
151    )
152    parser.add_argument(
153        "--fir-taps",
154        type=int,
155        default=131_072,
156        help="reference MIMO FIR echo-tap count for scale estimates",
157    )
158    parser.add_argument(
159        "--radius",
160        type=float,
161        default=0.985,
162        help="dominant stable pole radius of the generated full model",
163    )
164    parser.add_argument("--repeats", type=int, default=3, help="median timing repeats")
165    parser.add_argument("--seed", type=int, default=2028)
166    parser.add_argument(
167        "--n-threads",
168        type=int,
169        default=0,
170        help="threads for the compiled state-space runtime; 0 uses backend default",
171    )
172    parser.add_argument(
173        "--time-full",
174        action="store_true",
175        help="also time the unreduced full-order state-space model",
176    )
177    args = parser.parse_args()
178
179    if args.samples <= 0:
180        raise ValueError("--samples must be positive")
181    if args.batch <= 0:
182        raise ValueError("--batch must be positive")
183    if args.inputs <= 0 or args.outputs <= 0:
184        raise ValueError("--inputs and --outputs must be positive")
185    if args.full_order <= 0 or args.reduced_order <= 0:
186        raise ValueError("--full-order and --reduced-order must be positive")
187    if args.reduced_order > args.full_order:
188        raise ValueError("--reduced-order must not exceed --full-order")
189    if args.markov_samples <= args.block_rows + args.block_cols:
190        raise ValueError("--markov-samples must exceed --block-rows + --block-cols")
191    if args.fir_taps <= 0:
192        raise ValueError("--fir-taps must be positive")
193    if not (0.0 < args.radius < 1.0):
194        raise ValueError("--radius must satisfy 0 < radius < 1")
195
196    A, B, C, D = stable_coupled_state_space(
197        args.full_order,
198        args.outputs,
199        args.inputs,
200        args.radius,
201        args.seed,
202    )
203
204    markov_time, markov_obj = median_time(
205        lambda: ld.mimo_state_space_markov_response(A, B, C, D, args.markov_samples),
206        1,
207    )
208    markov = np.asarray(markov_obj, dtype=np.float64)
209
210    reduce_time, result_obj = median_time(
211        lambda: ld.finite_hankel_reduce_mimo(
212            markov,
213            reduced_order=args.reduced_order,
214            block_rows=args.block_rows,
215            block_cols=args.block_cols,
216        ),
217        1,
218    )
219    result = dict(result_obj)
220    Ar = np.asarray(result["A"], dtype=np.float64)
221    Br = np.asarray(result["B"], dtype=np.float64)
222    Cr = np.asarray(result["C"], dtype=np.float64)
223    Dr = np.asarray(result["D"], dtype=np.float64)
224
225    approx = np.asarray(
226        ld.mimo_state_space_markov_response(Ar, Br, Cr, Dr, args.markov_samples), dtype=np.float64
227    )
228    relative_markov_error = float(
229        np.sum((markov - approx) ** 2) / np.maximum(np.sum(markov * markov), 1e-30)
230    )
231
232    x = colored_multichannel_input(args.batch, args.samples, args.inputs, args.seed + 1)
233
234    reduced_time, y_obj = median_time(
235        lambda: ld.mimo_state_space_process_batch(Ar, Br, Cr, Dr, x, n_threads=args.n_threads),
236        args.repeats,
237    )
238    y = np.asarray(y_obj, dtype=np.float64)
239
240    reduced_throughput_streams = args.batch * args.samples / max(reduced_time, 1e-30) / 1e6
241    reduced_channel_throughput = (
242        args.batch * args.samples * args.outputs / max(reduced_time, 1e-30) / 1e6
243    )
244    reduced_visits = dense_state_space_visit_scale(
245        args.batch, args.samples, args.reduced_order, args.inputs, args.outputs
246    )
247    full_visits = dense_state_space_visit_scale(
248        args.batch, args.samples, args.full_order, args.inputs, args.outputs
249    )
250    fir_visits = direct_mimo_fir_visit_scale(
251        args.batch, args.samples, args.fir_taps, args.inputs, args.outputs
252    )
253    fir_lms_visits = 2 * fir_visits
254    hankel_rows = args.block_rows * args.outputs
255    hankel_cols = args.block_cols * args.inputs
256
257    rows: list[dict[str, object]] = [
258        {
259            "method": "mimo_markov_generation_cpp",
260            "seconds": markov_time,
261            "markov_samples": args.markov_samples,
262            "full_order": args.full_order,
263            "inputs": args.inputs,
264            "outputs": args.outputs,
265        },
266        {
267            "method": "finite_block_hankel_reduce_mimo_cpp",
268            "seconds": reduce_time,
269            "full_order": args.full_order,
270            "reduced_order": args.reduced_order,
271            "block_hankel_rows": hankel_rows,
272            "block_hankel_cols": hankel_cols,
273            "retained_hankel_energy": float(result.get("retained_hankel_energy", np.nan)),
274            "relative_markov_error": relative_markov_error,
275            "reduced_stable": bool(result.get("stable", False)),
276        },
277        {
278            "method": "mimo_state_space_process_batch_reduced_cpp",
279            "seconds": reduced_time,
280            "samples": args.samples,
281            "batch": args.batch,
282            "inputs": args.inputs,
283            "outputs": args.outputs,
284            "reduced_order": args.reduced_order,
285            "throughput_mstreams_per_s": reduced_throughput_streams,
286            "throughput_moutput_channels_per_s": reduced_channel_throughput,
287            "dense_state_space_visits": reduced_visits,
288            "dense_state_space_visit_rate_giga_per_s": reduced_visits
289            / max(reduced_time, 1e-30)
290            / 1e9,
291            "output_rms": float(np.sqrt(np.mean(y * y))),
292            "reduced_spectral_radius": spectral_radius(Ar),
293        },
294        {
295            "method": "direct_mimo_fir_echo_scale_estimate_filter_only",
296            "samples": args.samples,
297            "batch": args.batch,
298            "inputs": args.inputs,
299            "outputs": args.outputs,
300            "fir_taps": args.fir_taps,
301            "direct_fir_visits": fir_visits,
302            "fir_taps_per_reduced_state": args.fir_taps / args.reduced_order,
303        },
304        {
305            "method": "direct_mimo_fir_lms_scale_estimate_filter_plus_update",
306            "samples": args.samples,
307            "batch": args.batch,
308            "inputs": args.inputs,
309            "outputs": args.outputs,
310            "fir_taps": args.fir_taps,
311            "direct_fir_lms_visits": fir_lms_visits,
312            "fir_taps_per_reduced_state": args.fir_taps / args.reduced_order,
313        },
314    ]
315
316    print("MIMO long-signal finite-Hankel/state-space stress")
317    print("=" * 64)
318    print(f"batch streams: {args.batch:,}")
319    print(f"samples per stream: {args.samples:,}")
320    print(f"inputs x outputs: {args.inputs} x {args.outputs}")
321    print(f"full MIMO state order: {args.full_order:,}")
322    print(f"dominant full-model pole radius target: {args.radius:.6f}")
323    print(f"full-model spectral radius: {spectral_radius(A):.6f}")
324    print(f"Markov samples for reduction: {args.markov_samples:,}")
325    print(f"block-Hankel matrix: {hankel_rows:,} x {hankel_cols:,}")
326    print(f"reduced order: {args.reduced_order:,}")
327    print(f"Markov generation time: {markov_time:.6f} s")
328    print(f"finite MIMO block-Hankel reduction time: {reduce_time:.6f} s")
329    print(f"retained Hankel energy: {float(result.get('retained_hankel_energy', np.nan)):.6f}")
330    print(f"relative Markov error: {relative_markov_error:.3e}")
331    print(f"reduced model stable: {bool(result.get('stable', False))}")
332    print(f"reduced spectral radius: {spectral_radius(Ar):.6f}")
333    print()
334    print("compiled reduced MIMO runtime")
335    print("-" * 64)
336    print(f"median reduced state-space time: {reduced_time:.6f} s")
337    print(f"throughput: {reduced_throughput_streams:.2f} million multichannel samples/s")
338    print(f"output-channel throughput: {reduced_channel_throughput:.2f} million output samples/s")
339    print(f"dense reduced state-space visits: {reduced_visits:,}")
340    print(
341        f"dense visit rate: {reduced_visits / max(reduced_time, 1e-30) / 1e9:.2f} billion visits/s"
342    )
343    print(f"output RMS: {np.sqrt(np.mean(y * y)):.6f}")
344    print()
345    print("MIMO echo-scale comparison numbers")
346    print("-" * 64)
347    print(f"reference MIMO FIR taps per input-output path: {args.fir_taps:,}")
348    print(f"FIR taps / reduced state order: {args.fir_taps / args.reduced_order:.1f}x")
349    print(f"full dense state-space visits at same signal size: {full_visits:,}")
350    print(f"reduced dense state-space visits: {reduced_visits:,}")
351    print(f"direct MIMO FIR filter visits: {fir_visits:,}")
352    print(f"direct MIMO FIR LMS filter+update visits, rough scale: {fir_lms_visits:,}")
353    print("note: these are scale diagnostics, not an accuracy equivalence claim")
354
355    if args.time_full:
356        full_time, y_full_obj = median_time(
357            lambda: ld.mimo_state_space_process_batch(A, B, C, D, x, n_threads=args.n_threads),
358            args.repeats,
359        )
360        y_full = np.asarray(y_full_obj, dtype=np.float64)
361        rows.append(
362            {
363                "method": "mimo_state_space_process_batch_full_cpp",
364                "seconds": full_time,
365                "samples": args.samples,
366                "batch": args.batch,
367                "inputs": args.inputs,
368                "outputs": args.outputs,
369                "full_order": args.full_order,
370                "throughput_mstreams_per_s": args.batch
371                * args.samples
372                / max(full_time, 1e-30)
373                / 1e6,
374                "throughput_moutput_channels_per_s": args.batch
375                * args.samples
376                * args.outputs
377                / max(full_time, 1e-30)
378                / 1e6,
379                "dense_state_space_visits": full_visits,
380                "dense_state_space_visit_rate_giga_per_s": full_visits
381                / max(full_time, 1e-30)
382                / 1e9,
383                "output_rms": float(np.sqrt(np.mean(y_full * y_full))),
384                "speedup_reduced_vs_full": full_time / max(reduced_time, 1e-30),
385            }
386        )
387        print()
388        print("optional full-order runtime")
389        print("-" * 64)
390        print(f"median full state-space time: {full_time:.6f} s")
391        print(f"reduced/full runtime speedup: {full_time / max(reduced_time, 1e-30):.2f}x")
392        print(f"full output RMS: {np.sqrt(np.mean(y_full * y_full)):.6f}")
393
394    out_dir = artifact_dir()
395    csv_path = out_dir / "mimo_long_signal_state_space_stress.csv"
396    fieldnames = sorted({key for row in rows for key in row})
397    with csv_path.open("w", newline="", encoding="utf-8") as f:
398        writer = csv.DictWriter(f, fieldnames=fieldnames)
399        writer.writeheader()
400        writer.writerows(rows)
401    print()
402    print(f"wrote {csv_path}")
403
404
405if __name__ == "__main__":
406    main()