Multichannel audio decorrelation with energy preservation

Tutorial goal

Reduce channel correlation while keeping total signal energy roughly unchanged.

Note

New to the terminology? See the lattice DSP concept map and the causality/data-use guide for how online, offline, block, and MIMO examples should be read.

Context

Decorrelators are useful for spatial audio and multichannel processing demos. The goal here is not perceptual tuning; it is to show that matrix all-pass/lattice transforms can change correlation structure without changing total power much.

Key idea and equations

Let \(x[n]\in\mathbb{R}^c\) be the input block and \(y[n]\) the decorrelated output. The sample covariance matrices are

\[R_x = \mathbb{E}\{x[n]x[n]^T\}, \qquad R_y = \mathbb{E}\{y[n]y[n]^T\}.\]

Decorrelation aims to reduce off-diagonal covariance terms, for example

\[r_{off}(R)= \frac{\sum_{i\ne j}|R_{ij}|}{\sum_i |R_{ii}|}.\]

The transform is chosen to be approximately all-pass/unitary, so total energy should remain nearly unchanged:

\[\frac{\sum_n\lVert y[n]\rVert_2^2} {\sum_n\lVert x[n]\rVert_2^2} \approx 1.\]

The correlation heatmaps show the before/after coupling, while the energy ratio checks that decorrelation did not simply attenuate the signal.

Causality and data use

This decorrelator uses the causal OnlineMatrixLatticeAllPass runtime. Each output frame depends only on the current input frame and previous lattice states. The reported finite-block energy ratio can differ slightly from one because short prefixes omit the decaying all-pass tail.

What this example verifies

This verifies a causal forward decorrelator. The online all-pass runtime should reduce off-diagonal covariance/correlation while preserving total energy after the filter tail is included. It is a DSP diagnostic, not a perceptual audio product claim.

How to read the result

Use the before/after correlation matrices and summary bar plot to see the decorrelation effect; the energy ratio should remain close to one.

Run command

python examples/multichannel_audio_decorrelator.py

