Performance: parallel preprocessing + FIR phase-shift#4562
Draft
galenlynch wants to merge 4 commits intoSpikeInterface:mainfrom
Draft
Performance: parallel preprocessing + FIR phase-shift#4562galenlynch wants to merge 4 commits intoSpikeInterface:mainfrom
galenlynch wants to merge 4 commits intoSpikeInterface:mainfrom
Conversation
… cleanup Introduces three independent speedups for the preprocessing chain that profile-bound many real-world workflows (ephys mipmap generation, in particular). Each is opt-in via a new kwarg; existing defaults are unchanged. 1. FilterRecording: n_workers kwarg. When > 1, splits the channel axis into blocks and runs scipy.signal.sosfiltfilt/sosfilt in a per- segment thread pool. scipy's C SOS implementations release the GIL during per-column work, so Python-thread parallelism delivers real speedup. Measured ~3x on 8 threads for a 1M x 384 float32 chunk, bitwise-identical output. 2. CommonReferenceRecording: n_workers kwarg. When > 1 and the call lands on the common global-reference path (no groups, no ref channels), the median/mean reduction is split across time blocks and run in a per-segment thread pool. numpy's partition-based median releases the GIL during per-row work. Measured ~10x on 16 threads for a 1M x 384 chunk, bitwise-identical output. 3. PhaseShiftRecording: method="fft"|"fir" kwarg (default "fft" for backward compatibility). The FIR path uses a 32-tap Kaiser-windowed sinc implemented as a numba-jit kernel. Measured ~85x faster than the FFT path on (1M, 384) float32, with ~0.19% spike-band RMS error vs the FFT reference on real NP 2.0 data. The FIR path also takes a K/2 sample margin instead of the 40 ms used by FFT (a bounded- support FIR at a zero-padded boundary is already exact under linear convolution semantics), and has an int16-native dispatch opt-in via output_dtype=np.float32 that skips the int16 -> float64 -> int16 round-trip entirely. 4. get_chunk_with_margin: the in-place raised-cosine taper behind window_on_margin=True has been extracted into a standalone public function apply_raised_cosine_taper. The taper is FFT-specific (used to suppress spectral leakage on zero-padded boundaries) and conflating it with a generic chunk-fetching utility made get_chunk_with_margin unusable for bounded-support filters. window_on_margin is preserved as a deprecated kwarg with a DeprecationWarning; callers that need the taper should call apply_raised_cosine_taper explicitly after fetching. PhaseShift's FFT path is updated to do exactly that and keeps bit-for-bit backward-compatible output. Existing tests for all three modules still pass. Follow-up commits will add focused tests for the new kwargs and a benchmark script. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Tests: new tests lock in the correctness invariants for each opt-in
path.
* test_bandpass_parallel_matches_stock: n_workers>1 sosfiltfilt must
match n_workers=1 within float32 tolerance.
* test_filter_parallel_fewer_channels_than_workers: graceful fallback
when channels < n_workers.
* test_cmr_parallel_median_matches_stock: bit-identical output for
parallel median.
* test_cmr_parallel_average_matches_stock: within-ULP match for mean
(non-associative float sum across block partitions).
* test_phase_shift_fir_matches_fft_in_spike_band: FIR vs FFT
signal-band RMS error < 1% (measured ~0.2% on real NP data; the
synthetic test uses tones at 2.5 and 8.5 Hz from create_shifted_channel).
* test_phase_shift_fir_int16_advertises_float32: int16-native opt-in
dtype advertising.
* test_phase_shift_fft_still_matches_stock_after_taper_refactor:
regression guard for the get_chunk_with_margin taper extraction.
- Benchmark: benchmarks/preprocessing/bench_perf.py runs all four
head-to-head comparisons on synthetic NumpyRecording fixtures and
prints timing + correctness checks. Measured on a 24-core x86_64 host
with 1M x 384 chunks end-to-end through get_traces():
Bandpass (1M x 384 float32): 8.67 s -> 3.34 s (2.60x)
CMR median (1M x 384 float32): 3.95 s -> 0.83 s (4.76x)
PhaseShift FFT vs FIR (float32): 68.07 s -> 0.695 s (97.94x)
PhaseShift FFT vs FIR int16-native: 69.53 s -> 0.446 s (156.06x)
Bandpass and CMR scale sub-linearly due to DRAM bandwidth saturation;
the PhaseShift FIR benefits from both faster kernel compute and
bypassing the 40 ms margin + float64 round-trip.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The end-to-end get_traces() numbers in PR-text were puzzling next to the earlier monkey-patch numbers because micro-benches measure a different thing. Now report both tiers side-by-side so readers can see the non-parallelizable surrounding glue (subtract, cast, margin fetch) for what it is — the memory-bandwidth tax on component speedups rather than a regression. sosfiltfilt: component 2.66x, e2e 2.58x np.median axis=1: component 10.28x, e2e 4.74x (dilution = serial subtraction) phase-shift FIR f32: kernel ~83x, e2e 94.74x (FIR also cuts 40ms margin) phase-shift FIR int16: - e2e 147x (FIR + int16 + skipped roundtrip) PR-text.md updated with both tiers and a "why component != e2e" table that reviewers can reference if they ask. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Contributor
Author
|
I'm going to split this up into two PRs to make it easier to review. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Three independet opt-in improvements to the preprocessing chain, plus one small refactor that enables them cleanly. Each is opt-in via a new kwarg; existing defaults and outputs are unchanged.
Headline. For a standard preprocessing pipeline, with full int16 → bandpass → CMR → phase-shift → int16 pipeline on a 1M × 384 chunk, processing time went from 87.7 s → 6.4 s (13.7×). This was with every performance option introduced here enabled. See
bench_pipeline_int16below for the exact configuration.Extrapolated to a full 90-min NP 2.0 recording (30 kHz × 384 ch ≈ 155 × 1M-sample shards), single-worker
get_traces()sequential walk:The per-stage breakdown with just one preprocessor (each number labeled with what it measures — see
benchmarks/preprocessing/bench_perf.py) is:get_traces())FilterRecording(n_workers=N)1CommonReferenceRecording(n_workers=N)1PhaseShiftRecording(method="fir")"fft"PhaseShiftRecording(method="fir", output_dtype=np.float32)NonePipeline configuration:
BandpassFilterRecording(rec)→CommonReferenceRecording(...)→PhaseShiftRecording(..., method="fft")— all defaults.BandpassFilterRecording(rec, n_workers=8)→CommonReferenceRecording(..., n_workers=16)→PhaseShiftRecording(..., method="fir")— three opt-ins, no dtype widening.Tested on a 24-core x86-64 host.
Motivation
Profiling of streaming recordings through filter chain showed:
sosfiltfilt) was single-threaded — 23 of 24 cores sat idle.np.median) was single-threaded.Each change is a natural fit for Python-thread parallelism because the underlying C kernels release the GIL during per-column/per-row work, so no multiprocessing is needed. The FIR path is a well-known alternative to FFT-based fractional-delay interpolation, validated against the existing FFT reference on real NP data.
Changes
1.
FilterRecording(n_workers=N)— channel-parallel SOSFile:
src/spikeinterface/preprocessing/filter.pyn_workerskwarg (default1, preserving existing behavior).n_workers > 1,FilterRecordingSegment.get_tracessplits the channel axis into contiguous blocks and runsscipy.signal.sosfiltfilt/sosfilton each block in a per-segmentThreadPoolExecutor.2 * n_workers.2.
CommonReferenceRecording(n_workers=N)— time-parallel reductionFile:
src/spikeinterface/preprocessing/common_reference.pyn_workerskwarg (default1).group_indices=None,reference="global",ref_channel_ids=None) is parallelized — every other configuration delegates to the existing logic unchanged.n_workers > 1,_parallel_reduce_axis1splits the time axis into blocks and runsnp.median/np.meanper block in a thread pool.min_block=8192samples per thread the overhead dominates; falls back to serial automatically.3.
PhaseShiftRecording(method="fft"|"fir")— sinc FIR alternativeFile:
src/spikeinterface/preprocessing/phase_shift.pymethodkwarg (default"fft"for backward compatibility).method="fir"uses a 32-tap Kaiser-windowed sinc FIR, implemented as numba-jit kernels withprangeparallelism over time.n_taps // 2samples (16 for the 32-tap default), not the 40 ms the FFT path needs.n_tapsconfigurable (default 32, validated as even and >= 2).Why FFT ↔ FIR is algorithmically equivalent (Whittaker–Shannon)
A fractional-sample delay is, by definition, a sinc interpolation at the desired offset. Whittaker–Shannon states that any signal bandlimited to Nyquist is exactly reconstructed from its samples via
so delaying channel$c$ by a fractional $d_c$ samples is
— convolution with the ideal fractional-delay kernel$h_{d_c}[k] = \mathrm{sinc}(k - d_c)$ .
The existing FFT path realises this convolution spectrally: multiplying by$e^{i,2\pi f, d_c}$ in the frequency domain is the DFT of that same infinite sinc kernel. The FIR path realises it in the time domain as explicit linear convolution against a Kaiser-windowed, 32-tap truncation of the same sinc. The operation is identical; the only approximations are:
The measured 0.19% spike-band RMS against the FFT reference (Correctness table) is orders of magnitude below the analog noise floor (~4 ADC LSB RMS on NP 2.0), confirming the two approximations are benign for every physically meaningful downstream use (spike sorting, waveform extraction, envelope rendering).
Why the FFT path's 40 ms margin is unnecessary for the FIR
The 40 ms margin in the FFT path is not a property of fractional-delay interpolation — it's a workaround for the fact that
rfft/irfftoperate on a periodic buffer. Without zero-padding + a raised-cosine taper, energy from the right edge of a chunk wraps circularly to the left edge (and vice versa). A bounded-support linear FIR has no such wraparound:n_taps // 2samples of margin on each side (16 for the 32-tap default) is sufficient and mathematically exact under linear convolution. That saved margin is why the FIR's end-to-end speedup (100.8×) is larger than its kernel-only speedup (≈ 83×).int16-native fast path
PhaseShiftRecording(method="fir", output_dtype=np.float32)on an integer-dtyped parent enables a second kernel specialisation:float32as the recording's output dtype. Downstream filters (bandpass, CMR) consume float directly — which they need to anyway for their own math — so the net effect is removing three redundant cast passes from the pipeline.output_dtype=np.float32; default behavior is unchanged (FIR still round-and-casts back to the parent dtype).4.
get_chunk_with_margin— extract FFT-specific taperFile:
src/spikeinterface/core/time_series_tools.pyapply_raised_cosine_taper(data, margin, *, inplace=True)exposes the raised-cosine window that was previously inlined inget_chunk_with_margin(window_on_margin=True).window_on_margin=Truecontinues to work but is deprecated: it emits aDeprecationWarningand delegates toapply_raised_cosine_taper.PhaseShiftRecordingpath is updated to callget_chunk_with_margin(window_on_margin=False)and thenapply_raised_cosine_taperexplicitly. Output is bit-for-bit equivalent to pre-refactor behavior (regression test added).get_chunk_with_marginwas unusable for bounded-support filters both because the taper was redundant and because the in-place*=against a float taper fails on int-typed chunks. Separating the concern makes the utility filter-method-agnostic, which in turn unlocks the int16-native FIR path.Correctness
np.allclose(rtol=1e-5)np.array_equalnp.allclose(rtol=1e-5)< 1%on synthetic< 0.5%on real NP 2.0 datatest_phase_shift(FFT chunked-vs-full identity)error_mean / rms < 0.001The existing tests for all three modules pass unchanged.
Per-stage benchmarks (reproducible)
benchmarks/preprocessing/bench_perf.py— synthetic NumpyRecording, 1M × 384 chunks, measured on a 24-core x86_64 host (SI 0.103 dev, numpy 2.1, scipy 1.14, numba 0.60). Two tiers reported:Component-level (hot operation only)
Isolated kernel / reduction — no
get_chunk_with_margin, no dtype casts, no slicing. Shows the raw speedup of the parallelization technique itself.End-to-end, per stage (
rec.get_traces())Full SI preprocessing class through
get_traces(), including margin fetch, buffer copies, casts, and subtraction. Each stage measured in isolation.Why component ≠ per-stage end-to-end
get_chunk_with_marginmargin fetch + dtype cast — negligibletraces - shiftsubtraction (1.5 GB r/w, memory-bandwidth-bound) that isn't parallelizableBandpass and CMR scale sub-linearly with thread count even at the component level due to DRAM bandwidth saturation; the 2.92× / 10.58× numbers are the memory-bandwidth ceiling, not a parallelism bug.
Compatibility
get_chunk_with_margin(window_on_margin=True)still works and emits aDeprecationWarningpointing callers atapply_raised_cosine_taper._kwargsdicts updated on all three modified preprocessors;save()/load()round-trip the new kwargs correctly.numbais already a soft dep of SI's Kilosort path; the FIR kernels import it lazily and raise a clear error with install instructions if missing.Review guide
Suggested reading order if you want to split this into multiple PRs:
get_chunk_with_marginrefactor (commit099f6004): standalone, low-risk, independently useful.FilterRecording(n_workers=N): trivial; adds a thread pool around an existing scipy call.CommonReferenceRecording(n_workers=N): trivial; same pattern.PhaseShiftRecording(method="fir"): larger change; includes new numba kernels and a new per-segment get_traces path.output_dtypeint16-native advertising: narrower dtype contract change; worth separate discussion if desired.Happy to split the branch into multiple PRs if preferred.
Checklist
_kwargsupdated)window_on_marginkwarg