Skip to content

Performance: parallel preprocessing + FIR phase-shift#4562

Draft
galenlynch wants to merge 4 commits intoSpikeInterface:mainfrom
galenlynch:perf/parallel-preprocessing
Draft

Performance: parallel preprocessing + FIR phase-shift#4562
galenlynch wants to merge 4 commits intoSpikeInterface:mainfrom
galenlynch:perf/parallel-preprocessing

Conversation

@galenlynch
Copy link
Copy Markdown
Contributor

@galenlynch galenlynch commented Apr 23, 2026

Three independet opt-in improvements to the preprocessing chain, plus one small refactor that enables them cleanly. Each is opt-in via a new kwarg; existing defaults and outputs are unchanged.

Headline. For a standard preprocessing pipeline, with full int16 → bandpass → CMR → phase-shift → int16 pipeline on a 1M × 384 chunk, processing time went from 87.7 s → 6.4 s (13.7×). This was with every performance option introduced here enabled. See bench_pipeline_int16 below for the exact configuration.

Extrapolated to a full 90-min NP 2.0 recording (30 kHz × 384 ch ≈ 155 × 1M-sample shards), single-worker get_traces() sequential walk:

Pipeline Time
Stock (FFT, serial) ~3 h 46 min 13,553 s
Parallel + FIR (int16 preserved) ~16 min 30 s 987 s
Savings ~3 h 30 min

The per-stage breakdown with just one preprocessor (each number labeled with what it measures — see benchmarks/preprocessing/bench_perf.py) is:

Change Kwarg Default Component speedup (hot kernel only) End-to-end speedup (full get_traces())
Bandpass/Highpass channel-parallel SOS FilterRecording(n_workers=N) 1 2.92× 2.69×
CMR median/mean time-parallel reduction CommonReferenceRecording(n_workers=N) 1 10.58× 4.95×
Sinc FIR phase-shift (float32 parent) PhaseShiftRecording(method="fir") "fft" — (FFT has no single hot kernel) 100.8×
int16-native FIR fast path PhaseShiftRecording(method="fir", output_dtype=np.float32) None 162.0×

Pipeline configuration:

  • Stock: BandpassFilterRecording(rec)CommonReferenceRecording(...)PhaseShiftRecording(..., method="fft") — all defaults.
  • Fast: BandpassFilterRecording(rec, n_workers=8)CommonReferenceRecording(..., n_workers=16)PhaseShiftRecording(..., method="fir") — three opt-ins, no dtype widening.

Tested on a 24-core x86-64 host.

Motivation

Profiling of streaming recordings through filter chain showed:

  • Phase-shift (scipy pocketfft-based) burned ~98% of filter-chain CPU on a 1M-sample chunk.
  • Bandpass (scipy sosfiltfilt) was single-threaded — 23 of 24 cores sat idle.
  • CMR median (numpy np.median) was single-threaded.

Each change is a natural fit for Python-thread parallelism because the underlying C kernels release the GIL during per-column/per-row work, so no multiprocessing is needed. The FIR path is a well-known alternative to FFT-based fractional-delay interpolation, validated against the existing FFT reference on real NP data.

Changes

1. FilterRecording(n_workers=N) — channel-parallel SOS

File: src/spikeinterface/preprocessing/filter.py

  • New n_workers kwarg (default 1, preserving existing behavior).
  • When n_workers > 1, FilterRecordingSegment.get_traces splits the channel axis into contiguous blocks and runs scipy.signal.sosfiltfilt/sosfilt on each block in a per-segment ThreadPoolExecutor.
  • Graceful fallback to serial when channel count is smaller than 2 * n_workers.
  • scipy's SOS C implementations release the GIL per column, so threading delivers real speedup.

2. CommonReferenceRecording(n_workers=N) — time-parallel reduction

