canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.ndimage import gaussian_filter1d
|
|
7
|
+
|
|
8
|
+
from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
|
|
9
|
+
from .filters import _gaussian_filter1d
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
|
|
13
|
+
"""
|
|
14
|
+
Load and preprocess spike train data from npz file.
|
|
15
|
+
|
|
16
|
+
This function converts raw spike times into a time-binned spike matrix,
|
|
17
|
+
optionally applying Gaussian smoothing and filtering based on animal movement speed.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
spike_trains : dict
|
|
22
|
+
Dictionary containing ``'spike'`` and ``'t'``, and optionally ``'x'``/``'y'``.
|
|
23
|
+
``'spike'`` can be a dict of neuron->spike_times, a list/array of arrays, or
|
|
24
|
+
a numpy object array from ``np.load``.
|
|
25
|
+
config : SpikeEmbeddingConfig, optional
|
|
26
|
+
Configuration object controlling binning, smoothing, and speed filtering.
|
|
27
|
+
**kwargs : Any
|
|
28
|
+
Legacy keyword parameters (``res``, ``dt``, ``sigma``, ``smooth0``, ``speed0``,
|
|
29
|
+
``min_speed``). Prefer ``config`` in new code.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
tuple
|
|
34
|
+
``(spikes_bin, xx, yy, tt)`` where:
|
|
35
|
+
- ``spikes_bin`` is a (T, N) binned spike matrix.
|
|
36
|
+
- ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
|
|
37
|
+
otherwise ``None``.
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
>>> from canns.analyzer.data import SpikeEmbeddingConfig, embed_spike_trains
|
|
42
|
+
>>> cfg = SpikeEmbeddingConfig(smooth=False, speed_filter=False)
|
|
43
|
+
>>> spikes, xx, yy, tt = embed_spike_trains(mock_data, config=cfg) # doctest: +SKIP
|
|
44
|
+
>>> spikes.ndim
|
|
45
|
+
2
|
|
46
|
+
"""
|
|
47
|
+
# Handle backward compatibility and configuration
|
|
48
|
+
if config is None:
|
|
49
|
+
config = SpikeEmbeddingConfig(
|
|
50
|
+
res=kwargs.get("res", 100000),
|
|
51
|
+
dt=kwargs.get("dt", 1000),
|
|
52
|
+
sigma=kwargs.get("sigma", 5000),
|
|
53
|
+
smooth=kwargs.get("smooth0", True),
|
|
54
|
+
speed_filter=kwargs.get("speed0", True),
|
|
55
|
+
min_speed=kwargs.get("min_speed", 2.5),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
# Step 1: Extract and filter spike data
|
|
60
|
+
spikes_filtered = _extract_spike_data(spike_trains, config)
|
|
61
|
+
|
|
62
|
+
# Step 2: Create time bins
|
|
63
|
+
time_bins = _create_time_bins(spike_trains["t"], config)
|
|
64
|
+
|
|
65
|
+
# Step 3: Bin spike data
|
|
66
|
+
spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config)
|
|
67
|
+
|
|
68
|
+
# Step 4: Apply temporal smoothing if requested
|
|
69
|
+
if config.smooth:
|
|
70
|
+
spikes_bin = _apply_temporal_smoothing(spikes_bin, config)
|
|
71
|
+
|
|
72
|
+
# Step 5: Apply speed filtering if requested
|
|
73
|
+
if config.speed_filter:
|
|
74
|
+
return _apply_speed_filtering(spikes_bin, spike_trains, config)
|
|
75
|
+
|
|
76
|
+
return spikes_bin, None, None, None
|
|
77
|
+
|
|
78
|
+
except Exception as e:
|
|
79
|
+
raise ProcessingError(f"Failed to embed spike trains: {e}") from e
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _extract_spike_data(
|
|
83
|
+
spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
|
|
84
|
+
) -> dict[int, np.ndarray]:
|
|
85
|
+
"""Extract and filter spike data within time window."""
|
|
86
|
+
try:
|
|
87
|
+
# Handle different spike data formats
|
|
88
|
+
spike_data = spike_trains["spike"]
|
|
89
|
+
if hasattr(spike_data, "item") and callable(spike_data.item):
|
|
90
|
+
# numpy array with .item() method (from npz file)
|
|
91
|
+
spikes_all = spike_data[()]
|
|
92
|
+
elif isinstance(spike_data, dict):
|
|
93
|
+
# Already a dictionary
|
|
94
|
+
spikes_all = spike_data
|
|
95
|
+
elif isinstance(spike_data, (list, np.ndarray)):
|
|
96
|
+
# List or array format
|
|
97
|
+
spikes_all = spike_data
|
|
98
|
+
else:
|
|
99
|
+
# Try direct access
|
|
100
|
+
spikes_all = spike_data
|
|
101
|
+
|
|
102
|
+
t = spike_trains["t"]
|
|
103
|
+
|
|
104
|
+
min_time0 = np.min(t)
|
|
105
|
+
max_time0 = np.max(t)
|
|
106
|
+
|
|
107
|
+
# Extract spike intervals for each cell
|
|
108
|
+
if isinstance(spikes_all, dict):
|
|
109
|
+
# Dictionary format
|
|
110
|
+
spikes = {}
|
|
111
|
+
for i, key in enumerate(spikes_all.keys()):
|
|
112
|
+
s = np.array(spikes_all[key])
|
|
113
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
114
|
+
else:
|
|
115
|
+
# List/array format
|
|
116
|
+
cell_inds = np.arange(len(spikes_all))
|
|
117
|
+
spikes = {}
|
|
118
|
+
|
|
119
|
+
for i, m in enumerate(cell_inds):
|
|
120
|
+
s = np.array(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
|
|
121
|
+
# Filter spikes within time window
|
|
122
|
+
if len(s) > 0:
|
|
123
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
124
|
+
else:
|
|
125
|
+
spikes[i] = np.array([])
|
|
126
|
+
|
|
127
|
+
return spikes
|
|
128
|
+
|
|
129
|
+
except KeyError as e:
|
|
130
|
+
raise DataLoadError(f"Missing required data key: {e}") from e
|
|
131
|
+
except Exception as e:
|
|
132
|
+
raise ProcessingError(f"Error extracting spike data: {e}") from e
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
|
|
136
|
+
"""Create time bins for spike discretization."""
|
|
137
|
+
min_time0 = np.min(t)
|
|
138
|
+
max_time0 = np.max(t)
|
|
139
|
+
|
|
140
|
+
min_time = min_time0 * config.res
|
|
141
|
+
max_time = max_time0 * config.res
|
|
142
|
+
|
|
143
|
+
return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _bin_spike_data(
|
|
147
|
+
spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig
|
|
148
|
+
) -> np.ndarray:
|
|
149
|
+
"""Convert spike times to binned spike matrix."""
|
|
150
|
+
min_time = time_bins[0]
|
|
151
|
+
max_time = time_bins[-1]
|
|
152
|
+
|
|
153
|
+
spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
|
|
154
|
+
|
|
155
|
+
for n in spikes:
|
|
156
|
+
spike_times = np.array(spikes[n] * config.res - min_time, dtype=int)
|
|
157
|
+
# Filter valid spike times
|
|
158
|
+
spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)]
|
|
159
|
+
spike_times = np.array(spike_times / config.dt, int)
|
|
160
|
+
|
|
161
|
+
# Bin spikes
|
|
162
|
+
for j in spike_times:
|
|
163
|
+
if j < len(time_bins):
|
|
164
|
+
spikes_bin[j, n] += 1
|
|
165
|
+
|
|
166
|
+
return spikes_bin
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
|
|
170
|
+
"""Apply Gaussian temporal smoothing to spike matrix."""
|
|
171
|
+
# Calculate smoothing parameters (legacy implementation used custom kernel)
|
|
172
|
+
# Current implementation uses scipy's gaussian_filter1d for better performance
|
|
173
|
+
|
|
174
|
+
# Apply smoothing (simplified version - could be further optimized)
|
|
175
|
+
smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1]))
|
|
176
|
+
|
|
177
|
+
# Use scipy's gaussian_filter1d for better performance
|
|
178
|
+
|
|
179
|
+
sigma_bins = config.sigma / config.dt
|
|
180
|
+
|
|
181
|
+
for n in range(spikes_bin.shape[1]):
|
|
182
|
+
smoothed[:, n] = gaussian_filter1d(
|
|
183
|
+
spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Normalize
|
|
187
|
+
normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
|
|
188
|
+
return smoothed * normalization_factor
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _apply_speed_filtering(
|
|
192
|
+
spikes_bin: np.ndarray, spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
|
|
193
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
194
|
+
"""Apply speed-based filtering to spike data."""
|
|
195
|
+
try:
|
|
196
|
+
xx, yy, tt_pos, speed = _load_pos(
|
|
197
|
+
spike_trains["t"], spike_trains["x"], spike_trains["y"], res=config.res, dt=config.dt
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
valid = speed > config.min_speed
|
|
201
|
+
|
|
202
|
+
return (spikes_bin[valid, :], xx[valid], yy[valid], tt_pos[valid])
|
|
203
|
+
|
|
204
|
+
except KeyError as e:
|
|
205
|
+
raise DataLoadError(f"Missing position data for speed filtering: {e}") from e
|
|
206
|
+
except Exception as e:
|
|
207
|
+
raise ProcessingError(f"Error in speed filtering: {e}") from e
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _load_pos(t, x, y, res=100000, dt=1000):
|
|
211
|
+
"""
|
|
212
|
+
Compute animal position and speed from spike data file.
|
|
213
|
+
|
|
214
|
+
Interpolates animal positions to match spike time bins and computes smoothed velocity vectors and speed.
|
|
215
|
+
|
|
216
|
+
Parameters:
|
|
217
|
+
t (ndarray): Time points of the spikes (in seconds).
|
|
218
|
+
x (ndarray): X coordinates of the animal's position.
|
|
219
|
+
y (ndarray): Y coordinates of the animal's position.
|
|
220
|
+
res (int): Time scaling factor to align with spike resolution.
|
|
221
|
+
dt (int): Temporal bin size in microseconds.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
xx (ndarray): Interpolated x positions.
|
|
225
|
+
yy (ndarray): Interpolated y positions.
|
|
226
|
+
tt (ndarray): Corresponding time points (in seconds).
|
|
227
|
+
speed (ndarray): Speed at each time point (in cm/s).
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
min_time0 = np.min(t)
|
|
231
|
+
max_time0 = np.max(t)
|
|
232
|
+
|
|
233
|
+
times = np.where((t >= min_time0) & (t < max_time0))
|
|
234
|
+
x = x[times]
|
|
235
|
+
y = y[times]
|
|
236
|
+
t = t[times]
|
|
237
|
+
|
|
238
|
+
min_time = min_time0 * res
|
|
239
|
+
max_time = max_time0 * res
|
|
240
|
+
|
|
241
|
+
tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
|
|
242
|
+
|
|
243
|
+
idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1]))
|
|
244
|
+
idtt = np.digitize(np.arange(len(tt)), idt) - 1
|
|
245
|
+
|
|
246
|
+
idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1]))
|
|
247
|
+
divisor = np.bincount(idtt)
|
|
248
|
+
steps = 1.0 / divisor[divisor > 0]
|
|
249
|
+
N = np.max(divisor)
|
|
250
|
+
ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis])
|
|
251
|
+
ranges[ranges >= 1] = np.nan
|
|
252
|
+
|
|
253
|
+
rangesx = x[idx[:-1], np.newaxis] + np.multiply(
|
|
254
|
+
ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis]
|
|
255
|
+
)
|
|
256
|
+
xx = rangesx[~np.isnan(ranges)]
|
|
257
|
+
|
|
258
|
+
rangesy = y[idx[:-1], np.newaxis] + np.multiply(
|
|
259
|
+
ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
|
|
260
|
+
)
|
|
261
|
+
yy = rangesy[~np.isnan(ranges)]
|
|
262
|
+
|
|
263
|
+
xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
264
|
+
yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
|
|
265
|
+
dx = (xxs[1:] - xxs[:-1]) * 100
|
|
266
|
+
dy = (yys[1:] - yys[:-1]) * 100
|
|
267
|
+
speed = np.sqrt(dx**2 + dy**2) / 0.01
|
|
268
|
+
speed = np.concatenate(([speed[0]], speed))
|
|
269
|
+
return xx, yy, tt, speed
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numbers
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.exceptions import AxisError
|
|
7
|
+
from scipy.ndimage import _nd_image, _ni_support
|
|
8
|
+
from scipy.ndimage._filters import _invalid_origin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _gaussian_filter1d(
|
|
12
|
+
input,
|
|
13
|
+
sigma,
|
|
14
|
+
axis=-1,
|
|
15
|
+
order=0,
|
|
16
|
+
output=None,
|
|
17
|
+
mode="reflect",
|
|
18
|
+
cval=0.0,
|
|
19
|
+
truncate=4.0,
|
|
20
|
+
*,
|
|
21
|
+
radius=None,
|
|
22
|
+
):
|
|
23
|
+
"""1-D Gaussian filter.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
%(input)s
|
|
28
|
+
sigma : scalar
|
|
29
|
+
standard deviation for Gaussian kernel
|
|
30
|
+
%(axis)s
|
|
31
|
+
order : int, optional
|
|
32
|
+
An order of 0 corresponds to convolution with a Gaussian
|
|
33
|
+
kernel. A positive order corresponds to convolution with
|
|
34
|
+
that derivative of a Gaussian.
|
|
35
|
+
%(output)s
|
|
36
|
+
%(mode_reflect)s
|
|
37
|
+
%(cval)s
|
|
38
|
+
truncate : float, optional
|
|
39
|
+
Truncate the filter at this many standard deviations.
|
|
40
|
+
Default is 4.0.
|
|
41
|
+
radius : None or int, optional
|
|
42
|
+
Radius of the Gaussian kernel. If specified, the size of
|
|
43
|
+
the kernel will be ``2*radius + 1``, and `truncate` is ignored.
|
|
44
|
+
Default is None.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
gaussian_filter1d : ndarray
|
|
49
|
+
|
|
50
|
+
Notes
|
|
51
|
+
-----
|
|
52
|
+
The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
|
|
53
|
+
`radius` is None, a default ``radius = round(truncate * sigma)`` will be
|
|
54
|
+
used.
|
|
55
|
+
|
|
56
|
+
Examples
|
|
57
|
+
--------
|
|
58
|
+
>>> from scipy.ndimage import gaussian_filter1d
|
|
59
|
+
>>> import numpy as np
|
|
60
|
+
>>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
|
|
61
|
+
array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
|
|
62
|
+
>>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
|
|
63
|
+
array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
|
|
64
|
+
>>> import matplotlib.pyplot as plt
|
|
65
|
+
>>> rng = np.random.default_rng()
|
|
66
|
+
>>> x = rng.standard_normal(101).cumsum()
|
|
67
|
+
>>> y3 = _gaussian_filter1d(x, 3)
|
|
68
|
+
>>> y6 = _gaussian_filter1d(x, 6)
|
|
69
|
+
>>> plt.plot(x, 'k', label='original data')
|
|
70
|
+
>>> plt.plot(y3, '--', label='filtered, sigma=3')
|
|
71
|
+
>>> plt.plot(y6, ':', label='filtered, sigma=6')
|
|
72
|
+
>>> plt.legend()
|
|
73
|
+
>>> plt.grid()
|
|
74
|
+
>>> plt.show()
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
sd = float(sigma)
|
|
78
|
+
# make the radius of the filter equal to truncate standard deviations
|
|
79
|
+
lw = int(truncate * sd + 0.5)
|
|
80
|
+
if radius is not None:
|
|
81
|
+
lw = radius
|
|
82
|
+
if not isinstance(lw, numbers.Integral) or lw < 0:
|
|
83
|
+
raise ValueError("Radius must be a nonnegative integer.")
|
|
84
|
+
# Since we are calling correlate, not convolve, revert the kernel
|
|
85
|
+
weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
|
|
86
|
+
return _correlate1d(input, weights, axis, output, mode, cval, 0)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _gaussian_kernel1d(sigma, order, radius):
|
|
90
|
+
"""
|
|
91
|
+
Computes a 1-D Gaussian convolution kernel.
|
|
92
|
+
"""
|
|
93
|
+
if order < 0:
|
|
94
|
+
raise ValueError("order must be non-negative")
|
|
95
|
+
exponent_range = np.arange(order + 1)
|
|
96
|
+
sigma2 = sigma * sigma
|
|
97
|
+
x = np.arange(-radius, radius + 1)
|
|
98
|
+
phi_x = np.exp(-0.5 / sigma2 * x**2)
|
|
99
|
+
phi_x = phi_x / phi_x.sum()
|
|
100
|
+
|
|
101
|
+
if order == 0:
|
|
102
|
+
return phi_x
|
|
103
|
+
else:
|
|
104
|
+
# f(x) = q(x) * phi(x) = q(x) * exp(p(x))
|
|
105
|
+
# f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
|
|
106
|
+
# p'(x) = -1 / sigma ** 2
|
|
107
|
+
# Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
|
|
108
|
+
# coefficients of q(x)
|
|
109
|
+
q = np.zeros(order + 1)
|
|
110
|
+
q[0] = 1
|
|
111
|
+
D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
|
|
112
|
+
P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
|
|
113
|
+
Q_deriv = D + P
|
|
114
|
+
for _ in range(order):
|
|
115
|
+
q = Q_deriv.dot(q)
|
|
116
|
+
q = (x[:, None] ** exponent_range).dot(q)
|
|
117
|
+
return q * phi_x
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
|
|
121
|
+
"""Calculate a 1-D correlation along the given axis.
|
|
122
|
+
|
|
123
|
+
The lines of the array along the given axis are correlated with the
|
|
124
|
+
given weights.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
%(input)s
|
|
129
|
+
weights : array
|
|
130
|
+
1-D sequence of numbers.
|
|
131
|
+
%(axis)s
|
|
132
|
+
%(output)s
|
|
133
|
+
%(mode_reflect)s
|
|
134
|
+
%(cval)s
|
|
135
|
+
%(origin)s
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
result : ndarray
|
|
140
|
+
Correlation result. Has the same shape as `input`.
|
|
141
|
+
|
|
142
|
+
Examples
|
|
143
|
+
--------
|
|
144
|
+
>>> from scipy.ndimage import correlate1d
|
|
145
|
+
>>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
|
|
146
|
+
array([ 8, 26, 8, 12, 7, 28, 36, 9])
|
|
147
|
+
"""
|
|
148
|
+
input = np.asarray(input)
|
|
149
|
+
weights = np.asarray(weights)
|
|
150
|
+
complex_input = input.dtype.kind == "c"
|
|
151
|
+
complex_weights = weights.dtype.kind == "c"
|
|
152
|
+
if complex_input or complex_weights:
|
|
153
|
+
if complex_weights:
|
|
154
|
+
weights = weights.conj()
|
|
155
|
+
weights = weights.astype(np.complex128, copy=False)
|
|
156
|
+
kwargs = dict(axis=axis, mode=mode, origin=origin)
|
|
157
|
+
output = _ni_support._get_output(output, input, complex_output=True)
|
|
158
|
+
return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
|
|
159
|
+
|
|
160
|
+
output = _ni_support._get_output(output, input)
|
|
161
|
+
weights = np.asarray(weights, dtype=np.float64)
|
|
162
|
+
if weights.ndim != 1 or weights.shape[0] < 1:
|
|
163
|
+
raise RuntimeError("no filter weights given")
|
|
164
|
+
if not weights.flags.contiguous:
|
|
165
|
+
weights = weights.copy()
|
|
166
|
+
axis = _normalize_axis_index(axis, input.ndim)
|
|
167
|
+
if _invalid_origin(origin, len(weights)):
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"Invalid origin; origin must satisfy "
|
|
170
|
+
"-(len(weights) // 2) <= origin <= "
|
|
171
|
+
"(len(weights)-1) // 2"
|
|
172
|
+
)
|
|
173
|
+
mode = _ni_support._extend_mode_to_code(mode)
|
|
174
|
+
_nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
|
|
175
|
+
return output
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
|
|
179
|
+
"""Complex convolution via a linear combination of real convolutions."""
|
|
180
|
+
complex_input = input.dtype.kind == "c"
|
|
181
|
+
complex_weights = weights.dtype.kind == "c"
|
|
182
|
+
if complex_input and complex_weights:
|
|
183
|
+
# real component of the output
|
|
184
|
+
func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
|
|
185
|
+
output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
|
|
186
|
+
# imaginary component of the output
|
|
187
|
+
func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
|
|
188
|
+
output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
|
|
189
|
+
elif complex_input:
|
|
190
|
+
func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
|
|
191
|
+
func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
|
|
192
|
+
else:
|
|
193
|
+
if np.iscomplexobj(cval):
|
|
194
|
+
raise ValueError("Cannot provide a complex-valued cval when the input is real.")
|
|
195
|
+
func(input, weights.real, output=output.real, cval=cval, **kwargs)
|
|
196
|
+
func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
|
|
197
|
+
return output
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _normalize_axis_index(axis, ndim):
|
|
201
|
+
# Check if `axis` is in the correct range and normalize it
|
|
202
|
+
if axis < -ndim or axis >= ndim:
|
|
203
|
+
msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
|
|
204
|
+
raise AxisError(msg)
|
|
205
|
+
|
|
206
|
+
if axis < 0:
|
|
207
|
+
axis = axis + ndim
|
|
208
|
+
return axis
|