canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,633 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Grid Cell Classification
|
|
3
|
+
|
|
4
|
+
Implementation of gridness score algorithm for identifying and characterizing grid cells.
|
|
5
|
+
|
|
6
|
+
Based on the MATLAB gridnessScore.m implementation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from ..utils.circular_stats import circ_dist2
|
|
15
|
+
from ..utils.correlation import autocorrelation_2d, pearson_correlation
|
|
16
|
+
from ..utils.geometry import fit_ellipse, polyarea, squared_distance, wrap_to_pi
|
|
17
|
+
from ..utils.image_processing import (
|
|
18
|
+
dilate_image,
|
|
19
|
+
find_contours_at_level,
|
|
20
|
+
find_regional_maxima,
|
|
21
|
+
gaussian_filter_2d,
|
|
22
|
+
label_connected_components,
|
|
23
|
+
regionprops,
|
|
24
|
+
rotate_image,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class GridnessResult:
|
|
30
|
+
"""
|
|
31
|
+
Results from gridness score computation.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
score : float
|
|
36
|
+
Gridness score (range -2 to 2, typical grid cells: 0.3-1.3)
|
|
37
|
+
spacing : np.ndarray
|
|
38
|
+
Array of 3 grid field spacings (distances from center)
|
|
39
|
+
orientation : np.ndarray
|
|
40
|
+
Array of 3 grid field orientations (angles in degrees)
|
|
41
|
+
ellipse : np.ndarray
|
|
42
|
+
Fitted ellipse parameters [cx, cy, rx, ry, theta]
|
|
43
|
+
ellipse_theta_deg : float
|
|
44
|
+
Ellipse orientation in degrees [0, 180]
|
|
45
|
+
center_radius : float
|
|
46
|
+
Radius of the central autocorrelation field
|
|
47
|
+
optimal_radius : float
|
|
48
|
+
Radius at which gridness score is maximized
|
|
49
|
+
peak_locations : np.ndarray
|
|
50
|
+
Coordinates of detected grid peaks (N x 2 array)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
score: float
|
|
54
|
+
spacing: np.ndarray
|
|
55
|
+
orientation: np.ndarray
|
|
56
|
+
ellipse: np.ndarray
|
|
57
|
+
ellipse_theta_deg: float
|
|
58
|
+
center_radius: float
|
|
59
|
+
optimal_radius: float
|
|
60
|
+
peak_locations: np.ndarray | None = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GridnessAnalyzer:
|
|
64
|
+
"""
|
|
65
|
+
Analyzer for computing gridness scores from spatial autocorrelograms.
|
|
66
|
+
|
|
67
|
+
This implements the rotation-correlation method for quantifying hexagonal
|
|
68
|
+
grid patterns in neural firing rate maps.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
threshold : float, optional
|
|
73
|
+
Normalized threshold for contour detection (0-1). Default is 0.2.
|
|
74
|
+
min_orientation : float, optional
|
|
75
|
+
Minimum angular difference between fields (degrees). Default is 15.
|
|
76
|
+
min_center_radius : int, optional
|
|
77
|
+
Minimum center field radius in pixels. Default is 2.
|
|
78
|
+
num_gridness_radii : int, optional
|
|
79
|
+
Number of adjacent radii to average for gridness score. Default is 3.
|
|
80
|
+
|
|
81
|
+
Examples
|
|
82
|
+
--------
|
|
83
|
+
>>> analyzer = GridnessAnalyzer()
|
|
84
|
+
>>> # Assume we have a 2D rate map
|
|
85
|
+
>>> autocorr = compute_2d_autocorrelation(rate_map)
|
|
86
|
+
>>> result = analyzer.compute_gridness_score(autocorr)
|
|
87
|
+
>>> print(f"Gridness score: {result.score:.3f}")
|
|
88
|
+
>>> print(f"Grid spacing: {result.spacing}")
|
|
89
|
+
|
|
90
|
+
Notes
|
|
91
|
+
-----
|
|
92
|
+
Based on gridnessScore.m from the MATLAB codebase.
|
|
93
|
+
|
|
94
|
+
References
|
|
95
|
+
----------
|
|
96
|
+
The gridness score algorithm computes correlations between the autocorrelogram
|
|
97
|
+
and rotated versions at 30°, 60°, 90°, 120°, and 150°. The score is:
|
|
98
|
+
min(r_60°, r_120°) - max(r_30°, r_90°, r_150°)
|
|
99
|
+
|
|
100
|
+
This exploits the 60° rotational symmetry of hexagonal grids.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
threshold: float = 0.2,
|
|
106
|
+
min_orientation: float = 15.0,
|
|
107
|
+
min_center_radius: int = 2,
|
|
108
|
+
num_gridness_radii: int = 3,
|
|
109
|
+
):
|
|
110
|
+
self.threshold = threshold
|
|
111
|
+
self.min_orientation = min_orientation
|
|
112
|
+
self.min_center_radius = min_center_radius
|
|
113
|
+
self.num_gridness_radii = num_gridness_radii
|
|
114
|
+
|
|
115
|
+
def compute_gridness_score(self, autocorr: np.ndarray) -> GridnessResult:
|
|
116
|
+
"""
|
|
117
|
+
Compute gridness score from a 2D autocorrelogram.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
autocorr : np.ndarray
|
|
122
|
+
2D autocorrelogram of a firing rate map
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
result : GridnessResult
|
|
127
|
+
Complete gridness analysis results
|
|
128
|
+
|
|
129
|
+
Raises
|
|
130
|
+
------
|
|
131
|
+
ValueError
|
|
132
|
+
If autocorr is not 2D or if center field cannot be detected
|
|
133
|
+
"""
|
|
134
|
+
if autocorr.ndim != 1 and (autocorr.shape[0] == 1 or autocorr.shape[1] == 1):
|
|
135
|
+
# Degenerate case: 1D array
|
|
136
|
+
return self._create_nan_result()
|
|
137
|
+
|
|
138
|
+
# Normalize autocorrelogram to [0, 1]
|
|
139
|
+
autocorr = autocorr / np.max(autocorr)
|
|
140
|
+
|
|
141
|
+
# Find central field radius using contour detection
|
|
142
|
+
center_radius = self._find_center_radius(autocorr, self.threshold)
|
|
143
|
+
|
|
144
|
+
if center_radius < self.min_center_radius:
|
|
145
|
+
return self._create_nan_result()
|
|
146
|
+
|
|
147
|
+
# Get autocorr dimensions
|
|
148
|
+
half_height, half_width = np.array(autocorr.shape) // 2 + 1
|
|
149
|
+
autocorr_rad = min(half_height, half_width)
|
|
150
|
+
|
|
151
|
+
if center_radius >= autocorr_rad:
|
|
152
|
+
return self._create_nan_result()
|
|
153
|
+
|
|
154
|
+
# Compute rotation correlations and gridness score
|
|
155
|
+
gridness_scores, rad_steps = self._compute_rotation_correlations(
|
|
156
|
+
autocorr, center_radius, autocorr_rad, half_height, half_width
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Find optimal radius with maximum gridness
|
|
160
|
+
score, optimal_radius = self._find_optimal_gridness(gridness_scores, rad_steps)
|
|
161
|
+
|
|
162
|
+
# Extract grid statistics (spacing, orientation, ellipse fit)
|
|
163
|
+
grid_stats = self._extract_grid_statistics(
|
|
164
|
+
autocorr, optimal_radius, half_height, half_width
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Create result object
|
|
168
|
+
result = GridnessResult(
|
|
169
|
+
score=score,
|
|
170
|
+
spacing=grid_stats["spacing"],
|
|
171
|
+
orientation=grid_stats["orientation"],
|
|
172
|
+
ellipse=grid_stats["ellipse"],
|
|
173
|
+
ellipse_theta_deg=grid_stats["ellipse_theta_deg"],
|
|
174
|
+
center_radius=center_radius,
|
|
175
|
+
optimal_radius=optimal_radius,
|
|
176
|
+
peak_locations=grid_stats.get("peak_locations"),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
def _find_center_radius(self, autocorr: np.ndarray, threshold: float) -> float:
|
|
182
|
+
"""
|
|
183
|
+
Find the radius of the central autocorrelation field.
|
|
184
|
+
|
|
185
|
+
Uses contour detection at the specified threshold level.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
autocorr : np.ndarray
|
|
190
|
+
Normalized autocorrelogram
|
|
191
|
+
threshold : float
|
|
192
|
+
Contour detection threshold
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
radius : float
|
|
197
|
+
Radius of central field in pixels
|
|
198
|
+
"""
|
|
199
|
+
half_height, half_width = np.array(autocorr.shape) // 2 + 1
|
|
200
|
+
|
|
201
|
+
# Find contours at threshold level
|
|
202
|
+
contours = find_contours_at_level(autocorr, threshold)
|
|
203
|
+
|
|
204
|
+
if len(contours) == 0:
|
|
205
|
+
return -1
|
|
206
|
+
|
|
207
|
+
# Find contour closest to center
|
|
208
|
+
# Note: find_contours returns (row, col) = (y, x)
|
|
209
|
+
center_point = np.array([half_height - 1, half_width - 1]) # -1 for 0-indexing
|
|
210
|
+
|
|
211
|
+
min_dist = np.inf
|
|
212
|
+
center_contour = None
|
|
213
|
+
|
|
214
|
+
for contour in contours:
|
|
215
|
+
# Compute mean position of contour
|
|
216
|
+
mean_pos = np.mean(contour, axis=0)
|
|
217
|
+
dist = np.linalg.norm(mean_pos - center_point)
|
|
218
|
+
|
|
219
|
+
if dist < min_dist:
|
|
220
|
+
min_dist = dist
|
|
221
|
+
center_contour = contour
|
|
222
|
+
|
|
223
|
+
if center_contour is None:
|
|
224
|
+
return -1
|
|
225
|
+
|
|
226
|
+
# Compute area of central field contour
|
|
227
|
+
# Note: contour is (row, col) = (y, x), need to swap for polyarea
|
|
228
|
+
area = polyarea(center_contour[:, 1], center_contour[:, 0])
|
|
229
|
+
|
|
230
|
+
# Radius from area: r = sqrt(area / pi)
|
|
231
|
+
radius = np.floor(np.sqrt(area / np.pi))
|
|
232
|
+
|
|
233
|
+
return float(radius)
|
|
234
|
+
|
|
235
|
+
def _compute_rotation_correlations(
|
|
236
|
+
self,
|
|
237
|
+
autocorr: np.ndarray,
|
|
238
|
+
center_radius: float,
|
|
239
|
+
autocorr_rad: float,
|
|
240
|
+
half_height: int,
|
|
241
|
+
half_width: int,
|
|
242
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
243
|
+
"""
|
|
244
|
+
Compute correlations between autocorr and rotated versions.
|
|
245
|
+
|
|
246
|
+
Rotates at 30°, 60°, 90°, 120°, 150° and computes Pearson correlation
|
|
247
|
+
for expanding circular regions.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
autocorr : np.ndarray
|
|
252
|
+
Normalized autocorrelogram
|
|
253
|
+
center_radius : float
|
|
254
|
+
Radius of central field to exclude
|
|
255
|
+
autocorr_rad : float
|
|
256
|
+
Radius of autocorrelogram
|
|
257
|
+
half_height : int
|
|
258
|
+
Half height of autocorr
|
|
259
|
+
half_width : int
|
|
260
|
+
Half width of autocorr
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
gridness_scores : np.ndarray
|
|
265
|
+
Gridness scores at each radius
|
|
266
|
+
rad_steps : np.ndarray
|
|
267
|
+
Radius values tested
|
|
268
|
+
"""
|
|
269
|
+
# Define rotation angles
|
|
270
|
+
rot_angles_deg = np.array([30, 60, 90, 120, 150])
|
|
271
|
+
n_rot = len(rot_angles_deg)
|
|
272
|
+
|
|
273
|
+
# Create distance mask from center
|
|
274
|
+
rr, cc = np.meshgrid(np.arange(autocorr.shape[1]), np.arange(autocorr.shape[0]))
|
|
275
|
+
dist_from_center = np.sqrt((cc - (half_height - 1)) ** 2 + (rr - (half_width - 1)) ** 2)
|
|
276
|
+
|
|
277
|
+
# Mask for excluding central field
|
|
278
|
+
center_exclusion_mask = dist_from_center > center_radius
|
|
279
|
+
|
|
280
|
+
# Define radius steps
|
|
281
|
+
rad_steps = np.arange(center_radius + 1, autocorr_rad + 1)
|
|
282
|
+
num_steps = len(rad_steps)
|
|
283
|
+
|
|
284
|
+
if num_steps == 0:
|
|
285
|
+
return np.array([np.nan]), np.array([center_radius])
|
|
286
|
+
|
|
287
|
+
# Pre-compute all rotated versions
|
|
288
|
+
rotated_autocorr = np.zeros((autocorr.shape[0], autocorr.shape[1], n_rot))
|
|
289
|
+
for i, angle in enumerate(rot_angles_deg):
|
|
290
|
+
rotated_autocorr[:, :, i] = rotate_image(
|
|
291
|
+
autocorr, angle, output_shape=autocorr.shape, method="bilinear"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Vectorize autocorr and rotated versions
|
|
295
|
+
autocorr_vec = autocorr[center_exclusion_mask].ravel()
|
|
296
|
+
rotated_vec = rotated_autocorr[center_exclusion_mask, :].reshape(-1, n_rot)
|
|
297
|
+
dist_vec = dist_from_center[center_exclusion_mask].ravel()
|
|
298
|
+
|
|
299
|
+
# Compute correlations at each radius
|
|
300
|
+
gridness_scores = np.zeros((num_steps, 2))
|
|
301
|
+
gridness_scores[:, 1] = rad_steps
|
|
302
|
+
|
|
303
|
+
for i, radius in enumerate(rad_steps):
|
|
304
|
+
# Select points within this radius
|
|
305
|
+
in_radius = dist_vec < radius
|
|
306
|
+
|
|
307
|
+
if np.sum(in_radius) < 10: # Need minimum points
|
|
308
|
+
gridness_scores[i, 0] = np.nan
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
ref_circle = autocorr_vec[in_radius]
|
|
312
|
+
rot_circles = rotated_vec[in_radius, :]
|
|
313
|
+
|
|
314
|
+
# Compute Pearson correlations
|
|
315
|
+
rot_corr = pearson_correlation(ref_circle[:, np.newaxis], rot_circles)
|
|
316
|
+
|
|
317
|
+
# Gridness score: min(r_60, r_120) - max(r_30, r_90, r_150)
|
|
318
|
+
# rot_corr indices: [30°, 60°, 90°, 120°, 150°] = [0, 1, 2, 3, 4]
|
|
319
|
+
score = min(rot_corr[1], rot_corr[3]) - max(rot_corr[0], rot_corr[2], rot_corr[4])
|
|
320
|
+
gridness_scores[i, 0] = score
|
|
321
|
+
|
|
322
|
+
return gridness_scores, rad_steps
|
|
323
|
+
|
|
324
|
+
def _find_optimal_gridness(
|
|
325
|
+
self, gridness_scores: np.ndarray, rad_steps: np.ndarray
|
|
326
|
+
) -> tuple[float, float]:
|
|
327
|
+
"""
|
|
328
|
+
Find the radius with maximum gridness score.
|
|
329
|
+
|
|
330
|
+
Averages over num_gridness_radii adjacent radii for stability.
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
gridness_scores : np.ndarray
|
|
335
|
+
Gridness scores at each radius (N x 2 array)
|
|
336
|
+
rad_steps : np.ndarray
|
|
337
|
+
Radius values
|
|
338
|
+
|
|
339
|
+
Returns
|
|
340
|
+
-------
|
|
341
|
+
max_score : float
|
|
342
|
+
Maximum gridness score
|
|
343
|
+
optimal_radius : float
|
|
344
|
+
Radius at maximum score
|
|
345
|
+
"""
|
|
346
|
+
scores = gridness_scores[:, 0]
|
|
347
|
+
num_steps = len(scores)
|
|
348
|
+
|
|
349
|
+
if num_steps < self.num_gridness_radii:
|
|
350
|
+
# Not enough radii, just take max
|
|
351
|
+
valid_scores = scores[~np.isnan(scores)]
|
|
352
|
+
if len(valid_scores) == 0:
|
|
353
|
+
return np.nan, rad_steps[0] if len(rad_steps) > 0 else 0
|
|
354
|
+
|
|
355
|
+
max_idx = np.nanargmax(scores)
|
|
356
|
+
return scores[max_idx], rad_steps[max_idx]
|
|
357
|
+
|
|
358
|
+
# Average over adjacent radii
|
|
359
|
+
num_windows = num_steps - self.num_gridness_radii + 1
|
|
360
|
+
mean_scores = np.zeros(num_windows)
|
|
361
|
+
|
|
362
|
+
for i in range(num_windows):
|
|
363
|
+
window = scores[i : i + self.num_gridness_radii]
|
|
364
|
+
mean_scores[i] = np.nanmean(window)
|
|
365
|
+
|
|
366
|
+
# Find maximum
|
|
367
|
+
max_idx = np.nanargmax(mean_scores)
|
|
368
|
+
max_score = mean_scores[max_idx]
|
|
369
|
+
|
|
370
|
+
# Optimal radius is at center of window
|
|
371
|
+
optimal_idx = max_idx + (self.num_gridness_radii - 1) // 2
|
|
372
|
+
optimal_radius = rad_steps[optimal_idx]
|
|
373
|
+
|
|
374
|
+
return float(max_score), float(optimal_radius)
|
|
375
|
+
|
|
376
|
+
def _extract_grid_statistics(
|
|
377
|
+
self, autocorr: np.ndarray, optimal_radius: float, half_height: int, half_width: int
|
|
378
|
+
) -> dict:
|
|
379
|
+
"""
|
|
380
|
+
Extract grid field statistics (spacing, orientation, ellipse).
|
|
381
|
+
|
|
382
|
+
Finds peaks in the autocorrelogram and fits an ellipse to them.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
autocorr : np.ndarray
|
|
387
|
+
Autocorrelogram
|
|
388
|
+
optimal_radius : float
|
|
389
|
+
Optimal radius for analysis
|
|
390
|
+
half_height : int
|
|
391
|
+
Half height of autocorr
|
|
392
|
+
half_width : int
|
|
393
|
+
Half width of autocorr
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
stats : dict
|
|
398
|
+
Dictionary with spacing, orientation, ellipse parameters
|
|
399
|
+
"""
|
|
400
|
+
# Create distance mask
|
|
401
|
+
rr, cc = np.meshgrid(np.arange(autocorr.shape[1]), np.arange(autocorr.shape[0]))
|
|
402
|
+
dist_from_center = np.sqrt((cc - (half_height - 1)) ** 2 + (rr - (half_width - 1)) ** 2)
|
|
403
|
+
|
|
404
|
+
# Define search window around optimal radius
|
|
405
|
+
w = optimal_radius / 4
|
|
406
|
+
mask_outer = dist_from_center < (optimal_radius + w)
|
|
407
|
+
|
|
408
|
+
# Smooth autocorr to eliminate spurious maxima
|
|
409
|
+
autocorr_sm = gaussian_filter_2d(
|
|
410
|
+
autocorr, sigma=optimal_radius / (2 * np.pi) / 2, mode="reflect"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Apply mask
|
|
414
|
+
masked_autocorr = mask_outer * autocorr_sm
|
|
415
|
+
|
|
416
|
+
# Find regional maxima
|
|
417
|
+
maxima_map = find_regional_maxima(masked_autocorr, connectivity=1)
|
|
418
|
+
|
|
419
|
+
# Dilate to eliminate fragmentation
|
|
420
|
+
maxima_map_dilated = dilate_image(maxima_map, selem_type="square", selem_size=3)
|
|
421
|
+
|
|
422
|
+
# Label connected components
|
|
423
|
+
labels, num_labels = label_connected_components(maxima_map_dilated, connectivity=2)
|
|
424
|
+
|
|
425
|
+
if num_labels < 5:
|
|
426
|
+
warnings.warn("Not enough grid peaks found for statistics", stacklevel=2)
|
|
427
|
+
return self._create_nan_stats()
|
|
428
|
+
|
|
429
|
+
# Get region properties
|
|
430
|
+
props = regionprops(labels)
|
|
431
|
+
|
|
432
|
+
# Extract centroids
|
|
433
|
+
centroids = np.array([prop.centroid for prop in props]) # (N, 2) array of (row, col)
|
|
434
|
+
|
|
435
|
+
# Convert to (x, y) coordinates
|
|
436
|
+
centers_of_mass = centroids[:, ::-1] # Swap to (col, row) = (x, y)
|
|
437
|
+
|
|
438
|
+
# Compute orientations relative to center
|
|
439
|
+
center_point = np.array([half_width - 1, half_height - 1])
|
|
440
|
+
orientations = np.arctan2(
|
|
441
|
+
centers_of_mass[:, 1] - center_point[1], centers_of_mass[:, 0] - center_point[0]
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Compute distances to center
|
|
445
|
+
peaks_to_center = squared_distance(centers_of_mass.T, center_point[:, np.newaxis]).ravel()
|
|
446
|
+
|
|
447
|
+
# Remove zero orientation (central peak if any)
|
|
448
|
+
zero_idx = np.where(orientations == 0)[0]
|
|
449
|
+
if len(zero_idx) > 0:
|
|
450
|
+
mask = np.ones(len(orientations), dtype=bool)
|
|
451
|
+
mask[zero_idx] = False
|
|
452
|
+
orientations = orientations[mask]
|
|
453
|
+
centers_of_mass = centers_of_mass[mask]
|
|
454
|
+
peaks_to_center = peaks_to_center[mask]
|
|
455
|
+
|
|
456
|
+
# Filter fields with similar orientations
|
|
457
|
+
orient_dist_sq = circ_dist2(orientations)
|
|
458
|
+
close_fields = np.abs(orient_dist_sq) < np.deg2rad(self.min_orientation)
|
|
459
|
+
np.fill_diagonal(close_fields, False)
|
|
460
|
+
close_fields = np.triu(close_fields) # Keep upper triangle only
|
|
461
|
+
|
|
462
|
+
rows, cols = np.where(close_fields)
|
|
463
|
+
to_delete = []
|
|
464
|
+
for row, col in zip(rows, cols, strict=True):
|
|
465
|
+
# Keep the one closer to center
|
|
466
|
+
if peaks_to_center[row] > peaks_to_center[col]:
|
|
467
|
+
to_delete.append(row)
|
|
468
|
+
else:
|
|
469
|
+
to_delete.append(col)
|
|
470
|
+
|
|
471
|
+
to_delete = np.unique(to_delete)
|
|
472
|
+
if len(to_delete) > 0:
|
|
473
|
+
mask = np.ones(len(orientations), dtype=bool)
|
|
474
|
+
mask[to_delete] = False
|
|
475
|
+
orientations = orientations[mask]
|
|
476
|
+
centers_of_mass = centers_of_mass[mask]
|
|
477
|
+
peaks_to_center = peaks_to_center[mask]
|
|
478
|
+
|
|
479
|
+
if len(centers_of_mass) < 4:
|
|
480
|
+
warnings.warn("Not enough grid peaks after filtering", stacklevel=2)
|
|
481
|
+
return self._create_nan_stats()
|
|
482
|
+
|
|
483
|
+
# Sort by distance to center and keep 6 closest
|
|
484
|
+
sort_idx = np.argsort(peaks_to_center)
|
|
485
|
+
centers_of_mass = centers_of_mass[sort_idx]
|
|
486
|
+
if len(centers_of_mass) > 6:
|
|
487
|
+
centers_of_mass = centers_of_mass[:6]
|
|
488
|
+
|
|
489
|
+
# Compute final orientations and spacings
|
|
490
|
+
orientations_deg = np.rad2deg(
|
|
491
|
+
np.arctan2(
|
|
492
|
+
centers_of_mass[:, 1] - center_point[1], centers_of_mass[:, 0] - center_point[0]
|
|
493
|
+
)
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
spacings = np.sqrt(
|
|
497
|
+
(centers_of_mass[:, 0] - center_point[0]) ** 2
|
|
498
|
+
+ (centers_of_mass[:, 1] - center_point[1]) ** 2
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Fit ellipse to peaks
|
|
502
|
+
try:
|
|
503
|
+
ellipse = fit_ellipse(centers_of_mass[:, 0], centers_of_mass[:, 1])
|
|
504
|
+
ellipse_theta_deg = np.rad2deg(wrap_to_pi(ellipse[4]) + np.pi)
|
|
505
|
+
except Exception:
|
|
506
|
+
ellipse = np.full(5, np.nan)
|
|
507
|
+
ellipse_theta_deg = np.nan
|
|
508
|
+
|
|
509
|
+
# Select 3 orientations with smallest absolute values (closest to main axes)
|
|
510
|
+
abs_orient = np.abs(orientations_deg)
|
|
511
|
+
orient_sort_idx = np.argsort(abs_orient)
|
|
512
|
+
orient_sort_idx2 = np.argsort(
|
|
513
|
+
np.abs(orientations_deg - orientations_deg[orient_sort_idx[0]])
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
final_idx = orient_sort_idx2[:3]
|
|
517
|
+
orientations_deg = orientations_deg[final_idx]
|
|
518
|
+
spacings = spacings[final_idx]
|
|
519
|
+
|
|
520
|
+
# Sort by orientation
|
|
521
|
+
sort_idx = np.argsort(orientations_deg)
|
|
522
|
+
orientations_deg = orientations_deg[sort_idx]
|
|
523
|
+
spacings = spacings[sort_idx]
|
|
524
|
+
|
|
525
|
+
return {
|
|
526
|
+
"spacing": spacings,
|
|
527
|
+
"orientation": orientations_deg,
|
|
528
|
+
"ellipse": ellipse,
|
|
529
|
+
"ellipse_theta_deg": ellipse_theta_deg,
|
|
530
|
+
"peak_locations": centers_of_mass,
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
def _create_nan_result(self) -> GridnessResult:
|
|
534
|
+
"""Create a result with NaN values for failed analysis."""
|
|
535
|
+
return GridnessResult(
|
|
536
|
+
score=np.nan,
|
|
537
|
+
spacing=np.full(3, np.nan),
|
|
538
|
+
orientation=np.full(3, np.nan),
|
|
539
|
+
ellipse=np.full(5, np.nan),
|
|
540
|
+
ellipse_theta_deg=np.nan,
|
|
541
|
+
center_radius=0,
|
|
542
|
+
optimal_radius=np.nan,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def _create_nan_stats(self) -> dict:
|
|
546
|
+
"""Create statistics dict with NaN values."""
|
|
547
|
+
return {
|
|
548
|
+
"spacing": np.full(3, np.nan),
|
|
549
|
+
"orientation": np.full(3, np.nan),
|
|
550
|
+
"ellipse": np.full(5, np.nan),
|
|
551
|
+
"ellipse_theta_deg": np.nan,
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def compute_2d_autocorrelation(rate_map: np.ndarray, overlap: float = 0.8) -> np.ndarray:
|
|
556
|
+
"""
|
|
557
|
+
Compute 2D spatial autocorrelation of a firing rate map.
|
|
558
|
+
|
|
559
|
+
This is a convenience wrapper around the autocorrelation function
|
|
560
|
+
from the correlation module.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
rate_map : np.ndarray
|
|
565
|
+
2D firing rate map
|
|
566
|
+
overlap : float, optional
|
|
567
|
+
Overlap percentage (0-1). Default is 0.8.
|
|
568
|
+
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
autocorr : np.ndarray
|
|
572
|
+
2D autocorrelogram
|
|
573
|
+
|
|
574
|
+
Examples
|
|
575
|
+
--------
|
|
576
|
+
>>> rate_map = np.random.rand(50, 50)
|
|
577
|
+
>>> autocorr = compute_2d_autocorrelation(rate_map)
|
|
578
|
+
>>> print(autocorr.shape)
|
|
579
|
+
|
|
580
|
+
Notes
|
|
581
|
+
-----
|
|
582
|
+
Based on autocorrelation.m from the MATLAB codebase.
|
|
583
|
+
"""
|
|
584
|
+
return autocorrelation_2d(rate_map, overlap=overlap, normalize=True)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
if __name__ == "__main__":
|
|
588
|
+
print("Testing GridnessAnalyzer...")
|
|
589
|
+
|
|
590
|
+
# Create a synthetic grid-like pattern
|
|
591
|
+
print("\nCreating synthetic grid pattern...")
|
|
592
|
+
x = np.linspace(-2, 2, 100)
|
|
593
|
+
xx, yy = np.meshgrid(x, x)
|
|
594
|
+
|
|
595
|
+
# Hexagonal grid pattern (sum of 3 cosines at 60° angles)
|
|
596
|
+
theta1, theta2, theta3 = 0, np.pi / 3, 2 * np.pi / 3
|
|
597
|
+
k = 2 * np.pi / 0.4 # Spatial frequency
|
|
598
|
+
|
|
599
|
+
grid_pattern = (
|
|
600
|
+
np.cos(k * (xx * np.cos(theta1) + yy * np.sin(theta1)))
|
|
601
|
+
+ np.cos(k * (xx * np.cos(theta2) + yy * np.sin(theta2)))
|
|
602
|
+
+ np.cos(k * (xx * np.cos(theta3) + yy * np.sin(theta3)))
|
|
603
|
+
) / 3
|
|
604
|
+
|
|
605
|
+
# Make it a rate map (positive values)
|
|
606
|
+
rate_map = (grid_pattern + 1.5) / 2.5 * 10 # Scale to 0-10 Hz range
|
|
607
|
+
|
|
608
|
+
print(f"Rate map shape: {rate_map.shape}")
|
|
609
|
+
print(f"Rate map range: [{rate_map.min():.2f}, {rate_map.max():.2f}] Hz")
|
|
610
|
+
|
|
611
|
+
# Compute autocorrelation
|
|
612
|
+
print("\nComputing autocorrelation...")
|
|
613
|
+
autocorr = compute_2d_autocorrelation(rate_map)
|
|
614
|
+
print(f"Autocorr shape: {autocorr.shape}")
|
|
615
|
+
|
|
616
|
+
# Compute gridness score
|
|
617
|
+
print("\nComputing gridness score...")
|
|
618
|
+
analyzer = GridnessAnalyzer()
|
|
619
|
+
result = analyzer.compute_gridness_score(autocorr)
|
|
620
|
+
|
|
621
|
+
print("\nResults:")
|
|
622
|
+
print(f" Gridness score: {result.score:.3f}")
|
|
623
|
+
print(f" Grid spacing: {result.spacing}")
|
|
624
|
+
print(f" Grid orientation: {result.orientation}°")
|
|
625
|
+
print(f" Center radius: {result.center_radius}")
|
|
626
|
+
print(f" Optimal radius: {result.optimal_radius}")
|
|
627
|
+
|
|
628
|
+
if not np.isnan(result.ellipse).any():
|
|
629
|
+
print(f" Ellipse center: ({result.ellipse[0]:.1f}, {result.ellipse[1]:.1f})")
|
|
630
|
+
print(f" Ellipse radii: ({result.ellipse[2]:.1f}, {result.ellipse[3]:.1f})")
|
|
631
|
+
print(f" Ellipse angle: {result.ellipse_theta_deg:.1f}°")
|
|
632
|
+
|
|
633
|
+
print("\nGridnessAnalyzer test completed!")
|