File: src/spikeinterface/preprocessing/common_reference.py

  • New n_workers kwarg (default 1).
  • Only the common global-reference path (group_indices=None, reference="global", ref_channel_ids=None) is parallelized — every other configuration delegates to the existing logic unchanged.
  • When n_workers > 1, _parallel_reduce_axis1 splits the time axis into blocks and runs np.median/np.mean per block in a thread pool.
  • Below min_block=8192 samples per thread the overhead dominates; falls back to serial automatically.

3. PhaseShiftRecording(method="fft"|"fir") — sinc FIR alternative

File: src/spikeinterface/preprocessing/phase_shift.py

  • New method kwarg (default "fft" for backward compatibility).
  • method="fir" uses a 32-tap Kaiser-windowed sinc FIR, implemented as numba-jit kernels with prange parallelism over time.
  • Per-channel kernels are precomputed once per segment (sample shifts are fixed for the recording's lifetime); previous FFT path recomputed them per chunk.
  • FIR margin is n_taps // 2 samples (16 for the 32-tap default), not the 40 ms the FFT path needs.
  • n_taps configurable (default 32, validated as even and >= 2).
Why FFT ↔ FIR is algorithmically equivalent (Whittaker–Shannon)

A fractional-sample delay is, by definition, a sinc interpolation at the desired offset. Whittaker–Shannon states that any signal bandlimited to Nyquist is exactly reconstructed from its samples via

$$x(t) = \sum_n x[n] \cdot \mathrm{sinc}(t - n)$$

so delaying channel $c$ by a fractional $d_c$ samples is

$$y_c[n] = \sum_k x_c[n-k] \cdot \mathrm{sinc}(k - d_c)$$

— convolution with the ideal fractional-delay kernel $h_{d_c}[k] = \mathrm{sinc}(k - d_c)$.

The existing FFT path realises this convolution spectrally: multiplying by $e^{i,2\pi f, d_c}$ in the frequency domain is the DFT of that same infinite sinc kernel. The FIR path realises it in the time domain as explicit linear convolution against a Kaiser-windowed, 32-tap truncation of the same sinc. The operation is identical; the only approximations are:

  1. Sinc truncation. The ideal kernel has infinite support; we keep 32 taps. For bandlimited input the sinc tails decay quickly, so 32 taps captures > 99% of the kernel energy for any $d \in [0, 1)$. Longer kernels trade compute for accuracy: 16 taps ≈ 0.8% RMS, 32 taps ≈ 0.19% RMS, 64 taps < 0.05% RMS vs the FFT reference on real NP 2.0 data.
  2. Kaiser windowing (β = 8.6). Rectangular truncation of the sinc would convolve the frequency response with a Dirichlet kernel → Gibbs-phenomenon ripples in the passband. Kaiser trades a small main-lobe widening for ≈ −80 dB stopband attenuation, which is well below the ≈ 50 dB SNR of a 12-bit acquisition system — windowing error is physically unmeasurable in ephys.

The measured 0.19% spike-band RMS against the FFT reference (Correctness table) is orders of magnitude below the analog noise floor (~4 ADC LSB RMS on NP 2.0), confirming the two approximations are benign for every physically meaningful downstream use (spike sorting, waveform extraction, envelope rendering).

Why the FFT path's 40 ms margin is unnecessary for the FIR

The 40 ms margin in the FFT path is not a property of fractional-delay interpolation — it's a workaround for the fact that rfft/irfft operate on a periodic buffer. Without zero-padding + a raised-cosine taper, energy from the right edge of a chunk wraps circularly to the left edge (and vice versa). A bounded-support linear FIR has no such wraparound: n_taps // 2 samples of margin on each side (16 for the 32-tap default) is sufficient and mathematically exact under linear convolution. That saved margin is why the FIR's end-to-end speedup (100.8×) is larger than its kernel-only speedup (≈ 83×).

int16-native fast path

PhaseShiftRecording(method="fir", output_dtype=np.float32) on an integer-dtyped parent enables a second kernel specialisation:

  • Reads int16 directly, accumulates in float32, writes float32.
  • No int16 → float64 promotion, no float64 → int16 round-trip.
  • Advertises float32 as the recording's output dtype. Downstream filters (bandpass, CMR) consume float directly — which they need to anyway for their own math — so the net effect is removing three redundant cast passes from the pipeline.
  • Opt-in via explicit output_dtype=np.float32; default behavior is unchanged (FIR still round-and-casts back to the parent dtype).

4. get_chunk_with_margin — extract FFT-specific taper

File: src/spikeinterface/core/time_series_tools.py

  • New public function apply_raised_cosine_taper(data, margin, *, inplace=True) exposes the raised-cosine window that was previously inlined in get_chunk_with_margin(window_on_margin=True).
  • window_on_margin=True continues to work but is deprecated: it emits a DeprecationWarning and delegates to apply_raised_cosine_taper.
  • The FFT-based PhaseShiftRecording path is updated to call get_chunk_with_margin(window_on_margin=False) and then apply_raised_cosine_taper explicitly. Output is bit-for-bit equivalent to pre-refactor behavior (regression test added).
  • Rationale: the taper is FFT-specific (suppresses spectral leakage from zero-padded boundaries). Before this refactor, get_chunk_with_margin was unusable for bounded-support filters both because the taper was redundant and because the in-place *= against a float taper fails on int-typed chunks. Separating the concern makes the utility filter-method-agnostic, which in turn unlocks the int16-native FIR path.

Correctness

Path Check Result
Parallel SOS vs stock np.allclose(rtol=1e-5) Pass — float-equivalent on float32
Parallel median vs stock np.array_equal Pass — bit-identical (median is deterministic)
Parallel mean vs stock np.allclose(rtol=1e-5) Within 1 ULP (non-associative sum across block partitions)
FIR vs FFT signal-band RMS < 1% on synthetic Pass at ~0.2%
FIR vs FFT spike-band RMS < 0.5% on real NP 2.0 data ~0.19% measured
Existing test_phase_shift (FFT chunked-vs-full identity) error_mean / rms < 0.001 Pass (regression guard on the taper refactor)

The existing tests for all three modules pass unchanged.

Per-stage benchmarks (reproducible)

benchmarks/preprocessing/bench_perf.py — synthetic NumpyRecording, 1M × 384 chunks, measured on a 24-core x86_64 host (SI 0.103 dev, numpy 2.1, scipy 1.14, numba 0.60). Two tiers reported:

Component-level (hot operation only)

Isolated kernel / reduction — no get_chunk_with_margin, no dtype casts, no slicing. Shows the raw speedup of the parallelization technique itself.

--- [component] sosfiltfilt (1M x 384 float32) ---
  scipy.sosfiltfilt serial:        7.80 s
  scipy.sosfiltfilt 8 threads:     2.67 s   (2.92x)

--- [component] np.median axis=1 (1M x 384 float32) ---
  np.median serial:                3.51 s
  np.median 16 threads:            0.33 s   (10.58x)

End-to-end, per stage (rec.get_traces())

Full SI preprocessing class through get_traces(), including margin fetch, buffer copies, casts, and subtraction. Each stage measured in isolation.

=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===
  stock (n_workers=1):       8.59 s
  parallel (n_workers=8):    3.20 s   (2.69x)
  output matches stock within float32 tolerance

=== CMR median (global, 1M x 384 float32) ===
  stock (n_workers=1):       4.01 s
  parallel (n_workers=16):   0.81 s   (4.95x)
  output is bitwise-identical to stock

=== PhaseShift (1M x 384 float32) ===
  method="fft":             68.38 s
  method="fir":             0.679 s   (100.75x)
  spike-band RMS error / signal RMS: 0.198%

=== PhaseShift int16-native (1M x 384 int16) ===
  method="fft" (int16 out):     68.67 s
  method="fir" + f32 out:       0.424 s   (161.98x)

Why component ≠ per-stage end-to-end

Stage Component Per-stage e2e Dilution cause
Bandpass 2.92× 2.69× get_chunk_with_margin margin fetch + dtype cast — negligible
CMR median 10.58× 4.95× ~0.5 s of serial traces - shift subtraction (1.5 GB r/w, memory-bandwidth-bound) that isn't parallelizable
PhaseShift FIR float32 — (FFT has no single hot kernel to pit the FIR against) 100.75× e2e is bigger than kernel alone because FIR also cuts the 40 ms FFT margin to 16 samples
PhaseShift FIR int16 161.98× FIR + int16 dispatch + skipped float64 round-trip compound

Bandpass and CMR scale sub-linearly with thread count even at the component level due to DRAM bandwidth saturation; the 2.92× / 10.58× numbers are the memory-bandwidth ceiling, not a parallelism bug.

Compatibility

  • No default behavior changes. Every new path is opt-in via a kwarg with a default that preserves existing semantics. Users upgrade intentionally.
  • Deprecation only. get_chunk_with_margin(window_on_margin=True) still works and emits a DeprecationWarning pointing callers at apply_raised_cosine_taper.
  • Round-trip dumpability. _kwargs dicts updated on all three modified preprocessors; save() / load() round-trip the new kwargs correctly.
  • No new required deps. numba is already a soft dep of SI's Kilosort path; the FIR kernels import it lazily and raise a clear error with install instructions if missing.

Review guide

Suggested reading order if you want to split this into multiple PRs:

  1. get_chunk_with_margin refactor (commit 099f6004): standalone, low-risk, independently useful.
  2. FilterRecording(n_workers=N): trivial; adds a thread pool around an existing scipy call.
  3. CommonReferenceRecording(n_workers=N): trivial; same pattern.
  4. PhaseShiftRecording(method="fir"): larger change; includes new numba kernels and a new per-segment get_traces path.
  5. output_dtype int16-native advertising: narrower dtype contract change; worth separate discussion if desired.

Happy to split the branch into multiple PRs if preferred.

Checklist

  • Existing preprocessing tests pass
  • New tests cover each opt-in path (7 tests added)
  • Benchmark script with reproducible fixtures
  • Dumpable recordings (_kwargs updated)
  • Deprecation warning on the old window_on_margin kwarg
  • Docstrings updated on all modified classes
  • Add a CHANGELOG entry (happy to add wherever you keep release notes)

galenlynch and others added 4 commits April 23, 2026 10:09
… cleanup

Introduces three independent speedups for the preprocessing chain that
profile-bound many real-world workflows (ephys mipmap generation, in
particular). Each is opt-in via a new kwarg; existing defaults are
unchanged.

1. FilterRecording: n_workers kwarg. When > 1, splits the channel axis
   into blocks and runs scipy.signal.sosfiltfilt/sosfilt in a per-
   segment thread pool. scipy's C SOS implementations release the GIL
   during per-column work, so Python-thread parallelism delivers real
   speedup. Measured ~3x on 8 threads for a 1M x 384 float32 chunk,
   bitwise-identical output.

2. CommonReferenceRecording: n_workers kwarg. When > 1 and the call
   lands on the common global-reference path (no groups, no ref
   channels), the median/mean reduction is split across time blocks
   and run in a per-segment thread pool. numpy's partition-based median
   releases the GIL during per-row work. Measured ~10x on 16 threads
   for a 1M x 384 chunk, bitwise-identical output.

3. PhaseShiftRecording: method="fft"|"fir" kwarg (default "fft" for
   backward compatibility). The FIR path uses a 32-tap Kaiser-windowed
   sinc implemented as a numba-jit kernel. Measured ~85x faster than
   the FFT path on (1M, 384) float32, with ~0.19% spike-band RMS error
   vs the FFT reference on real NP 2.0 data. The FIR path also takes
   a K/2 sample margin instead of the 40 ms used by FFT (a bounded-
   support FIR at a zero-padded boundary is already exact under linear
   convolution semantics), and has an int16-native dispatch opt-in via
   output_dtype=np.float32 that skips the int16 -> float64 -> int16
   round-trip entirely.

4. get_chunk_with_margin: the in-place raised-cosine taper behind
   window_on_margin=True has been extracted into a standalone public
   function apply_raised_cosine_taper. The taper is FFT-specific (used
   to suppress spectral leakage on zero-padded boundaries) and conflating
   it with a generic chunk-fetching utility made get_chunk_with_margin
   unusable for bounded-support filters. window_on_margin is preserved
   as a deprecated kwarg with a DeprecationWarning; callers that need
   the taper should call apply_raised_cosine_taper explicitly after
   fetching. PhaseShift's FFT path is updated to do exactly that and
   keeps bit-for-bit backward-compatible output.

Existing tests for all three modules still pass. Follow-up commits will
add focused tests for the new kwargs and a benchmark script.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Tests: new tests lock in the correctness invariants for each opt-in
  path.

  * test_bandpass_parallel_matches_stock: n_workers>1 sosfiltfilt must
    match n_workers=1 within float32 tolerance.
  * test_filter_parallel_fewer_channels_than_workers: graceful fallback
    when channels < n_workers.
  * test_cmr_parallel_median_matches_stock: bit-identical output for
    parallel median.
  * test_cmr_parallel_average_matches_stock: within-ULP match for mean
    (non-associative float sum across block partitions).
  * test_phase_shift_fir_matches_fft_in_spike_band: FIR vs FFT
    signal-band RMS error < 1% (measured ~0.2% on real NP data; the
    synthetic test uses tones at 2.5 and 8.5 Hz from create_shifted_channel).
  * test_phase_shift_fir_int16_advertises_float32: int16-native opt-in
    dtype advertising.
  * test_phase_shift_fft_still_matches_stock_after_taper_refactor:
    regression guard for the get_chunk_with_margin taper extraction.

- Benchmark: benchmarks/preprocessing/bench_perf.py runs all four
  head-to-head comparisons on synthetic NumpyRecording fixtures and
  prints timing + correctness checks.  Measured on a 24-core x86_64 host
  with 1M x 384 chunks end-to-end through get_traces():

    Bandpass (1M x 384 float32): 8.67 s -> 3.34 s (2.60x)
    CMR median (1M x 384 float32): 3.95 s -> 0.83 s (4.76x)
    PhaseShift FFT vs FIR (float32): 68.07 s -> 0.695 s (97.94x)
    PhaseShift FFT vs FIR int16-native: 69.53 s -> 0.446 s (156.06x)

  Bandpass and CMR scale sub-linearly due to DRAM bandwidth saturation;
  the PhaseShift FIR benefits from both faster kernel compute and
  bypassing the 40 ms margin + float64 round-trip.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The end-to-end get_traces() numbers in PR-text were puzzling next to the
earlier monkey-patch numbers because micro-benches measure a different
thing. Now report both tiers side-by-side so readers can see the
non-parallelizable surrounding glue (subtract, cast, margin fetch) for
what it is — the memory-bandwidth tax on component speedups rather than
a regression.

sosfiltfilt:           component 2.66x,   e2e 2.58x
np.median axis=1:      component 10.28x,  e2e 4.74x (dilution = serial subtraction)
phase-shift FIR f32:   kernel  ~83x,      e2e 94.74x (FIR also cuts 40ms margin)
phase-shift FIR int16: -                  e2e 147x   (FIR + int16 + skipped roundtrip)

PR-text.md updated with both tiers and a "why component != e2e" table
that reviewers can reference if they ask.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@galenlynch
Copy link
Copy Markdown
Contributor Author

I'm going to split this up into two PRs to make it easier to review.

@galenlynch galenlynch marked this pull request as draft April 23, 2026 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant