canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Spatial Analysis Functions
|
|
3
|
+
|
|
4
|
+
Functions for computing spatial firing rate maps and related metrics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy import ndimage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compute_rate_map(
|
|
12
|
+
spike_times: np.ndarray,
|
|
13
|
+
positions: np.ndarray,
|
|
14
|
+
time_stamps: np.ndarray,
|
|
15
|
+
spatial_bins: int = 20,
|
|
16
|
+
position_range: tuple[float, float] | None = None,
|
|
17
|
+
smoothing_sigma: float = 2.0,
|
|
18
|
+
min_occupancy: float = 0.0,
|
|
19
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
20
|
+
"""
|
|
21
|
+
Compute 2D spatial firing rate map.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
spike_times : np.ndarray
|
|
26
|
+
Spike times in seconds
|
|
27
|
+
positions : np.ndarray
|
|
28
|
+
Animal positions, shape (N, 2) where columns are (x, y) coordinates
|
|
29
|
+
time_stamps : np.ndarray
|
|
30
|
+
Time stamps for position samples
|
|
31
|
+
spatial_bins : int or tuple, optional
|
|
32
|
+
Number of spatial bins. If int, uses same for both dimensions.
|
|
33
|
+
Default is 20.
|
|
34
|
+
position_range : tuple of float, optional
|
|
35
|
+
(min, max) for position coordinates. If None, inferred from data.
|
|
36
|
+
smoothing_sigma : float, optional
|
|
37
|
+
Standard deviation of Gaussian smoothing kernel. Default is 2.0.
|
|
38
|
+
min_occupancy : float, optional
|
|
39
|
+
Minimum occupancy (seconds) for valid bins. Default is 0.0.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
rate_map : np.ndarray
|
|
44
|
+
2D firing rate map (Hz), shape (spatial_bins, spatial_bins)
|
|
45
|
+
occupancy_map : np.ndarray
|
|
46
|
+
Time spent in each bin (seconds)
|
|
47
|
+
x_edges : np.ndarray
|
|
48
|
+
Bin edges for x coordinate
|
|
49
|
+
y_edges : np.ndarray
|
|
50
|
+
Bin edges for y coordinate
|
|
51
|
+
|
|
52
|
+
Examples
|
|
53
|
+
--------
|
|
54
|
+
>>> # Simulate data
|
|
55
|
+
>>> time_stamps = np.linspace(0, 100, 10000)
|
|
56
|
+
>>> positions = np.column_stack([
|
|
57
|
+
... np.sin(time_stamps * 0.1),
|
|
58
|
+
... np.cos(time_stamps * 0.1)
|
|
59
|
+
... ])
|
|
60
|
+
>>> spike_times = time_stamps[::50] # Some spikes
|
|
61
|
+
>>> rate_map, occ, x_edges, y_edges = compute_rate_map(
|
|
62
|
+
... spike_times, positions, time_stamps
|
|
63
|
+
... )
|
|
64
|
+
"""
|
|
65
|
+
# Handle bin specification
|
|
66
|
+
if isinstance(spatial_bins, int):
|
|
67
|
+
n_bins_x = n_bins_y = spatial_bins
|
|
68
|
+
else:
|
|
69
|
+
n_bins_x, n_bins_y = spatial_bins
|
|
70
|
+
|
|
71
|
+
# Determine position range
|
|
72
|
+
if position_range is None:
|
|
73
|
+
x_min, x_max = positions[:, 0].min(), positions[:, 0].max()
|
|
74
|
+
y_min, y_max = positions[:, 1].min(), positions[:, 1].max()
|
|
75
|
+
else:
|
|
76
|
+
x_min, x_max = position_range
|
|
77
|
+
y_min, y_max = position_range
|
|
78
|
+
|
|
79
|
+
# Create bin edges
|
|
80
|
+
x_edges = np.linspace(x_min, x_max, n_bins_x + 1)
|
|
81
|
+
y_edges = np.linspace(y_min, y_max, n_bins_y + 1)
|
|
82
|
+
|
|
83
|
+
# Compute occupancy map
|
|
84
|
+
dt = np.median(np.diff(time_stamps))
|
|
85
|
+
occupancy_map, _, _ = np.histogram2d(positions[:, 0], positions[:, 1], bins=[x_edges, y_edges])
|
|
86
|
+
occupancy_map = occupancy_map.T * dt # Transpose to match image orientation
|
|
87
|
+
|
|
88
|
+
# Get spike positions
|
|
89
|
+
spike_positions = np.column_stack(
|
|
90
|
+
[
|
|
91
|
+
np.interp(spike_times, time_stamps, positions[:, 0]),
|
|
92
|
+
np.interp(spike_times, time_stamps, positions[:, 1]),
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Compute spike count map
|
|
97
|
+
spike_map, _, _ = np.histogram2d(
|
|
98
|
+
spike_positions[:, 0], spike_positions[:, 1], bins=[x_edges, y_edges]
|
|
99
|
+
)
|
|
100
|
+
spike_map = spike_map.T # Transpose
|
|
101
|
+
|
|
102
|
+
# Compute rate map
|
|
103
|
+
rate_map = np.zeros_like(occupancy_map)
|
|
104
|
+
valid = occupancy_map > min_occupancy
|
|
105
|
+
rate_map[valid] = spike_map[valid] / occupancy_map[valid]
|
|
106
|
+
|
|
107
|
+
# Apply Gaussian smoothing
|
|
108
|
+
if smoothing_sigma > 0:
|
|
109
|
+
rate_map = ndimage.gaussian_filter(rate_map, sigma=smoothing_sigma)
|
|
110
|
+
|
|
111
|
+
return rate_map, occupancy_map, x_edges, y_edges
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compute_rate_map_from_binned(
|
|
115
|
+
x: np.ndarray,
|
|
116
|
+
y: np.ndarray,
|
|
117
|
+
spike_counts: np.ndarray,
|
|
118
|
+
bins: int = 35,
|
|
119
|
+
min_occupancy: float = 0.0,
|
|
120
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
121
|
+
"""
|
|
122
|
+
Compute a 2D rate map from binned spike counts aligned to positions.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
x : np.ndarray
|
|
127
|
+
X positions aligned to spike_counts (same length).
|
|
128
|
+
y : np.ndarray
|
|
129
|
+
Y positions aligned to spike_counts (same length).
|
|
130
|
+
spike_counts : np.ndarray
|
|
131
|
+
Spike counts per time bin (same length as x/y).
|
|
132
|
+
bins : int, optional
|
|
133
|
+
Number of spatial bins per dimension. Default is 35.
|
|
134
|
+
min_occupancy : float, optional
|
|
135
|
+
Minimum occupancy count for valid bins. Default is 0.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
rate_map : np.ndarray
|
|
140
|
+
2D firing rate map, shape (bins, bins).
|
|
141
|
+
occupancy_map : np.ndarray
|
|
142
|
+
Occupancy counts per bin.
|
|
143
|
+
x_edges : np.ndarray
|
|
144
|
+
Bin edges for x coordinate.
|
|
145
|
+
y_edges : np.ndarray
|
|
146
|
+
Bin edges for y coordinate.
|
|
147
|
+
"""
|
|
148
|
+
x = np.asarray(x, dtype=float).ravel()
|
|
149
|
+
y = np.asarray(y, dtype=float).ravel()
|
|
150
|
+
spike_counts = np.asarray(spike_counts, dtype=float).ravel()
|
|
151
|
+
|
|
152
|
+
T = min(len(x), len(y), len(spike_counts))
|
|
153
|
+
if T == 0:
|
|
154
|
+
raise ValueError("x, y, and spike_counts must be non-empty and aligned.")
|
|
155
|
+
|
|
156
|
+
x = x[:T]
|
|
157
|
+
y = y[:T]
|
|
158
|
+
spike_counts = spike_counts[:T]
|
|
159
|
+
|
|
160
|
+
valid = np.isfinite(x) & np.isfinite(y) & np.isfinite(spike_counts)
|
|
161
|
+
x = x[valid]
|
|
162
|
+
y = y[valid]
|
|
163
|
+
spike_counts = spike_counts[valid]
|
|
164
|
+
|
|
165
|
+
occupancy_map, x_edges, y_edges = np.histogram2d(x, y, bins=bins)
|
|
166
|
+
spike_map, _, _ = np.histogram2d(x, y, bins=[x_edges, y_edges], weights=spike_counts)
|
|
167
|
+
|
|
168
|
+
occupancy_map = occupancy_map.T
|
|
169
|
+
spike_map = spike_map.T
|
|
170
|
+
|
|
171
|
+
rate_map = np.zeros_like(occupancy_map)
|
|
172
|
+
mask = occupancy_map > float(min_occupancy)
|
|
173
|
+
rate_map[mask] = spike_map[mask] / occupancy_map[mask]
|
|
174
|
+
|
|
175
|
+
return rate_map, occupancy_map, x_edges, y_edges
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def compute_spatial_information(
|
|
179
|
+
rate_map: np.ndarray, occupancy_map: np.ndarray, mean_rate: float | None = None
|
|
180
|
+
) -> float:
|
|
181
|
+
"""
|
|
182
|
+
Compute spatial information score (bits per spike).
|
|
183
|
+
|
|
184
|
+
Spatial information quantifies how much information about the animal's
|
|
185
|
+
location is conveyed by each spike.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
rate_map : np.ndarray
|
|
190
|
+
2D firing rate map (Hz)
|
|
191
|
+
occupancy_map : np.ndarray
|
|
192
|
+
Time spent in each bin (seconds)
|
|
193
|
+
mean_rate : float, optional
|
|
194
|
+
Mean firing rate. If None, computed from rate_map and occupancy_map.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
spatial_info : float
|
|
199
|
+
Spatial information in bits per spike
|
|
200
|
+
|
|
201
|
+
Examples
|
|
202
|
+
--------
|
|
203
|
+
>>> rate_map = np.random.rand(20, 20) * 10
|
|
204
|
+
>>> occupancy_map = np.ones((20, 20))
|
|
205
|
+
>>> info = compute_spatial_information(rate_map, occupancy_map)
|
|
206
|
+
|
|
207
|
+
Notes
|
|
208
|
+
-----
|
|
209
|
+
Formula: I = Σ_i p_i * (r_i / r_mean) * log2(r_i / r_mean)
|
|
210
|
+
where:
|
|
211
|
+
- p_i is probability of occupancy in bin i
|
|
212
|
+
- r_i is firing rate in bin i
|
|
213
|
+
- r_mean is mean firing rate
|
|
214
|
+
|
|
215
|
+
References
|
|
216
|
+
----------
|
|
217
|
+
Skaggs et al. (1993). "An information-theoretic approach to deciphering
|
|
218
|
+
the hippocampal code." NIPS.
|
|
219
|
+
"""
|
|
220
|
+
# Compute occupancy probability
|
|
221
|
+
total_time = np.sum(occupancy_map)
|
|
222
|
+
if total_time == 0:
|
|
223
|
+
return 0.0
|
|
224
|
+
|
|
225
|
+
prob_occupancy = occupancy_map / total_time
|
|
226
|
+
|
|
227
|
+
# Compute mean firing rate if not provided
|
|
228
|
+
if mean_rate is None:
|
|
229
|
+
mean_rate = np.sum(rate_map * occupancy_map) / total_time
|
|
230
|
+
|
|
231
|
+
if mean_rate == 0:
|
|
232
|
+
return 0.0
|
|
233
|
+
|
|
234
|
+
# Compute spatial information
|
|
235
|
+
spatial_info = 0.0
|
|
236
|
+
for i in range(rate_map.shape[0]):
|
|
237
|
+
for j in range(rate_map.shape[1]):
|
|
238
|
+
if prob_occupancy[i, j] > 0 and rate_map[i, j] > 0:
|
|
239
|
+
ratio = rate_map[i, j] / mean_rate
|
|
240
|
+
spatial_info += prob_occupancy[i, j] * ratio * np.log2(ratio)
|
|
241
|
+
|
|
242
|
+
return spatial_info
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def compute_field_statistics(
|
|
246
|
+
rate_map: np.ndarray, threshold: float = 0.2, min_area: int = 9
|
|
247
|
+
) -> dict:
|
|
248
|
+
"""
|
|
249
|
+
Extract firing field statistics from a rate map.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
rate_map : np.ndarray
|
|
254
|
+
2D firing rate map (Hz)
|
|
255
|
+
threshold : float, optional
|
|
256
|
+
Threshold as fraction of peak rate. Default is 0.2 (20% of peak).
|
|
257
|
+
min_area : int, optional
|
|
258
|
+
Minimum field size in pixels. Default is 9.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
stats : dict
|
|
263
|
+
Dictionary with:
|
|
264
|
+
- num_fields: number of detected fields
|
|
265
|
+
- field_sizes: list of field areas
|
|
266
|
+
- field_peaks: list of peak firing rates
|
|
267
|
+
- field_centers: list of field centers (x, y)
|
|
268
|
+
|
|
269
|
+
Examples
|
|
270
|
+
--------
|
|
271
|
+
>>> rate_map = np.random.rand(50, 50) * 10
|
|
272
|
+
>>> stats = compute_field_statistics(rate_map)
|
|
273
|
+
>>> print(f"Found {stats['num_fields']} firing fields")
|
|
274
|
+
"""
|
|
275
|
+
from ..utils.image_processing import label_connected_components, regionprops
|
|
276
|
+
|
|
277
|
+
# Threshold rate map
|
|
278
|
+
peak_rate = np.max(rate_map)
|
|
279
|
+
threshold_value = threshold * peak_rate
|
|
280
|
+
binary_map = rate_map > threshold_value
|
|
281
|
+
|
|
282
|
+
# Label connected components
|
|
283
|
+
labels, num_labels = label_connected_components(binary_map)
|
|
284
|
+
|
|
285
|
+
# Get region properties
|
|
286
|
+
props = regionprops(labels, intensity_image=rate_map)
|
|
287
|
+
|
|
288
|
+
# Filter by minimum area
|
|
289
|
+
valid_props = [p for p in props if p.area >= min_area]
|
|
290
|
+
|
|
291
|
+
# Extract statistics
|
|
292
|
+
stats = {
|
|
293
|
+
"num_fields": len(valid_props),
|
|
294
|
+
"field_sizes": [p.area for p in valid_props],
|
|
295
|
+
"field_peaks": [np.max(rate_map[p.coords[:, 0], p.coords[:, 1]]) for p in valid_props],
|
|
296
|
+
"field_centers": [p.centroid for p in valid_props],
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
return stats
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def compute_grid_spacing(rate_map: np.ndarray, method: str = "autocorr") -> float | None:
|
|
303
|
+
"""
|
|
304
|
+
Estimate grid spacing from a rate map.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
rate_map : np.ndarray
|
|
309
|
+
2D firing rate map
|
|
310
|
+
method : str, optional
|
|
311
|
+
Method for estimation: 'autocorr' (default) or 'fft'
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
spacing : float or None
|
|
316
|
+
Estimated grid spacing in bins, or None if cannot be determined
|
|
317
|
+
|
|
318
|
+
Notes
|
|
319
|
+
-----
|
|
320
|
+
This is a simplified implementation. For full grid analysis,
|
|
321
|
+
use GridnessAnalyzer.
|
|
322
|
+
"""
|
|
323
|
+
if method == "autocorr":
|
|
324
|
+
from ..utils.image_processing import find_regional_maxima
|
|
325
|
+
from .grid_cells import compute_2d_autocorrelation
|
|
326
|
+
|
|
327
|
+
# Compute autocorrelation
|
|
328
|
+
autocorr = compute_2d_autocorrelation(rate_map)
|
|
329
|
+
|
|
330
|
+
# Find peaks
|
|
331
|
+
maxima = find_regional_maxima(autocorr)
|
|
332
|
+
|
|
333
|
+
# Find distances from center
|
|
334
|
+
center = np.array(autocorr.shape) // 2
|
|
335
|
+
coords = np.argwhere(maxima)
|
|
336
|
+
|
|
337
|
+
if len(coords) < 2:
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
# Remove center peak
|
|
341
|
+
distances = np.linalg.norm(coords - center, axis=1)
|
|
342
|
+
non_center = distances > 5 # Exclude central peak
|
|
343
|
+
if np.sum(non_center) == 0:
|
|
344
|
+
return None
|
|
345
|
+
|
|
346
|
+
# Median distance to peaks
|
|
347
|
+
spacing = np.median(distances[non_center])
|
|
348
|
+
return float(spacing)
|
|
349
|
+
|
|
350
|
+
elif method == "fft":
|
|
351
|
+
# FFT-based spacing estimation
|
|
352
|
+
fft = np.fft.fft2(rate_map)
|
|
353
|
+
fft_shift = np.fft.fftshift(fft)
|
|
354
|
+
power = np.abs(fft_shift) ** 2
|
|
355
|
+
|
|
356
|
+
# Find peak in power spectrum (excluding DC component)
|
|
357
|
+
center = np.array(power.shape) // 2
|
|
358
|
+
power[center[0] - 2 : center[0] + 3, center[1] - 2 : center[1] + 3] = 0
|
|
359
|
+
|
|
360
|
+
peak_idx = np.unravel_index(np.argmax(power), power.shape)
|
|
361
|
+
distance = np.linalg.norm(np.array(peak_idx) - center)
|
|
362
|
+
|
|
363
|
+
if distance > 0:
|
|
364
|
+
spacing = power.shape[0] / distance
|
|
365
|
+
return float(spacing)
|
|
366
|
+
|
|
367
|
+
return None
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
if __name__ == "__main__":
|
|
371
|
+
print("Testing spatial analysis functions...")
|
|
372
|
+
|
|
373
|
+
# Simulate trajectory and spikes
|
|
374
|
+
print("\nSimulating data...")
|
|
375
|
+
time_stamps = np.linspace(0, 100, 10000) # 100 seconds
|
|
376
|
+
t = time_stamps
|
|
377
|
+
|
|
378
|
+
# Circular trajectory
|
|
379
|
+
positions = np.column_stack([0.5 * np.sin(t * 0.1), 0.5 * np.cos(t * 0.1)])
|
|
380
|
+
|
|
381
|
+
# Place cell: fires at (0, 0.5)
|
|
382
|
+
place_field_center = np.array([0.0, 0.5])
|
|
383
|
+
place_field_width = 0.15
|
|
384
|
+
|
|
385
|
+
distances = np.linalg.norm(positions - place_field_center, axis=1)
|
|
386
|
+
firing_prob = np.exp(-(distances**2) / (2 * place_field_width**2))
|
|
387
|
+
firing_prob = firing_prob * 0.1 # Max 10% per time bin
|
|
388
|
+
|
|
389
|
+
spike_mask = np.random.rand(len(t)) < firing_prob
|
|
390
|
+
spike_times = t[spike_mask]
|
|
391
|
+
|
|
392
|
+
print(f"Generated {len(spike_times)} spikes")
|
|
393
|
+
|
|
394
|
+
# Compute rate map
|
|
395
|
+
print("\nComputing rate map...")
|
|
396
|
+
rate_map, occupancy, x_edges, y_edges = compute_rate_map(
|
|
397
|
+
spike_times,
|
|
398
|
+
positions,
|
|
399
|
+
time_stamps,
|
|
400
|
+
spatial_bins=20,
|
|
401
|
+
position_range=(-0.75, 0.75),
|
|
402
|
+
smoothing_sigma=1.5,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
print(f"Rate map shape: {rate_map.shape}")
|
|
406
|
+
print(f"Peak firing rate: {rate_map.max():.2f} Hz")
|
|
407
|
+
print(f"Mean firing rate: {rate_map[occupancy > 0].mean():.2f} Hz")
|
|
408
|
+
|
|
409
|
+
# Compute spatial information
|
|
410
|
+
print("\nComputing spatial information...")
|
|
411
|
+
spatial_info = compute_spatial_information(rate_map, occupancy)
|
|
412
|
+
print(f"Spatial information: {spatial_info:.3f} bits/spike")
|
|
413
|
+
|
|
414
|
+
# Extract field statistics
|
|
415
|
+
print("\nExtracting firing fields...")
|
|
416
|
+
field_stats = compute_field_statistics(rate_map, threshold=0.3)
|
|
417
|
+
print(f"Number of fields: {field_stats['num_fields']}")
|
|
418
|
+
for i, (size, peak, center) in enumerate(
|
|
419
|
+
zip(
|
|
420
|
+
field_stats["field_sizes"],
|
|
421
|
+
field_stats["field_peaks"],
|
|
422
|
+
field_stats["field_centers"],
|
|
423
|
+
strict=False,
|
|
424
|
+
)
|
|
425
|
+
):
|
|
426
|
+
print(
|
|
427
|
+
f" Field {i + 1}: size={size} pixels, peak={peak:.2f} Hz, "
|
|
428
|
+
f"center=({center[0]:.1f}, {center[1]:.1f})"
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
print("\nSpatial analysis tests completed!")
|