canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter1d
7
+
8
+ from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
9
+ from .filters import _gaussian_filter1d
10
+
11
+
12
+ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
13
+ """
14
+ Load and preprocess spike train data from npz file.
15
+
16
+ This function converts raw spike times into a time-binned spike matrix,
17
+ optionally applying Gaussian smoothing and filtering based on animal movement speed.
18
+
19
+ Parameters
20
+ ----------
21
+ spike_trains : dict
22
+ Dictionary containing ``'spike'`` and ``'t'``, and optionally ``'x'``/``'y'``.
23
+ ``'spike'`` can be a dict of neuron->spike_times, a list/array of arrays, or
24
+ a numpy object array from ``np.load``.
25
+ config : SpikeEmbeddingConfig, optional
26
+ Configuration object controlling binning, smoothing, and speed filtering.
27
+ **kwargs : Any
28
+ Legacy keyword parameters (``res``, ``dt``, ``sigma``, ``smooth0``, ``speed0``,
29
+ ``min_speed``). Prefer ``config`` in new code.
30
+
31
+ Returns
32
+ -------
33
+ tuple
34
+ ``(spikes_bin, xx, yy, tt)`` where:
35
+ - ``spikes_bin`` is a (T, N) binned spike matrix.
36
+ - ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
37
+ otherwise ``None``.
38
+
39
+ Examples
40
+ --------
41
+ >>> from canns.analyzer.data import SpikeEmbeddingConfig, embed_spike_trains
42
+ >>> cfg = SpikeEmbeddingConfig(smooth=False, speed_filter=False)
43
+ >>> spikes, xx, yy, tt = embed_spike_trains(mock_data, config=cfg) # doctest: +SKIP
44
+ >>> spikes.ndim
45
+ 2
46
+ """
47
+ # Handle backward compatibility and configuration
48
+ if config is None:
49
+ config = SpikeEmbeddingConfig(
50
+ res=kwargs.get("res", 100000),
51
+ dt=kwargs.get("dt", 1000),
52
+ sigma=kwargs.get("sigma", 5000),
53
+ smooth=kwargs.get("smooth0", True),
54
+ speed_filter=kwargs.get("speed0", True),
55
+ min_speed=kwargs.get("min_speed", 2.5),
56
+ )
57
+
58
+ try:
59
+ # Step 1: Extract and filter spike data
60
+ spikes_filtered = _extract_spike_data(spike_trains, config)
61
+
62
+ # Step 2: Create time bins
63
+ time_bins = _create_time_bins(spike_trains["t"], config)
64
+
65
+ # Step 3: Bin spike data
66
+ spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config)
67
+
68
+ # Step 4: Apply temporal smoothing if requested
69
+ if config.smooth:
70
+ spikes_bin = _apply_temporal_smoothing(spikes_bin, config)
71
+
72
+ # Step 5: Apply speed filtering if requested
73
+ if config.speed_filter:
74
+ return _apply_speed_filtering(spikes_bin, spike_trains, config)
75
+
76
+ return spikes_bin, None, None, None
77
+
78
+ except Exception as e:
79
+ raise ProcessingError(f"Failed to embed spike trains: {e}") from e
80
+
81
+
82
+ def _extract_spike_data(
83
+ spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
84
+ ) -> dict[int, np.ndarray]:
85
+ """Extract and filter spike data within time window."""
86
+ try:
87
+ # Handle different spike data formats
88
+ spike_data = spike_trains["spike"]
89
+ if hasattr(spike_data, "item") and callable(spike_data.item):
90
+ # numpy array with .item() method (from npz file)
91
+ spikes_all = spike_data[()]
92
+ elif isinstance(spike_data, dict):
93
+ # Already a dictionary
94
+ spikes_all = spike_data
95
+ elif isinstance(spike_data, (list, np.ndarray)):
96
+ # List or array format
97
+ spikes_all = spike_data
98
+ else:
99
+ # Try direct access
100
+ spikes_all = spike_data
101
+
102
+ t = spike_trains["t"]
103
+
104
+ min_time0 = np.min(t)
105
+ max_time0 = np.max(t)
106
+
107
+ # Extract spike intervals for each cell
108
+ if isinstance(spikes_all, dict):
109
+ # Dictionary format
110
+ spikes = {}
111
+ for i, key in enumerate(spikes_all.keys()):
112
+ s = np.array(spikes_all[key])
113
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
114
+ else:
115
+ # List/array format
116
+ cell_inds = np.arange(len(spikes_all))
117
+ spikes = {}
118
+
119
+ for i, m in enumerate(cell_inds):
120
+ s = np.array(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
121
+ # Filter spikes within time window
122
+ if len(s) > 0:
123
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
124
+ else:
125
+ spikes[i] = np.array([])
126
+
127
+ return spikes
128
+
129
+ except KeyError as e:
130
+ raise DataLoadError(f"Missing required data key: {e}") from e
131
+ except Exception as e:
132
+ raise ProcessingError(f"Error extracting spike data: {e}") from e
133
+
134
+
135
+ def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
136
+ """Create time bins for spike discretization."""
137
+ min_time0 = np.min(t)
138
+ max_time0 = np.max(t)
139
+
140
+ min_time = min_time0 * config.res
141
+ max_time = max_time0 * config.res
142
+
143
+ return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt)
144
+
145
+
146
+ def _bin_spike_data(
147
+ spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig
148
+ ) -> np.ndarray:
149
+ """Convert spike times to binned spike matrix."""
150
+ min_time = time_bins[0]
151
+ max_time = time_bins[-1]
152
+
153
+ spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
154
+
155
+ for n in spikes:
156
+ spike_times = np.array(spikes[n] * config.res - min_time, dtype=int)
157
+ # Filter valid spike times
158
+ spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)]
159
+ spike_times = np.array(spike_times / config.dt, int)
160
+
161
+ # Bin spikes
162
+ for j in spike_times:
163
+ if j < len(time_bins):
164
+ spikes_bin[j, n] += 1
165
+
166
+ return spikes_bin
167
+
168
+
169
+ def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
170
+ """Apply Gaussian temporal smoothing to spike matrix."""
171
+ # Calculate smoothing parameters (legacy implementation used custom kernel)
172
+ # Current implementation uses scipy's gaussian_filter1d for better performance
173
+
174
+ # Apply smoothing (simplified version - could be further optimized)
175
+ smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1]))
176
+
177
+ # Use scipy's gaussian_filter1d for better performance
178
+
179
+ sigma_bins = config.sigma / config.dt
180
+
181
+ for n in range(spikes_bin.shape[1]):
182
+ smoothed[:, n] = gaussian_filter1d(
183
+ spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant"
184
+ )
185
+
186
+ # Normalize
187
+ normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
188
+ return smoothed * normalization_factor
189
+
190
+
191
+ def _apply_speed_filtering(
192
+ spikes_bin: np.ndarray, spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
193
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
194
+ """Apply speed-based filtering to spike data."""
195
+ try:
196
+ xx, yy, tt_pos, speed = _load_pos(
197
+ spike_trains["t"], spike_trains["x"], spike_trains["y"], res=config.res, dt=config.dt
198
+ )
199
+
200
+ valid = speed > config.min_speed
201
+
202
+ return (spikes_bin[valid, :], xx[valid], yy[valid], tt_pos[valid])
203
+
204
+ except KeyError as e:
205
+ raise DataLoadError(f"Missing position data for speed filtering: {e}") from e
206
+ except Exception as e:
207
+ raise ProcessingError(f"Error in speed filtering: {e}") from e
208
+
209
+
210
+ def _load_pos(t, x, y, res=100000, dt=1000):
211
+ """
212
+ Compute animal position and speed from spike data file.
213
+
214
+ Interpolates animal positions to match spike time bins and computes smoothed velocity vectors and speed.
215
+
216
+ Parameters:
217
+ t (ndarray): Time points of the spikes (in seconds).
218
+ x (ndarray): X coordinates of the animal's position.
219
+ y (ndarray): Y coordinates of the animal's position.
220
+ res (int): Time scaling factor to align with spike resolution.
221
+ dt (int): Temporal bin size in microseconds.
222
+
223
+ Returns:
224
+ xx (ndarray): Interpolated x positions.
225
+ yy (ndarray): Interpolated y positions.
226
+ tt (ndarray): Corresponding time points (in seconds).
227
+ speed (ndarray): Speed at each time point (in cm/s).
228
+ """
229
+
230
+ min_time0 = np.min(t)
231
+ max_time0 = np.max(t)
232
+
233
+ times = np.where((t >= min_time0) & (t < max_time0))
234
+ x = x[times]
235
+ y = y[times]
236
+ t = t[times]
237
+
238
+ min_time = min_time0 * res
239
+ max_time = max_time0 * res
240
+
241
+ tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
242
+
243
+ idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1]))
244
+ idtt = np.digitize(np.arange(len(tt)), idt) - 1
245
+
246
+ idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1]))
247
+ divisor = np.bincount(idtt)
248
+ steps = 1.0 / divisor[divisor > 0]
249
+ N = np.max(divisor)
250
+ ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis])
251
+ ranges[ranges >= 1] = np.nan
252
+
253
+ rangesx = x[idx[:-1], np.newaxis] + np.multiply(
254
+ ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis]
255
+ )
256
+ xx = rangesx[~np.isnan(ranges)]
257
+
258
+ rangesy = y[idx[:-1], np.newaxis] + np.multiply(
259
+ ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
260
+ )
261
+ yy = rangesy[~np.isnan(ranges)]
262
+
263
+ xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
264
+ yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
265
+ dx = (xxs[1:] - xxs[:-1]) * 100
266
+ dy = (yys[1:] - yys[:-1]) * 100
267
+ speed = np.sqrt(dx**2 + dy**2) / 0.01
268
+ speed = np.concatenate(([speed[0]], speed))
269
+ return xx, yy, tt, speed
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ import numbers
4
+
5
+ import numpy as np
6
+ from numpy.exceptions import AxisError
7
+ from scipy.ndimage import _nd_image, _ni_support
8
+ from scipy.ndimage._filters import _invalid_origin
9
+
10
+
11
+ def _gaussian_filter1d(
12
+ input,
13
+ sigma,
14
+ axis=-1,
15
+ order=0,
16
+ output=None,
17
+ mode="reflect",
18
+ cval=0.0,
19
+ truncate=4.0,
20
+ *,
21
+ radius=None,
22
+ ):
23
+ """1-D Gaussian filter.
24
+
25
+ Parameters
26
+ ----------
27
+ %(input)s
28
+ sigma : scalar
29
+ standard deviation for Gaussian kernel
30
+ %(axis)s
31
+ order : int, optional
32
+ An order of 0 corresponds to convolution with a Gaussian
33
+ kernel. A positive order corresponds to convolution with
34
+ that derivative of a Gaussian.
35
+ %(output)s
36
+ %(mode_reflect)s
37
+ %(cval)s
38
+ truncate : float, optional
39
+ Truncate the filter at this many standard deviations.
40
+ Default is 4.0.
41
+ radius : None or int, optional
42
+ Radius of the Gaussian kernel. If specified, the size of
43
+ the kernel will be ``2*radius + 1``, and `truncate` is ignored.
44
+ Default is None.
45
+
46
+ Returns
47
+ -------
48
+ gaussian_filter1d : ndarray
49
+
50
+ Notes
51
+ -----
52
+ The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
53
+ `radius` is None, a default ``radius = round(truncate * sigma)`` will be
54
+ used.
55
+
56
+ Examples
57
+ --------
58
+ >>> from scipy.ndimage import gaussian_filter1d
59
+ >>> import numpy as np
60
+ >>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
61
+ array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
62
+ >>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
63
+ array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
64
+ >>> import matplotlib.pyplot as plt
65
+ >>> rng = np.random.default_rng()
66
+ >>> x = rng.standard_normal(101).cumsum()
67
+ >>> y3 = _gaussian_filter1d(x, 3)
68
+ >>> y6 = _gaussian_filter1d(x, 6)
69
+ >>> plt.plot(x, 'k', label='original data')
70
+ >>> plt.plot(y3, '--', label='filtered, sigma=3')
71
+ >>> plt.plot(y6, ':', label='filtered, sigma=6')
72
+ >>> plt.legend()
73
+ >>> plt.grid()
74
+ >>> plt.show()
75
+
76
+ """
77
+ sd = float(sigma)
78
+ # make the radius of the filter equal to truncate standard deviations
79
+ lw = int(truncate * sd + 0.5)
80
+ if radius is not None:
81
+ lw = radius
82
+ if not isinstance(lw, numbers.Integral) or lw < 0:
83
+ raise ValueError("Radius must be a nonnegative integer.")
84
+ # Since we are calling correlate, not convolve, revert the kernel
85
+ weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
86
+ return _correlate1d(input, weights, axis, output, mode, cval, 0)
87
+
88
+
89
+ def _gaussian_kernel1d(sigma, order, radius):
90
+ """
91
+ Computes a 1-D Gaussian convolution kernel.
92
+ """
93
+ if order < 0:
94
+ raise ValueError("order must be non-negative")
95
+ exponent_range = np.arange(order + 1)
96
+ sigma2 = sigma * sigma
97
+ x = np.arange(-radius, radius + 1)
98
+ phi_x = np.exp(-0.5 / sigma2 * x**2)
99
+ phi_x = phi_x / phi_x.sum()
100
+
101
+ if order == 0:
102
+ return phi_x
103
+ else:
104
+ # f(x) = q(x) * phi(x) = q(x) * exp(p(x))
105
+ # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
106
+ # p'(x) = -1 / sigma ** 2
107
+ # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
108
+ # coefficients of q(x)
109
+ q = np.zeros(order + 1)
110
+ q[0] = 1
111
+ D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
112
+ P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
113
+ Q_deriv = D + P
114
+ for _ in range(order):
115
+ q = Q_deriv.dot(q)
116
+ q = (x[:, None] ** exponent_range).dot(q)
117
+ return q * phi_x
118
+
119
+
120
+ def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
121
+ """Calculate a 1-D correlation along the given axis.
122
+
123
+ The lines of the array along the given axis are correlated with the
124
+ given weights.
125
+
126
+ Parameters
127
+ ----------
128
+ %(input)s
129
+ weights : array
130
+ 1-D sequence of numbers.
131
+ %(axis)s
132
+ %(output)s
133
+ %(mode_reflect)s
134
+ %(cval)s
135
+ %(origin)s
136
+
137
+ Returns
138
+ -------
139
+ result : ndarray
140
+ Correlation result. Has the same shape as `input`.
141
+
142
+ Examples
143
+ --------
144
+ >>> from scipy.ndimage import correlate1d
145
+ >>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
146
+ array([ 8, 26, 8, 12, 7, 28, 36, 9])
147
+ """
148
+ input = np.asarray(input)
149
+ weights = np.asarray(weights)
150
+ complex_input = input.dtype.kind == "c"
151
+ complex_weights = weights.dtype.kind == "c"
152
+ if complex_input or complex_weights:
153
+ if complex_weights:
154
+ weights = weights.conj()
155
+ weights = weights.astype(np.complex128, copy=False)
156
+ kwargs = dict(axis=axis, mode=mode, origin=origin)
157
+ output = _ni_support._get_output(output, input, complex_output=True)
158
+ return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
159
+
160
+ output = _ni_support._get_output(output, input)
161
+ weights = np.asarray(weights, dtype=np.float64)
162
+ if weights.ndim != 1 or weights.shape[0] < 1:
163
+ raise RuntimeError("no filter weights given")
164
+ if not weights.flags.contiguous:
165
+ weights = weights.copy()
166
+ axis = _normalize_axis_index(axis, input.ndim)
167
+ if _invalid_origin(origin, len(weights)):
168
+ raise ValueError(
169
+ "Invalid origin; origin must satisfy "
170
+ "-(len(weights) // 2) <= origin <= "
171
+ "(len(weights)-1) // 2"
172
+ )
173
+ mode = _ni_support._extend_mode_to_code(mode)
174
+ _nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
175
+ return output
176
+
177
+
178
+ def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
179
+ """Complex convolution via a linear combination of real convolutions."""
180
+ complex_input = input.dtype.kind == "c"
181
+ complex_weights = weights.dtype.kind == "c"
182
+ if complex_input and complex_weights:
183
+ # real component of the output
184
+ func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
185
+ output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
186
+ # imaginary component of the output
187
+ func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
188
+ output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
189
+ elif complex_input:
190
+ func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
191
+ func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
192
+ else:
193
+ if np.iscomplexobj(cval):
194
+ raise ValueError("Cannot provide a complex-valued cval when the input is real.")
195
+ func(input, weights.real, output=output.real, cval=cval, **kwargs)
196
+ func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
197
+ return output
198
+
199
+
200
+ def _normalize_axis_index(axis, ndim):
201
+ # Check if `axis` is in the correct range and normalize it
202
+ if axis < -ndim or axis >= ndim:
203
+ msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
204
+ raise AxisError(msg)
205
+
206
+ if axis < 0:
207
+ axis = axis + ndim
208
+ return axis