ibl-neuropixel 1.8.1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ibl-neuropixel
3
- Version: 1.8.1
3
+ Version: 1.9.1
4
4
  Summary: Collection of tools for Neuropixel 1.0 and 2.0 probes data
5
5
  Home-page: https://github.com/int-brain-lab/ibl-neuropixel
6
6
  Author: The International Brain Laboratory
@@ -41,6 +41,53 @@ Minimum Python version supported is 3.10
41
41
 
42
42
  ## Destriping
43
43
  ### Getting started
44
+
45
+ #### Compress a binary file losslessly using `mtscomp`
46
+
47
+ The mtscomp util implements fast chunked compression for neurophysiology data in a single shard.
48
+ Package repository is [here](https://github.com/int-brain-lab/mtscomp).
49
+
50
+
51
+ ```python
52
+ from pathlib import Path
53
+ import spikeglx
54
+ file_spikeglx = Path('/datadisk/neuropixel/file.imec0.ap.bin')
55
+ sr = spikeglx.Reader(file_spikeglx)
56
+ sr.compress_file()
57
+ # note: you can use sr.compress_file(keep_original=False) to also remove the orginal bin file
58
+ ```
59
+
60
+ #### Reading raw spikeglx file and manipulating arrays
61
+
62
+ The mtscomp util implements fast chunked compression for neurophysiology data in a single shard.
63
+ Package repository is [here](https://github.com/int-brain-lab/mtscomp).
64
+
65
+ ```python
66
+ from pathlib import Path
67
+ import spikeglx
68
+
69
+ import ibldsp.voltage
70
+
71
+ file_spikeglx = Path('/datadisk/Data/neuropixel/human/Pt01.imec0.ap.bin')
72
+ sr = spikeglx.Reader(file_spikeglx)
73
+
74
+ # reads in 300ms of data
75
+ raw = sr[10_300_000:10_310_000, :sr.nc - sr.nsync].T
76
+ destripe = ibldsp.voltage.destripe(raw, fs=sr.fs, neuropixel_version=1)
77
+
78
+ # display with matplotlib backend
79
+ import ibldsp.plots
80
+ ibldsp.plots.voltageshow(raw, fs=sr.fs, title='raw')
81
+ ibldsp.plots.voltageshow(destripe, fs=sr.fs, title='destripe')
82
+
83
+ # display with QT backend
84
+ from viewephys.gui import viewephys
85
+ eqc = {}
86
+ eqc['raw'] = viewephys(raw, fs=sr.fs, title='raw')
87
+ eqc['destripe'] = viewephys(destripe, fs=sr.fs, title='destripe')
88
+ ```
89
+
90
+ #### Destripe a binary file
44
91
  This relies on a fast fourier transform external library: `pip install pyfftw`.
45
92
 
46
93
  Minimal working example to destripe a neuropixel binary file.
@@ -71,22 +118,4 @@ The following describes the methods implemented in this repository.
71
118
  https://doi.org/10.6084/m9.figshare.19705522
72
119
 
73
120
  ## Contribution
74
- Contribution checklist:
75
- - run tests
76
- - ruff format
77
- - PR to main
78
-
79
-
80
- Pypi Release checklist:
81
- - Edit the version number in `setup.py`
82
- - add release notes in `release_notes.md`
83
-
84
-
85
- ```shell
86
- ruff format
87
- tag=X.Y.Z
88
- git tag -a $tag
89
- git push origin $tag
90
- ```
91
-
92
- Create new release with tag X.Y.Z (will automatically publish to PyPI)
121
+ Please see our [contribution guidelines](CONTRIBUTING.md) for details on how to contribute to this project.
@@ -1,6 +1,6 @@
1
1
  neuropixel.py,sha256=P7sIBAtGIqKReK7OqMBqdwPaTeHjhHMyfyBRL_AvuQY,37987
2
- spikeglx.py,sha256=OPvkZdnMguBAuajA8XjJko9N6-UDo4GdKeYp3DhIClU,40865
3
- ibl_neuropixel-1.8.1.dist-info/licenses/LICENSE,sha256=JJCjBeS78UPiX7TZpE-FnMjNNpCyrFb4s8VDGG2wD10,1087
2
+ spikeglx.py,sha256=4TPXnFGhJahClxr4fA9HwTeiiHBQS9ZEfkWl6t20q2s,41068
3
+ ibl_neuropixel-1.9.1.dist-info/licenses/LICENSE,sha256=JJCjBeS78UPiX7TZpE-FnMjNNpCyrFb4s8VDGG2wD10,1087
4
4
  ibldsp/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  ibldsp/cadzow.py,sha256=pAtxDxBwoNhoxFNc2R5WLwUrmKsq4rQuaglRNgW2Lj8,7251
6
6
  ibldsp/cuda_tools.py,sha256=6LpVhYOCuOXEEg8kJ3aOCE4hzA1Yq1dojsbbBQmQCF4,2387
@@ -8,26 +8,28 @@ ibldsp/destripe_gpu.py,sha256=I5jzFocpsYw36kMMd533YThbrQaZix5e1sHqsUjHvO4,2824
8
8
  ibldsp/filter_gpu.py,sha256=DPrPBLRXeCh_6BcJWJnPFaxS9Q6kX4nPENZg-c2q5rc,5789
9
9
  ibldsp/fourier.py,sha256=RI58nhs4ZZXx1M6EtuhA0vbtkNaBRS2QNv7tPkVomao,10608
10
10
  ibldsp/icsd.py,sha256=y9NWOXBB4Nfb5A1fQMKlOu0PdVDVOZ39v2pwk2zzB84,44923
11
- ibldsp/plots.py,sha256=lgSqnGXMKnJ7fAa3ru30oeIZIAAd-Fz9Cgzx-p2w04k,2064
11
+ ibldsp/plots.py,sha256=XmYC4yca_seZYNEmC5hE5wBiJAl_fi_KU00DbNcM6jI,4577
12
12
  ibldsp/raw_metrics.py,sha256=Ie4b7unuFc-XiFc9-tpTsUkph29G-20NvM7iJ25jAPI,5198
13
13
  ibldsp/smooth.py,sha256=m_mByXHG_JyFErnYsZ27gXjcqpfwCEuWa6eOb9eFuyg,8033
14
14
  ibldsp/spiketrains.py,sha256=lYP1PD4l6T-4KhFu8ZXlbnUUnEQLOriGxN1szacolPY,6878
15
- ibldsp/utils.py,sha256=p3yvxXdfW36PNmN8qZQ237ZlkPvNvrJ3qCWkDFuy5Q8,13398
16
- ibldsp/voltage.py,sha256=ID1FDpA9s4qhr6GBy-1SqQlsuU7YjvX5-T25bUaGpDI,39815
17
- ibldsp/waveform_extraction.py,sha256=IZWMDmsDnC7FdwvhYKzCwriq0rbWw4esGZYKPMQWkhY,26544
15
+ ibldsp/utils.py,sha256=uvEPw1adkppiGXuYBkM_fuuX5owq7LRmA6vm438rrYc,17959
16
+ ibldsp/voltage.py,sha256=Iias93xAvxfRDrzgZT-aw-w4xfWtykx2zWLhI2CxzVI,45408
17
+ ibldsp/waveform_extraction.py,sha256=yKrldgHqpwQ_Dq6xdoSCceKkfrL9FUXnpwKJUM3R41M,26570
18
18
  ibldsp/waveforms.py,sha256=5OBLYuM902WS_9WGDDmiTh4BpYWGe7-bQYTMxc2mYII,35166
19
19
  neurowaveforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  neurowaveforms/model.py,sha256=YOPWMMNNS_Op5TyK4Br1i9_Ni41jLSqHie5r1vb5VjY,6729
21
21
  tests/integration/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  tests/integration/csd_experiments.py,sha256=bddMl2SCzeEM_QnBrZGypUYMKxFVDc6qderyUyX-iew,3158
23
- tests/integration/test_destripe.py,sha256=6OwqWWz3hJSPGAeEGDcJJkG4bZMnNeaU80AlH7vyrno,6170
23
+ tests/integration/test_destripe.py,sha256=ZV7gasuFib2AbVb63WczgQvc13PbIX9t4pQgamBMgRY,6161
24
24
  tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  tests/unit/test_ephys_np2.py,sha256=1wsgS_C5W8tUO_qDyORBRUKGsrB0Gq3wMLAjJcjrNZ4,15599
26
- tests/unit/test_ibldsp.py,sha256=bekaSW02sZxdn4xML_7WRssUOdVpWPTXmj3IPxyLlWA,24872
27
26
  tests/unit/test_neuropixel.py,sha256=ZFKrvTYaYgK5WgOfoHa5x9BNUpRomACPiIm6Kr-A3gw,2511
28
- tests/unit/test_spikeglx.py,sha256=p0cATRg7xK4WsS_PP_fng_qMATSx6FbRz9kBe8bgIUk,33130
27
+ tests/unit/test_plots.py,sha256=PhCxrEN1Zd1jTgmiwd16_dEghcI7kwmHT3AQmAPpzkA,850
28
+ tests/unit/test_spikeglx.py,sha256=9PrSOPGrYAAQEeJPAOmqc3Rhgia6ftv-zihVWXglhqw,34388
29
+ tests/unit/test_utils.py,sha256=37XQDUqcABYrrsdX17kX54H4e5jld7GOn1ISxtgoa5U,21859
30
+ tests/unit/test_voltage.py,sha256=Nr6KqNGn2yOGPJYnvVzxdM5IiEHvK2FicDR_7fzvTHQ,6228
29
31
  tests/unit/test_waveforms.py,sha256=VnFvUi1pteROwwbC5Ebp2lqSxF3a8a7eXHpD8OUeuTg,16237
30
- ibl_neuropixel-1.8.1.dist-info/METADATA,sha256=nctIH01vFPOU42nUCLmbqWG6ByMZQoZGBqQgCgswQ0c,2505
31
- ibl_neuropixel-1.8.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
32
- ibl_neuropixel-1.8.1.dist-info/top_level.txt,sha256=WtVcEUptnwU6BT72cgGmrWYFGM9d9qCEqe3LwR9FIw4,48
33
- ibl_neuropixel-1.8.1.dist-info/RECORD,,
32
+ ibl_neuropixel-1.9.1.dist-info/METADATA,sha256=RaS1xeg11qze-sAmPqKVqdKOgUcSk-5l01HNrkX9kIw,3746
33
+ ibl_neuropixel-1.9.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ ibl_neuropixel-1.9.1.dist-info/top_level.txt,sha256=WtVcEUptnwU6BT72cgGmrWYFGM9d9qCEqe3LwR9FIw4,48
35
+ ibl_neuropixel-1.9.1.dist-info/RECORD,,
ibldsp/plots.py CHANGED
@@ -1,9 +1,17 @@
1
1
  import numpy as np
2
2
  import matplotlib.pyplot as plt
3
3
 
4
+ AP_RANGE_UV = 75
5
+ LF_RANGE_UV = 250
6
+
4
7
 
5
8
  def show_channels_labels(
6
- raw, fs, channel_labels, xfeats, similarity_threshold, psd_hf_threshold=0.02
9
+ raw,
10
+ fs,
11
+ channel_labels,
12
+ xfeats,
13
+ similarity_threshold=(-0.5, 1),
14
+ psd_hf_threshold=0.02,
7
15
  ):
8
16
  """
9
17
  Shows the features side by side a snippet of raw data
@@ -13,14 +21,20 @@ def show_channels_labels(
13
21
  nc, ns = raw.shape
14
22
  raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset
15
23
  ns_plot = np.minimum(ns, 3000)
16
- vaxis_uv = 250 if fs < 2600 else 75
17
24
  fig, ax = plt.subplots(
18
25
  1, 5, figsize=(18, 6), gridspec_kw={"width_ratios": [1, 1, 1, 8, 0.2]}
19
26
  )
20
27
  ax[0].plot(xfeats["xcor_hf"], np.arange(nc))
21
- ax[0].plot(
28
+ ax[0].plot( # plot channel below the similarity threshold as dead in black
22
29
  xfeats["xcor_hf"][(iko := channel_labels == 1)], np.arange(nc)[iko], "k*"
23
30
  )
31
+ ax[0].plot( # plot the values above the similarity threshold as noisy in red
32
+ xfeats["xcor_hf"][
33
+ (iko := np.where(xfeats["xcor_hf"] > similarity_threshold[1]))
34
+ ],
35
+ np.arange(nc)[iko],
36
+ "r*",
37
+ )
24
38
  ax[0].plot(similarity_threshold[0] * np.ones(2), [0, nc], "k--")
25
39
  ax[0].plot(similarity_threshold[1] * np.ones(2), [0, nc], "r--")
26
40
  ax[0].set(
@@ -30,7 +44,11 @@ def show_channels_labels(
30
44
  title="a) dead channel",
31
45
  )
32
46
  ax[1].plot(xfeats["psd_hf"], np.arange(nc))
33
- ax[1].plot(xfeats["psd_hf"][(iko := channel_labels == 2)], np.arange(nc)[iko], "r*")
47
+ ax[1].plot(
48
+ xfeats["psd_hf"][(iko := xfeats["psd_hf"] > psd_hf_threshold)],
49
+ np.arange(nc)[iko],
50
+ "r*",
51
+ )
34
52
  ax[1].plot(psd_hf_threshold * np.array([1, 1]), [0, nc], "r--")
35
53
  ax[1].set(yticklabels=[], xlabel="PSD", ylim=[0, nc], title="b) noisy channel")
36
54
  ax[1].sharey(ax[0])
@@ -41,18 +59,77 @@ def show_channels_labels(
41
59
  ax[2].plot([-0.75, -0.75], [0, nc], "y--")
42
60
  ax[2].set(yticklabels=[], xlabel="LF coherence", ylim=[0, nc], title="c) outside")
43
61
  ax[2].sharey(ax[0])
44
- im = ax[3].imshow(
45
- raw[:, :ns_plot] * 1e6,
46
- origin="lower",
47
- cmap="PuOr",
48
- aspect="auto",
49
- vmin=-vaxis_uv,
50
- vmax=vaxis_uv,
51
- extent=[0, ns_plot / fs * 1e3, 0, nc],
52
- )
53
- ax[3].set(yticklabels=[], title="d) Raw data", xlabel="time (ms)", ylim=[0, nc])
54
- ax[3].grid(False)
62
+ voltageshow(raw[:, :ns_plot], fs, ax=ax[3], cax=ax[4])
55
63
  ax[3].sharey(ax[0])
56
- plt.colorbar(im, cax=ax[4], shrink=0.8).ax.set(ylabel="(uV)")
57
64
  fig.tight_layout()
58
65
  return fig, ax
66
+
67
+
68
+ def voltageshow(
69
+ raw,
70
+ fs,
71
+ cmap="PuOr",
72
+ ax=None,
73
+ cax=None,
74
+ cbar_label="Voltage (uV)",
75
+ scaling=1e6,
76
+ vrange=None,
77
+ **axis_kwargs,
78
+ ):
79
+ """
80
+ Visualizes electrophysiological voltage data as a heatmap.
81
+
82
+ This function displays raw voltage data as a color-coded image with appropriate
83
+ scaling based on the sampling frequency. It automatically selects voltage range
84
+ based on whether the data is low-frequency (LF) or action potential (AP) data.
85
+
86
+ Parameters
87
+ ----------
88
+ raw : numpy.ndarray
89
+ Raw voltage data array with shape (channels, samples), in Volts
90
+ fs : float
91
+ Sampling frequency in Hz, used to determine time axis scaling and voltage range.
92
+ cmap : str, optional
93
+ Matplotlib colormap name for the heatmap. Default is 'PuOr'.
94
+ ax : matplotlib.axes.Axes, optional
95
+ Axes object to plot on. If None, a new figure and axes are created.
96
+ cax : matplotlib.axes.Axes, optional
97
+ Axes object for the colorbar. If None and ax is None, a new colorbar axes is created.
98
+ cbar_label : str, optional
99
+ Label for the colorbar. Default is 'Voltage (uV)'.
100
+ vrange: float, optional
101
+ Voltage range for the colorbar. Defaults to +/- 75 uV for AP and +/- 250 uV for LF.
102
+ scaling: float, optional
103
+ Unit transform: default is 1e6: we expect Volts but plot uV.
104
+ **axis_kwargs: optional
105
+ Additional keyword arguments for the axis properties, fed to the ax.set() method.
106
+ Returns
107
+ -------
108
+ matplotlib.image.AxesImage
109
+ The image object created by imshow, which can be used for further customization.
110
+ """
111
+ if ax is None:
112
+ fig, axs = plt.subplots(1, 2, gridspec_kw={"width_ratios": [1, 0.05]})
113
+ ax, cax = axs
114
+ nc, ns = raw.shape
115
+ default_vrange = LF_RANGE_UV if fs < 2600 else AP_RANGE_UV
116
+ vrange = vrange if vrange is not None else default_vrange
117
+ im = ax.imshow(
118
+ raw * scaling,
119
+ origin="lower",
120
+ cmap=cmap,
121
+ aspect="auto",
122
+ vmin=-vrange,
123
+ vmax=vrange,
124
+ extent=[0, ns / fs, 0, nc],
125
+ )
126
+ # set the axis properties: we use defaults values that can be overridden by user-provided ones
127
+ axis_kwargs = (
128
+ dict(ylim=[0, nc], xlabel="Time (s)", ylabel="Depth (μm)") | axis_kwargs
129
+ )
130
+ ax.set(**axis_kwargs)
131
+ ax.grid(False)
132
+ if cax is not None:
133
+ plt.colorbar(im, cax=cax, shrink=0.8).ax.set(ylabel=cbar_label)
134
+
135
+ return im
ibldsp/utils.py CHANGED
@@ -89,7 +89,7 @@ def parabolic_max(x):
89
89
  # for 2D arrays, operate along the last dimension
90
90
  ns = x.shape[-1]
91
91
  axis = -1
92
- imax = np.argmax(x, axis=axis)
92
+ imax = np.nanargmax(x, axis=axis)
93
93
 
94
94
  if x.ndim == 1:
95
95
  v010 = x[np.maximum(np.minimum(imax + np.array([-1, 0, 1]), ns - 1), 0)]
@@ -268,12 +268,64 @@ def make_channel_index(geom, radius=200.0, pad_val=None):
268
268
 
269
269
  class WindowGenerator(object):
270
270
  """
271
- `wg = WindowGenerator(ns, nswin, overlap)`
271
+ A utility class for generating sliding windows for signal processing applications.
272
272
 
273
- Provide sliding windows indices generator for signal processing applications.
274
- For straightforward spectrogram / periodogram implementation, prefer scipy methods !
273
+ WindowGenerator provides various methods to iterate through windows of a signal
274
+ with configurable window size and overlap. It's particularly useful for operations
275
+ like spectrograms, filtering, or any processing that requires windowed analysis.
275
276
 
276
- Example of implementations in test_dsp.py.
277
+ Parameters
278
+ ----------
279
+ ns : int
280
+ Total number of samples in the signal to be windowed.
281
+ nswin : int
282
+ Number of samples in each window.
283
+ overlap : int
284
+ Number of samples that overlap between consecutive windows.
285
+
286
+ Attributes
287
+ ----------
288
+ ns : int
289
+ Total number of samples in the signal.
290
+ nswin : int
291
+ Number of samples in each window.
292
+ overlap : int
293
+ Number of samples that overlap between consecutive windows.
294
+ nwin : int
295
+ Total number of windows.
296
+ iw : int or None
297
+ Current window index during iteration.
298
+
299
+ Notes
300
+ -----
301
+ For straightforward spectrogram or periodogram implementation,
302
+ scipy methods are recommended over this class.
303
+
304
+ Examples
305
+ --------
306
+ # straight windowing without overlap
307
+ >>> wg = WindowGenerator(ns=1000, nwin=111)
308
+ >>> signal = np.random.randn(1000)
309
+ >>> for window_slice in wg.slice:
310
+ ... window_data = signal[window_slice]
311
+ ... # Process window_data
312
+
313
+ # windowing with overlap (ie. buffers for apodization)
314
+ >>> for win_slice, valid_slice, win_valid_slice in wg.slices_valid:
315
+ ... window = signal[win_slice]
316
+ ... # Process window
317
+ ... processed = some_function_with_edge_effect(window)
318
+ ... # Only use the valid portion for reconstruction
319
+ ... recons[valid_slice] = processed[win_valid_slice]
320
+
321
+ # splicing add a fade-in / fade-out in the overlap so that reconstruction has unit amplitude
322
+ >>> recons = np.zeros_like(signal)
323
+ >>> for win_slice, amplitude in wg.splice:
324
+ ... window = signal[win_slice]
325
+ ... # Process window
326
+ ... processed = some_function(window)
327
+ ... # The processed windows is weighted with the amplitude and added to the reconstructed signal
328
+ ... recons[win_slice] = recons[win_slice] + processed * amplitude
277
329
  """
278
330
 
279
331
  def __init__(self, ns, nswin, overlap):
@@ -289,14 +341,35 @@ class WindowGenerator(object):
289
341
  self.iw = None
290
342
 
291
343
  @property
292
- def firstlast_splicing(self):
344
+ def splice(self):
293
345
  """
294
- Generator that yields the indices as well as an amplitude function that can be used
295
- to splice the windows together.
296
- In the overlap, the amplitude function gradually transitions the amplitude from one window
297
- to the next. The amplitudes always sum to one (ie. windows are symmetrical)
346
+ Generator that yields slices and amplitude arrays for windowed signal processing with splicing.
347
+
348
+ This property provides a convenient way to iterate through all windows with their
349
+ corresponding amplitude arrays for proper signal reconstruction. The amplitude arrays
350
+ contain tapering values (from a Hann window) at the overlapping regions to ensure
351
+ unit amplitude of all samples of the original signal
352
+
353
+ Yields
354
+ ------
355
+ tuple
356
+ A tuple containing:
357
+ - slice: A Python slice object representing the current window
358
+ - amp: A numpy array containing amplitude values for proper splicing/tapering
359
+ at overlap regions
360
+
361
+ Notes
362
+ -----
363
+ This is particularly useful for overlap-add methods where windows need to be
364
+ properly weighted before being combined in the reconstruction process.
365
+ """
366
+ for first, last, amp in self.firstlast_splicing:
367
+ yield slice(first, last), amp
298
368
 
299
- :return: tuple of (first_index, last_index, amplitude_vector]
369
+ @property
370
+ def firstlast_splicing(self):
371
+ """
372
+ cf. self.splice
300
373
  """
301
374
  w = scipy.signal.windows.hann((self.overlap + 1) * 2 + 1, sym=True)[
302
375
  1 : self.overlap + 1
@@ -310,7 +383,7 @@ class WindowGenerator(object):
310
383
  yield (first, last, amp)
311
384
 
312
385
  @property
313
- def firstlast_valid(self):
386
+ def firstlast_valid(self, discard_edges=False):
314
387
  """
315
388
  Generator that yields a tuple of first, last, first_valid, last_valid index of windows
316
389
  The valid indices span up to half of the overlap
@@ -318,12 +391,18 @@ class WindowGenerator(object):
318
391
  """
319
392
  assert self.overlap % 2 == 0, "Overlap must be even"
320
393
  for first, last in self.firstlast:
321
- first_valid = 0 if first == 0 else first + self.overlap // 2
322
- last_valid = last if last == self.ns else last - self.overlap // 2
394
+ first_valid = (
395
+ 0 if first == 0 and not discard_edges else first + self.overlap // 2
396
+ )
397
+ last_valid = (
398
+ last
399
+ if last == self.ns and not discard_edges
400
+ else last - self.overlap // 2
401
+ )
323
402
  yield (first, last, first_valid, last_valid)
324
403
 
325
404
  @property
326
- def firstlast(self, return_valid=False):
405
+ def firstlast(self):
327
406
  """
328
407
  Generator that yields first and last index of windows
329
408
 
@@ -343,13 +422,51 @@ class WindowGenerator(object):
343
422
  @property
344
423
  def slice(self):
345
424
  """
346
- Generator that yields slices of windows
347
-
348
- :return: a slice of the window
425
+ Generator that yields slice objects for each window in the signal.
426
+
427
+ This property provides a convenient way to iterate through all windows
428
+ defined by the WindowGenerator parameters. Each yielded slice can be
429
+ used directly to index into the original signal array.
430
+
431
+ Yields
432
+ ------
433
+ slice
434
+ A Python slice object representing the current window, defined by
435
+ its first and last indices. The slice can be used to extract the
436
+ corresponding window from the original signal.
349
437
  """
350
438
  for first, last in self.firstlast:
351
439
  yield slice(first, last)
352
440
 
441
+ @property
442
+ def slices_valid(self):
443
+ """
444
+ Generator that yields slices for windowed signal processing with valid regions.
445
+
446
+ This method generates tuples of slice objects that can be used to extract windows
447
+ from a signal and identify the valid (non-overlapping) portions within each window.
448
+ It's particularly useful for reconstruction operations where overlapping regions
449
+ need special handling.
450
+
451
+ Yields
452
+ ------
453
+ tuple
454
+ A tuple containing three slice objects:
455
+ - slice(first, last): The full window slice
456
+ - slice(first_valid, last_valid): The valid portion of the signal in absolute indices
457
+ - slice_window_valid: The valid portion relative to the window (for use within the window)
458
+
459
+ Notes
460
+ -----
461
+ This generator relies on the firstlast_valid property which provides the
462
+ indices for both the full windows and their valid regions.
463
+ """
464
+ for first, last, first_valid, last_valid in self.firstlast_valid:
465
+ slice_window_valid = slice(
466
+ first_valid - first, None if (lv := -(last - last_valid)) == 0 else lv
467
+ )
468
+ yield slice(first, last), slice(first_valid, last_valid), slice_window_valid
469
+
353
470
  def slice_array(self, sig, axis=-1):
354
471
  """
355
472
  Provided an array or sliceable object, generator that yields
ibldsp/voltage.py CHANGED
@@ -3,6 +3,8 @@ Module to work with raw voltage traces. Spike sorting pre-processing functions.
3
3
  """
4
4
 
5
5
  import inspect
6
+ import joblib
7
+ import tqdm
6
8
  from pathlib import Path
7
9
 
8
10
  import numpy as np
@@ -217,6 +219,7 @@ def kfilt(
217
219
  xf, gain = agc(x, wl=lagc, si=1.0, gpu=gpu)
218
220
  if ntr_pad > 0:
219
221
  # pad the array with a mirrored version of itself and apply a cosine taper
222
+ ntr_pad = np.min([ntr_pad, xf.shape[0]])
220
223
  xf = gp.r_[gp.flipud(xf[:ntr_pad]), xf, gp.flipud(xf[-ntr_pad:])]
221
224
  if ntr_tap > 0:
222
225
  taper = fourier.fcn_cosine([0, ntr_tap], gpu=gpu)(gp.arange(nxp)) # taper up
@@ -266,6 +269,120 @@ def saturation(
266
269
  return saturation, mute
267
270
 
268
271
 
272
+ def saturation_samples_to_intervals(
273
+ _saturation: np.ndarray, output_file: Path = None
274
+ ) -> pd.DataFrame:
275
+ """
276
+ Convert a flat npy file to a table with saturation intervals.
277
+ :param _saturation: np.ndarray: Boolean array with saturation samples set as True
278
+ :return:
279
+ """
280
+ assert not _saturation[0]
281
+ ind, pol = ibldsp.utils.fronts(_saturation.astype(np.int8))
282
+ # if the last sample is positive, make sure the interval is closed by providing an even number of events
283
+ if len(pol) > 0 and pol[-1] == 1:
284
+ pol = np.r_[pol, -1]
285
+ ind = np.r_[ind, _saturation.shape[0] - 1]
286
+ df_saturation = pd.DataFrame(
287
+ np.c_[ind[::2], ind[1::2]], columns=["start_sample", "stop_sample"]
288
+ )
289
+ if output_file is not None:
290
+ df_saturation.to_parquet(output_file)
291
+ return df_saturation
292
+
293
+
294
+ def saturation_cbin(
295
+ bin_file_path: Path,
296
+ file_saturation: Path = None,
297
+ max_voltage=None,
298
+ n_jobs=4,
299
+ v_per_sec=1e-8,
300
+ proportion=0.2,
301
+ mute_window_samples=7,
302
+ ) -> Path:
303
+ """
304
+ Detect saturation in a compressed binary (cbin) electrophysiology file and save the results.
305
+
306
+ This function processes a SpikeGLX binary file in chunks to identify saturated samples
307
+ and saves the results as a memory-mapped boolean array. Processing is parallelized
308
+ for improved performance.
309
+
310
+ Parameters
311
+ ----------
312
+ bin_file_path : Path | spikeglx.Reader
313
+ Path to the SpikeGLX binary file to be processed or spikeglx.Reader object
314
+ file_saturation : Path, optional
315
+ Path where the saturation data will be saved. If None, defaults to
316
+ "_iblqc_ephysSaturation.samples.npy" in the same directory as the input file
317
+ max_voltage : np.float, optional
318
+ one-sided maximum voltage range (V), if not provided will use the spikeglx metadata
319
+ n_jobs : int, optional
320
+ Number of parallel jobs to use for processing, defaults to 4
321
+ v_per_sec : float, optional
322
+ Maximum derivative of the voltage in V/s (or units/s), defaults to 1e-8
323
+ proportion : float, optional
324
+ Threshold proportion (0-1) of channels that must be above threshold to consider
325
+ a sample as saturated, defaults to 0.2
326
+ mute_window_samples : int, optional
327
+ Number of samples for the cosine taper applied to the saturation, defaults to 7
328
+
329
+ Returns
330
+ -------
331
+ Path
332
+ Path to the file where the saturation data was saved
333
+ """
334
+ if isinstance(bin_file_path, spikeglx.Reader):
335
+ sr = bin_file_path
336
+ bin_file_path = sr.file_bin
337
+ else:
338
+ sr = spikeglx.Reader(bin_file_path)
339
+ file_saturation = (
340
+ file_saturation
341
+ if file_saturation is not None
342
+ else bin_file_path.parent.joinpath("_iblqc_ephysSaturation.samples.npy")
343
+ )
344
+ max_voltage = max_voltage if max_voltage is not None else sr.range_volts[:-1]
345
+ # Create a memory-mapped array
346
+ _saturation = np.lib.format.open_memmap(
347
+ file_saturation, dtype=bool, mode="w+", shape=(sr.ns,)
348
+ )
349
+ _saturation[:] = False # Initialize all values to False
350
+ _saturation.flush() # Make sure to flush to disk
351
+
352
+ wg = ibldsp.utils.WindowGenerator(ns=sr.ns, nswin=2**17, overlap=16)
353
+
354
+ # we can parallelize this as there is no conflict on output
355
+ def _saturation_slice(slice_win, slice_valid, slice_relative_valid):
356
+ sr = spikeglx.Reader(bin_file_path)
357
+ data = sr[slice_win, : sr.nc - sr.nsync].T
358
+ satwin, _ = ibldsp.voltage.saturation(
359
+ data,
360
+ max_voltage=max_voltage,
361
+ fs=sr.fs,
362
+ v_per_sec=v_per_sec,
363
+ proportion=proportion,
364
+ mute_window_samples=mute_window_samples,
365
+ )
366
+ _saturation[slice_valid] = satwin[slice_relative_valid]
367
+ _saturation.flush()
368
+ # getting the list of jobs as a generator allows running tqdm to monitor progress
369
+
370
+ jobs = [
371
+ joblib.delayed(_saturation_slice)(slw, slv, slrv)
372
+ for (slw, slv, slrv) in wg.slices_valid
373
+ ]
374
+ list(
375
+ tqdm.tqdm(
376
+ joblib.Parallel(return_as="generator", n_jobs=n_jobs)(jobs), total=wg.nwin
377
+ )
378
+ )
379
+
380
+ _ = saturation_samples_to_intervals(
381
+ _saturation, output_file=file_saturation.with_suffix(".pqt")
382
+ )
383
+ return file_saturation.with_suffix(".pqt")
384
+
385
+
269
386
  def interpolate_bad_channels(
270
387
  data, channel_labels=None, x=None, y=None, p=1.3, kriging_distance_um=20, gpu=False
271
388
  ):
@@ -655,6 +772,9 @@ def decompress_destripe_cbin(
655
772
  saturation_data = np.load(file_saturation)
656
773
  assert rms_data.shape[0] == time_data.shape[0] * ncv
657
774
  rms_data = rms_data.reshape(time_data.shape[0], ncv)
775
+ # Save the rms data using the original channel index
776
+ unsort = np.argsort(sr.raw_channel_order)[: -sr.nsync]
777
+ rms_data = rms_data[:, unsort]
658
778
  output_qc_path = (
659
779
  output_file.parent if output_qc_path is None else output_qc_path
660
780
  )
@@ -662,10 +782,11 @@ def decompress_destripe_cbin(
662
782
  np.save(
663
783
  output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data
664
784
  )
665
- np.save(
666
- output_qc_path.joinpath("_iblqc_ephysSaturation.samples.npy"),
785
+ saturation_samples_to_intervals(
667
786
  saturation_data,
787
+ output_file=output_qc_path.joinpath("_iblqc_ephysSaturation.samples.pqt"),
668
788
  )
789
+ file_saturation.unlink()
669
790
 
670
791
 
671
792
  def detect_bad_channels(
@@ -781,7 +902,7 @@ def detect_bad_channels(
781
902
  window_size = 25 # Choose based on desired smoothing (e.g., 25 samples)
782
903
  kernel = np.ones(window_size) / window_size
783
904
  # Apply convolution
784
- signal_filtered = np.convolve(signal_noisy, kernel, mode='same')
905
+ signal_filtered = np.convolve(signal_noisy, kernel, mode="same")
785
906
 
786
907
  diff_x = np.diff(signal_filtered)
787
908
  indx = np.where(diff_x < -0.02)[0] # hardcoded threshold
@@ -934,16 +1055,39 @@ def stack(data, word, fcn_agg=np.nanmean, header=None):
934
1055
 
935
1056
  def current_source_density(lfp, h, n=2, method="diff", sigma=1 / 3):
936
1057
  """
937
- Compute the current source density (CSD) of a given LFP signal recorded on neuropixel 1 or 2
938
- :param data: LFP signal (n_channels, n_samples)
939
- :param h: trace header dictionary
940
- :param n: the n derivative
941
- :param method: diff (straight double difference) or kernel CSD (needs the KCSD python package)
942
- :param sigma: conductivity, defaults to 1/3 S.m-1
943
- :return:
1058
+ Compute the current source density (CSD) of a given LFP signal recorded on Neuropixel probes.
1059
+
1060
+ The CSD estimates the location of current sources and sinks in neural tissue based on
1061
+ the spatial distribution of local field potentials (LFPs). This implementation supports
1062
+ both the standard double-derivative method and kernel CSD method.
1063
+
1064
+ The CSD is computed for each column of the Neuropixel probe layout separately.
1065
+
1066
+ Parameters
1067
+ ----------
1068
+ lfp : numpy.ndarray
1069
+ LFP signal array with shape (n_channels, n_samples)
1070
+ h : dict
1071
+ Trace header dictionary containing probe geometry information with keys:
1072
+ 'x', 'y' for electrode coordinates, 'col' for column indices, and 'row' for row indices
1073
+ n : int, optional
1074
+ Order of the derivative for the 'diff' method, defaults to 2
1075
+ method : str, optional
1076
+ Method to compute CSD:
1077
+ - 'diff': standard finite difference method (default)
1078
+ - 'kcsd': kernel CSD method (requires the KCSD Python package)
1079
+ sigma : float, optional
1080
+ Tissue conductivity in Siemens per meter, defaults to 1/3 S.m-1
1081
+
1082
+ Returns
1083
+ -------
1084
+ numpy.ndarray
1085
+ Current source density with the same shape as the input LFP array.
1086
+ Positive values indicate current sources, negative values indicate sinks.
1087
+ Units are in A.m-3 (amperes per cubic meter).
944
1088
  """
945
1089
  csd = np.zeros(lfp.shape, dtype=np.float64) * np.nan
946
- xy = h["x"] + 1j * h["y"]
1090
+ xy = (h["x"] + 1j * h["y"]) / 1e6
947
1091
  for col in np.unique(h["col"]):
948
1092
  ind = np.where(h["col"] == col)[0]
949
1093
  isort = np.argsort(h["row"][ind])
@@ -990,7 +1134,6 @@ def _svd_denoise(datr, rank):
990
1134
 
991
1135
  def svd_denoise_npx(datr, rank=None, collection=None):
992
1136
  """
993
-
994
1137
  :param datr: [nc, ns]
995
1138
  :param rank:
996
1139
  :param collection:
@@ -280,6 +280,7 @@ def extract_wfs_cbin(
280
280
  chunksize_samples=int(30_000),
281
281
  reader_kwargs=None,
282
282
  n_jobs=None,
283
+ wfs_dtype=np.float32,
283
284
  preprocess_steps=None,
284
285
  seed=None,
285
286
  scratch_dir=None,
spikeglx.py CHANGED
@@ -144,8 +144,8 @@ class Reader:
144
144
  sglx_file = str(self.file_bin)
145
145
  if self.is_mtscomp:
146
146
  self._raw = mtscomp.Reader()
147
- ch_file = self.ch_file or _get_companion_file(sglx_file, ".ch")
148
- self._raw.open(self.file_bin, ch_file)
147
+ self.ch_file = self._parse_ch_file()
148
+ self._raw.open(self.file_bin, self.ch_file)
149
149
  if self._raw.shape != (self.ns, self.nc):
150
150
  ftsec = self._raw.shape[0] / self.fs
151
151
  if not self.ignore_warnings: # avoid the checks for streaming data
@@ -392,10 +392,8 @@ class Reader:
392
392
  if "out" not in kwargs:
393
393
  kwargs["out"] = self.file_bin.with_suffix(".bin")
394
394
  assert self.is_mtscomp
395
- if file_ch is None:
396
- file_ch = self.file_bin.with_suffix(".ch")
397
-
398
- r = mtscomp.decompress(self.file_bin, file_ch, **kwargs)
395
+ ch_file = self._parse_ch_file(file_ch)
396
+ r = mtscomp.decompress(self.file_bin, ch_file, **kwargs)
399
397
  r.close()
400
398
  if not keep_original:
401
399
  self.close()
@@ -411,14 +409,15 @@ class Reader:
411
409
  """
412
410
  if file_meta is None:
413
411
  file_meta = Path(self.file_bin).with_suffix(".meta")
414
-
412
+ file_ch = file_ch if file_ch is not None else self.ch_file
415
413
  if scratch_dir is None:
416
414
  bin_file = Path(self.file_bin).with_suffix(".bin")
417
415
  else:
418
416
  scratch_dir.mkdir(exist_ok=True, parents=True)
419
- bin_file = scratch_dir / Path(self.file_bin).with_suffix(".bin").name
420
- file_meta_scratch = scratch_dir / file_meta.name
421
- shutil.copy(self.file_meta_data, file_meta_scratch)
417
+ bin_file = (
418
+ Path(scratch_dir).joinpath(self.file_bin.name).with_suffix(".bin")
419
+ )
420
+ shutil.copy(self.file_meta_data, bin_file.parent / self.file_meta_data.name)
422
421
  if not bin_file.exists():
423
422
  t0 = time.time()
424
423
  _logger.info("File is compressed, decompressing to a temporary file...")
@@ -460,6 +459,12 @@ class Reader:
460
459
  log_func(f"SHA1 computed: {sc}")
461
460
  return sm == sc
462
461
 
462
+ def _parse_ch_file(self, ch_file=None):
463
+ ch_file = (
464
+ _get_companion_file(self.file_bin, ".ch") if ch_file is None else ch_file
465
+ )
466
+ return ch_file
467
+
463
468
 
464
469
  class OnlineReader(Reader):
465
470
  @property
@@ -995,7 +1000,7 @@ def _mock_spikeglx_file(
995
1000
  meta_file,
996
1001
  ns,
997
1002
  nc,
998
- sync_depth,
1003
+ sync_depth=16,
999
1004
  random=False,
1000
1005
  int2volts=0.6 / 32768,
1001
1006
  corrupt=False,
@@ -4,6 +4,7 @@ import logging
4
4
  import shutil
5
5
  import unittest
6
6
  from pathlib import Path
7
+ import pandas as pd
7
8
 
8
9
  import neuropixel
9
10
  import spikeglx
@@ -84,11 +85,11 @@ class TestEphysSpikeSortingMultiProcess(unittest.TestCase):
84
85
  shutil.rmtree(self.file_path.parent)
85
86
 
86
87
  def _assert_qc(self):
87
- sr = spikeglx.Reader(self.file_path)
88
- saturated = np.load(
89
- self.file_path.parent.joinpath("_iblqc_ephysSaturation.samples.npy")
88
+ df_saturated = pd.read_parquet(
89
+ self.file_path.parent.joinpath("_iblqc_ephysSaturation.samples.pqt")
90
90
  )
91
- self.assertEqual(sr.ns, saturated.size)
91
+ self.assertTrue(df_saturated.shape[1] == 2)
92
+
92
93
  self.assertTrue(
93
94
  self.file_path.parent.joinpath("_iblqc_ephysTimeRmsAP.rms.npy").exists()
94
95
  )
@@ -0,0 +1,30 @@
1
+ import unittest
2
+
3
+ import numpy as np
4
+
5
+ import ibldsp.plots
6
+ import ibldsp.voltage
7
+
8
+
9
+ class TestPlots(unittest.TestCase):
10
+ def test_voltage(self):
11
+ ibldsp.plots.voltageshow(
12
+ (np.random.rand(384, 2000) - 0.5) / 1e6 * 20, fs=30_000
13
+ )
14
+
15
+ def test_bad_channels(self):
16
+ np.random.seed(0)
17
+ raw = np.random.randn(384, 2000) / 1e6 * 15
18
+ raw += np.random.randn(1, 2000) / 1e6 * 2
19
+ raw[66] *= 2
20
+ raw[166] = 0
21
+ fs = 30_000
22
+ labels, features = ibldsp.voltage.detect_bad_channels(raw, fs)
23
+ ibldsp.plots.show_channels_labels(
24
+ raw=raw,
25
+ fs=30_000,
26
+ channel_labels=labels,
27
+ xfeats=features,
28
+ )
29
+ np.testing.assert_array_equal(np.argwhere(labels == 2), 66)
30
+ np.testing.assert_array_equal(np.argwhere(labels == 1), 166)
@@ -2,6 +2,7 @@ from pathlib import Path
2
2
  import shutil
3
3
  import tempfile
4
4
  import unittest
5
+ import uuid
5
6
 
6
7
  import numpy as np
7
8
  from iblutil.io import hashfile
@@ -243,6 +244,8 @@ class TestsSpikeGLX_compress(unittest.TestCase):
243
244
  with spikeglx.Reader(self.file_cbin, open=False) as sc:
244
245
  self.assertTrue(sc.is_mtscomp)
245
246
  compare_data(sr_ref, sc)
247
+ # here we make sure the chunks file has been registered as a property
248
+ self.assertEqual(sc.ch_file, self.file_cbin.with_suffix(".ch"))
246
249
 
247
250
  # test decompression in-place
248
251
  sc.decompress_file(keep_original=False, overwrite=True)
@@ -669,9 +672,29 @@ class TestsBasicReader(unittest.TestCase):
669
672
  Tests the basic usage where there is a flat binary and no metadata associated
670
673
  """
671
674
 
672
- def test_get_companion_file(self):
673
- import uuid
675
+ def test_integration_companion_files_cbin(self):
676
+ with tempfile.TemporaryDirectory() as td:
677
+ bin_file_orig = Path(td) / "toto.ap.bin"
678
+ meta_file = Path(td) / f"toto.ap.{str(uuid.uuid4())}.meta"
679
+ ch_file = Path(td) / f"toto.ap.{str(uuid.uuid4())}.ch"
680
+ spikeglx._mock_spikeglx_file(
681
+ bin_file_orig,
682
+ meta_file=Path(TEST_PATH).joinpath("sample3B_g0_t0.imec1.ap.meta"),
683
+ ns=90_000,
684
+ nc=385,
685
+ )
686
+ sr = spikeglx.Reader(bin_file_orig)
687
+ sr.compress_file(keep_original=False)
688
+ cbin_file = Path(td) / f"toto.ap.{str(uuid.uuid4())}.cbin"
689
+ shutil.move(bin_file_orig.with_suffix(".cbin"), cbin_file)
690
+ shutil.move(bin_file_orig.with_suffix(".ch"), ch_file)
691
+ shutil.move(bin_file_orig.with_suffix(".meta"), meta_file)
692
+ sr = spikeglx.Reader(cbin_file)
693
+ self.assertEqual(sr.file_bin, cbin_file)
694
+ self.assertEqual(sr.file_meta_data, meta_file)
695
+ self.assertEqual(sr.ch_file, ch_file)
674
696
 
697
+ def test_get_companion_file(self):
675
698
  with tempfile.TemporaryDirectory() as td:
676
699
  sglx_file = Path(td) / f"sample3A_g0_t0.imec.ap.{str(uuid.uuid4())}.bin"
677
700
  meta_file = Path(td) / f"sample3A_g0_t0.imec.ap.{str(uuid.uuid4())}.meta"
@@ -7,7 +7,6 @@ import spikeglx
7
7
  import ibldsp.fourier as fourier
8
8
  import ibldsp.utils as utils
9
9
  import ibldsp.voltage as voltage
10
- import ibldsp.cadzow as cadzow
11
10
  import ibldsp.smooth as smooth
12
11
  import ibldsp.spiketrains as spiketrains
13
12
  import ibldsp.raw_metrics as raw_metrics
@@ -70,8 +69,8 @@ class TestSyncTimestamps(unittest.TestCase):
70
69
 
71
70
  class TestParabolicMax(unittest.TestCase):
72
71
  # expected values
73
- maxi = np.array([np.nan, 0, 3.04166667, 3.04166667, 5, 5])
74
- ipeak = np.array([np.nan, 0, 5.166667, 2.166667, 0, 7])
72
+ maxi = np.array([0.0, 0.0, 3.04166667, 3.04166667, 5, 5])
73
+ ipeak = np.array([0.0, 0.0, 5.166667, 2.166667, 0, 7])
75
74
  # input
76
75
  x = np.array(
77
76
  [
@@ -368,6 +367,13 @@ class TestWindowGenerator(unittest.TestCase):
368
367
  for first, last, amp in wg.firstlast_splicing:
369
368
  sig_out[first:last] = sig_out[first:last] + amp * sig_in[first:last]
370
369
  np.testing.assert_allclose(sig_out, sig_in)
370
+ # now performs the same operation with the new interface
371
+ sig_in = np.random.randn(600)
372
+ sig_out = np.zeros_like(sig_in)
373
+ wg = utils.WindowGenerator(ns=600, nswin=100, overlap=20)
374
+ for slicewin, amp in wg.splice:
375
+ sig_out[slicewin] = sig_out[slicewin] + amp * sig_in[slicewin]
376
+ np.testing.assert_allclose(sig_out, sig_in)
371
377
 
372
378
  def test_firstlast_valid(self):
373
379
  sig_in = np.random.randn(600)
@@ -377,6 +383,15 @@ class TestWindowGenerator(unittest.TestCase):
377
383
  sig_out[first_valid:last_valid] = sig_in[first_valid:last_valid]
378
384
  np.testing.assert_array_equal(sig_out, sig_in)
379
385
 
386
+ def test_slices_valid(self):
387
+ sig_in = np.random.randn(600)
388
+ sig_out = np.zeros_like(sig_in)
389
+ wg = utils.WindowGenerator(ns=600, nswin=39, overlap=20)
390
+ for slice_win, slice_valid, slice_win_valid in wg.slices_valid:
391
+ win = sig_in[slice_win]
392
+ sig_out[slice_valid] = win[slice_win_valid]
393
+ np.testing.assert_array_equal(sig_out, sig_in)
394
+
380
395
  def test_tscale(self):
381
396
  wg = utils.WindowGenerator(ns=500, nswin=100, overlap=50)
382
397
  ts = wg.tscale(fs=1000)
@@ -430,103 +445,6 @@ class TestFrontDetection(unittest.TestCase):
430
445
  np.testing.assert_array_equal(utils.rises(a, step=3, analog=True), 283)
431
446
 
432
447
 
433
- class TestVoltage(unittest.TestCase):
434
- def test_destripe_parameters(self):
435
- import inspect
436
-
437
- _, _, spatial_fcn = voltage._get_destripe_parameters(
438
- 30_000, None, None, k_filter=True
439
- )
440
- assert "kfilt" in inspect.getsource(spatial_fcn)
441
- _, _, spatial_fcn = voltage._get_destripe_parameters(
442
- 2_500, None, None, k_filter=False
443
- )
444
- assert "car" in inspect.getsource(spatial_fcn)
445
- _, _, spatial_fcn = voltage._get_destripe_parameters(
446
- 2_500, None, None, k_filter=None
447
- )
448
- assert "dat: dat" in inspect.getsource(spatial_fcn)
449
- _, _, spatial_fcn = voltage._get_destripe_parameters(
450
- 2_500, None, None, k_filter=lambda dat: 3 * dat
451
- )
452
- assert "lambda dat: 3 * dat" in inspect.getsource(spatial_fcn)
453
-
454
- def test_fk(self):
455
- """
456
- creates a couple of plane waves and separate them using the velocity HP filter
457
- """
458
- ntr, ns, sr, dx, v1, v2 = (500, 2000, 0.002, 5, 2000, 1000)
459
- data = np.zeros((ntr, ns), np.float32)
460
- data[:, :100] = utils.ricker(100, 4)
461
- offset = np.arange(ntr) * dx
462
- offset = np.abs(offset - np.mean(offset))
463
- data_v1 = fourier.fshift(data, offset / v1 / sr)
464
- data_v2 = fourier.fshift(data, offset / v2 / sr)
465
-
466
- noise = np.random.randn(ntr, ns) / 60
467
- fk = voltage.fk(
468
- data_v1 + data_v2 + noise,
469
- si=sr,
470
- dx=dx,
471
- vbounds=[1200, 1500],
472
- ntr_pad=10,
473
- ntr_tap=15,
474
- lagc=0.25,
475
- )
476
- fknoise = voltage.fk(
477
- noise, si=sr, dx=dx, vbounds=[1200, 1500], ntr_pad=10, ntr_tap=15, lagc=0.25
478
- )
479
- # at least 90% of the traces should be below 50dB and 98% below 40 dB
480
- assert np.mean(20 * np.log10(utils.rms(fk - data_v1 - fknoise)) < -50) > 0.9
481
- assert np.mean(20 * np.log10(utils.rms(fk - data_v1 - fknoise)) < -40) > 0.98
482
- # test the K option
483
- kbands = np.sin(np.arange(ns) / ns * 8 * np.pi) / 10
484
- fkk = voltage.fk(
485
- data_v1 + data_v2 + kbands,
486
- si=sr,
487
- dx=dx,
488
- vbounds=[1200, 1500],
489
- ntr_pad=40,
490
- ntr_tap=15,
491
- lagc=0.25,
492
- kfilt={"bounds": [0, 0.01], "btype": "hp"},
493
- )
494
- assert np.mean(20 * np.log10(utils.rms(fkk - data_v1)) < -40) > 0.9
495
- # from easyqc.gui import viewseis
496
- # a = viewseis(data_v1 + data_v2 + kbands, .002, title='input')
497
- # b = viewseis(fkk, .002, title='output')
498
- # c = viewseis(data_v1 - fkk, .002, title='test')
499
-
500
- def test_saturation(self):
501
- np.random.seed(7654)
502
- data = (np.random.randn(384, 30_000).astype(np.float32) + 20) * 1e-6
503
- saturated, mute = voltage.saturation(data, max_voltage=1200)
504
- np.testing.assert_array_equal(saturated, 0)
505
- np.testing.assert_array_equal(mute, 1.0)
506
- # now we stick a big waveform in the middle of the recorder and expect some saturation
507
- w = utils.ricker(100, 4)
508
- w = np.minimum(1200, w / w.max() * 1400)
509
- data[:, 13_600:13700] = data[0, 13_600:13700] + w * 1e-6
510
- saturated, mute = voltage.saturation(
511
- data,
512
- max_voltage=np.ones(
513
- 384,
514
- )
515
- * 1200
516
- * 1e-6,
517
- )
518
- self.assertGreater(np.sum(saturated), 5)
519
- self.assertGreater(np.sum(mute == 0), np.sum(saturated))
520
-
521
-
522
- class TestCadzow(unittest.TestCase):
523
- def test_trajectory_matrixes(self):
524
- assert np.all(
525
- cadzow.traj_matrix_indices(4) == np.array([[1, 0], [2, 1], [3, 2]])
526
- )
527
- assert np.all(cadzow.traj_matrix_indices(3) == np.array([[1, 0], [2, 1]]))
528
-
529
-
530
448
  class TestStack(unittest.TestCase):
531
449
  def test_simple_stack(self):
532
450
  ntr, ns = (24, 400)
@@ -0,0 +1,160 @@
1
+ import numpy as np
2
+ import tempfile
3
+ from pathlib import Path
4
+ import unittest
5
+
6
+ import pandas as pd
7
+
8
+ import spikeglx
9
+ import ibldsp.voltage
10
+ import ibldsp.fourier
11
+ import ibldsp.utils
12
+ import ibldsp.cadzow
13
+
14
+
15
+ class TestDestripe(unittest.TestCase):
16
+ def test_destripe_parameters(self):
17
+ import inspect
18
+
19
+ _, _, spatial_fcn = ibldsp.voltage._get_destripe_parameters(
20
+ 30_000, None, None, k_filter=True
21
+ )
22
+ assert "kfilt" in inspect.getsource(spatial_fcn)
23
+ _, _, spatial_fcn = ibldsp.voltage._get_destripe_parameters(
24
+ 2_500, None, None, k_filter=False
25
+ )
26
+ assert "car" in inspect.getsource(spatial_fcn)
27
+ _, _, spatial_fcn = ibldsp.voltage._get_destripe_parameters(
28
+ 2_500, None, None, k_filter=None
29
+ )
30
+ assert "dat: dat" in inspect.getsource(spatial_fcn)
31
+ _, _, spatial_fcn = ibldsp.voltage._get_destripe_parameters(
32
+ 2_500, None, None, k_filter=lambda dat: 3 * dat
33
+ )
34
+ assert "lambda dat: 3 * dat" in inspect.getsource(spatial_fcn)
35
+
36
+ def test_fk(self):
37
+ """
38
+ creates a couple of plane waves and separate them using the velocity HP filter
39
+ """
40
+ ntr, ns, sr, dx, v1, v2 = (500, 2000, 0.002, 5, 2000, 1000)
41
+ data = np.zeros((ntr, ns), np.float32)
42
+ data[:, :100] = ibldsp.utils.ricker(100, 4)
43
+ offset = np.arange(ntr) * dx
44
+ offset = np.abs(offset - np.mean(offset))
45
+ data_v1 = ibldsp.fourier.fshift(data, offset / v1 / sr)
46
+ data_v2 = ibldsp.fourier.fshift(data, offset / v2 / sr)
47
+
48
+ noise = np.random.randn(ntr, ns) / 60
49
+ fk = ibldsp.voltage.fk(
50
+ data_v1 + data_v2 + noise,
51
+ si=sr,
52
+ dx=dx,
53
+ vbounds=[1200, 1500],
54
+ ntr_pad=10,
55
+ ntr_tap=15,
56
+ lagc=0.25,
57
+ )
58
+ fknoise = ibldsp.voltage.fk(
59
+ noise, si=sr, dx=dx, vbounds=[1200, 1500], ntr_pad=10, ntr_tap=15, lagc=0.25
60
+ )
61
+ # at least 90% of the traces should be below 50dB and 98% below 40 dB
62
+ assert (
63
+ np.mean(20 * np.log10(ibldsp.utils.rms(fk - data_v1 - fknoise)) < -50) > 0.9
64
+ )
65
+ assert (
66
+ np.mean(20 * np.log10(ibldsp.utils.rms(fk - data_v1 - fknoise)) < -40)
67
+ > 0.98
68
+ )
69
+ # test the K option
70
+ kbands = np.sin(np.arange(ns) / ns * 8 * np.pi) / 10
71
+ fkk = ibldsp.voltage.fk(
72
+ data_v1 + data_v2 + kbands,
73
+ si=sr,
74
+ dx=dx,
75
+ vbounds=[1200, 1500],
76
+ ntr_pad=40,
77
+ ntr_tap=15,
78
+ lagc=0.25,
79
+ kfilt={"bounds": [0, 0.01], "btype": "hp"},
80
+ )
81
+ assert np.mean(20 * np.log10(ibldsp.utils.rms(fkk - data_v1)) < -40) > 0.9
82
+ # from easyqc.gui import viewseis
83
+ # a = viewseis(data_v1 + data_v2 + kbands, .002, title='input')
84
+ # b = viewseis(fkk, .002, title='output')
85
+ # c = viewseis(data_v1 - fkk, .002, title='test')
86
+
87
+
88
+ class TestSaturation(unittest.TestCase):
89
+ def test_saturation_cbin(self):
90
+ nsat = 252
91
+ ns, nc = (350_072, 384)
92
+ s2v = np.float32(2.34375e-06)
93
+ sat = ibldsp.utils.fcn_cosine([0, 100])(
94
+ np.arange(nsat)
95
+ ) - ibldsp.utils.fcn_cosine([150, 250])(np.arange(nsat))
96
+ range_volt = 0.0012
97
+ sat = (sat / s2v * 0.0012).astype(np.int16)
98
+
99
+ with tempfile.TemporaryDirectory() as temp_dir:
100
+ file_bin = Path(temp_dir) / "binary.bin"
101
+ data = np.memmap(file_bin, dtype=np.int16, mode="w+", shape=(ns, nc))
102
+ data[50_000 : 50_000 + nsat, :] = sat[:, np.newaxis]
103
+
104
+ _sr = spikeglx.Reader(
105
+ file_bin, fs=30_000, dtype=np.int16, nc=nc, nsync=0, s2v=s2v
106
+ )
107
+ file_saturation = ibldsp.voltage.saturation_cbin(
108
+ _sr, max_voltage=range_volt, n_jobs=1
109
+ )
110
+ df_sat = pd.read_parquet(file_saturation)
111
+ assert np.sum(df_sat["stop_sample"] - df_sat["start_sample"]) == 67
112
+
113
+ def test_saturation(self):
114
+ np.random.seed(7654)
115
+ data = (np.random.randn(384, 30_000).astype(np.float32) + 20) * 1e-6
116
+ saturated, mute = ibldsp.voltage.saturation(data, max_voltage=1200)
117
+ np.testing.assert_array_equal(saturated, 0)
118
+ np.testing.assert_array_equal(mute, 1.0)
119
+ # now we stick a big waveform in the middle of the recorder and expect some saturation
120
+ w = ibldsp.utils.ricker(100, 4)
121
+ w = np.minimum(1200, w / w.max() * 1400)
122
+ data[:, 13_600:13700] = data[0, 13_600:13700] + w * 1e-6
123
+ saturated, mute = ibldsp.voltage.saturation(
124
+ data,
125
+ max_voltage=np.ones(
126
+ 384,
127
+ )
128
+ * 1200
129
+ * 1e-6,
130
+ )
131
+ self.assertGreater(np.sum(saturated), 5)
132
+ self.assertGreater(np.sum(mute == 0), np.sum(saturated))
133
+
134
+ def test_saturation_intervals_output(self):
135
+ saturation = np.zeros(50_000, dtype=bool)
136
+ # we test empty files, make sure we can read/write from empty parquet
137
+ with tempfile.TemporaryDirectory() as temp_dir:
138
+ # Create a file path within the temporary directory
139
+ temp_file = Path(temp_dir).joinpath("saturation.pqt")
140
+ df_nothing = ibldsp.voltage.saturation_samples_to_intervals(
141
+ saturation, output_file=Path(temp_dir).joinpath("saturation.pqt")
142
+ )
143
+ df_nothing2 = pd.read_parquet(temp_file)
144
+ self.assertEqual(df_nothing.shape[0], 0)
145
+ self.assertEqual(df_nothing2.shape[0], 0)
146
+ # for the case with saturation intervals, we simply test the number of rows correspond to the events
147
+ saturation[3441:3509] = True
148
+ saturation[45852:45865] = True
149
+ df_sat = ibldsp.voltage.saturation_samples_to_intervals(saturation)
150
+ self.assertEqual(81, np.sum(df_sat["stop_sample"] - df_sat["start_sample"]))
151
+
152
+
153
+ class TestCadzow(unittest.TestCase):
154
+ def test_trajectory_matrixes(self):
155
+ assert np.all(
156
+ ibldsp.cadzow.traj_matrix_indices(4) == np.array([[1, 0], [2, 1], [3, 2]])
157
+ )
158
+ assert np.all(
159
+ ibldsp.cadzow.traj_matrix_indices(3) == np.array([[1, 0], [2, 1]])
160
+ )