Source code

  1"""Streaming multichannel audio decorrelation with a real matrix-lattice all-pass filter.
  2
  3A real-coefficient matrix lattice all-pass filter can redistribute energy across
  4channels and frequency while preserving total power.  This example uses the
  5causal online runtime, so each output frame depends only on the current input
  6frame and previous lattice states.
  7"""
  8
  9from __future__ import annotations
 10
 11import os
 12from pathlib import Path
 13
 14import numpy as np
 15
 16from lattice_dsp import MatrixLatticeAllPass, contractive_matrix_from_raw, unitary_polar_factor
 17
 18
 19def _artifact_dir() -> Path:
 20    path = Path(os.environ.get("LATTICE_DSP_ARTIFACT_DIR", "reports/example-artifacts"))
 21    path.mkdir(parents=True, exist_ok=True)
 22    return path
 23
 24
 25def _make_real_filter(rng: np.random.Generator, channels: int, order: int) -> MatrixLatticeAllPass:
 26    reflections = [
 27        contractive_matrix_from_raw(0.45 * rng.normal(size=(channels, channels)))
 28        for _ in range(order)
 29    ]
 30    residue = unitary_polar_factor(rng.normal(size=(channels, channels)))
 31    return MatrixLatticeAllPass(reflections, residue=residue)
 32
 33
 34def _apply_real_streaming_filter(x: np.ndarray, filt: MatrixLatticeAllPass) -> np.ndarray:
 35    runtime = filt.to_online_filter()
 36    y = runtime.process(x)
 37    return np.real_if_close(y, tol=1000).real
 38
 39
 40def _mean_abs_off_diagonal_correlation(x: np.ndarray) -> float:
 41    corr = np.corrcoef(x, rowvar=False)
 42    upper = corr[np.triu_indices_from(corr, k=1)]
 43    return float(np.mean(np.abs(upper)))
 44
 45
 46def _save_figures(x: np.ndarray, y: np.ndarray, input_corr: float, output_corr: float) -> None:
 47    try:
 48        import matplotlib.pyplot as plt
 49    except ImportError:  # pragma: no cover - optional plotting dependency
 50        print("matplotlib is not installed; skipped figures")
 51        return
 52
 53    out_dir = _artifact_dir()
 54    corr_x = np.corrcoef(x, rowvar=False)
 55    corr_y = np.corrcoef(y, rowvar=False)
 56
 57    fig, axes = plt.subplots(1, 2, figsize=(8.4, 3.6))
 58    for ax, title, corr in (
 59        (axes[0], f"input correlation\nmean |offdiag|={input_corr:.3f}", corr_x),
 60        (axes[1], f"output correlation\nmean |offdiag|={output_corr:.3f}", corr_y),
 61    ):
 62        im = ax.imshow(corr, vmin=-1.0, vmax=1.0)
 63        ax.set_title(title)
 64        ax.set_xlabel("channel")
 65        ax.set_ylabel("channel")
 66        fig.colorbar(im, ax=ax, shrink=0.78)
 67    fig.tight_layout()
 68    path = out_dir / "multichannel_audio_decorrelator_correlation.png"
 69    fig.savefig(path, dpi=160)
 70    plt.close(fig)
 71    print(f"wrote {path}")
 72
 73    segment = slice(0, 700)
 74    fig, ax = plt.subplots(figsize=(7.2, 4.0))
 75    ax.plot(x[segment, 0], label="input ch0")
 76    ax.plot(y[segment, 0], linestyle="--", label="output ch0")
 77    ax.set_xlabel("sample")
 78    ax.set_ylabel("normalized amplitude")
 79    ax.set_title("Same-energy decorrelation changes waveform shape")
 80    ax.legend(loc="best")
 81    fig.tight_layout()
 82    path = out_dir / "multichannel_audio_decorrelator_waveform.png"
 83    fig.savefig(path, dpi=160)
 84    plt.close(fig)
 85    print(f"wrote {path}")
 86
 87    fig, ax = plt.subplots(figsize=(6.6, 4.0))
 88    ax.bar(["input", "output"], [input_corr, output_corr])
 89    ax.set_ylabel("mean |off-diagonal correlation|")
 90    ax.set_title("Correlation decreases while total energy is preserved")
 91    fig.tight_layout()
 92    path = out_dir / "multichannel_audio_decorrelator_corr_summary.png"
 93    fig.savefig(path, dpi=160)
 94    plt.close(fig)
 95    print(f"wrote {path}")
 96
 97
 98rng = np.random.default_rng(4)
 99channels = 4
100order = 5
101n_samples = 8192
102
103# A smooth shared source creates highly correlated channels.  Small delays and
104# independent noise make it more realistic than exact copies.
105white = rng.normal(size=n_samples + 128)
106source = np.convolve(white, np.ones(64) / 64.0, mode="valid")[:n_samples]
107input_channels = []
108for ch in range(channels):
109    delayed = np.roll(source, 3 * ch)
110    input_channels.append(0.9 * delayed + 0.1 * rng.normal(size=n_samples))
111x = np.stack(input_channels, axis=1)
112x = (x - x.mean(axis=0)) / x.std(axis=0)
113
114filt = _make_real_filter(rng, channels, order)
115y = _apply_real_streaming_filter(x, filt)
116y = (y - y.mean(axis=0)) / y.std(axis=0)
117
118input_corr = _mean_abs_off_diagonal_correlation(x)
119output_corr = _mean_abs_off_diagonal_correlation(y)
120energy_ratio = float(np.sum(y * y) / np.sum(x * x))
121
122print("channels:", channels)
123print("order:", order)
124print("max reflection singular value:", round(filt.max_reflection_singular_value(), 6))
125print("input mean |offdiag corr|:", round(input_corr, 4))
126print("output mean |offdiag corr|:", round(output_corr, 4))
127print("decorrelation factor:", round(input_corr / output_corr, 2), "x")
128print("normalized energy ratio:", f"{energy_ratio:.6f}")
129print(
130    "takeaway: causal MIMO all-pass filtering can decorrelate channels without using future samples"
131)
132
133_save_figures(x, y, input_corr, output_corr)