canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,633 @@
1
+ """
2
+ Grid Cell Classification
3
+
4
+ Implementation of gridness score algorithm for identifying and characterizing grid cells.
5
+
6
+ Based on the MATLAB gridnessScore.m implementation.
7
+ """
8
+
9
+ import warnings
10
+ from dataclasses import dataclass
11
+
12
+ import numpy as np
13
+
14
+ from ..utils.circular_stats import circ_dist2
15
+ from ..utils.correlation import autocorrelation_2d, pearson_correlation
16
+ from ..utils.geometry import fit_ellipse, polyarea, squared_distance, wrap_to_pi
17
+ from ..utils.image_processing import (
18
+ dilate_image,
19
+ find_contours_at_level,
20
+ find_regional_maxima,
21
+ gaussian_filter_2d,
22
+ label_connected_components,
23
+ regionprops,
24
+ rotate_image,
25
+ )
26
+
27
+
28
+ @dataclass
29
+ class GridnessResult:
30
+ """
31
+ Results from gridness score computation.
32
+
33
+ Attributes
34
+ ----------
35
+ score : float
36
+ Gridness score (range -2 to 2, typical grid cells: 0.3-1.3)
37
+ spacing : np.ndarray
38
+ Array of 3 grid field spacings (distances from center)
39
+ orientation : np.ndarray
40
+ Array of 3 grid field orientations (angles in degrees)
41
+ ellipse : np.ndarray
42
+ Fitted ellipse parameters [cx, cy, rx, ry, theta]
43
+ ellipse_theta_deg : float
44
+ Ellipse orientation in degrees [0, 180]
45
+ center_radius : float
46
+ Radius of the central autocorrelation field
47
+ optimal_radius : float
48
+ Radius at which gridness score is maximized
49
+ peak_locations : np.ndarray
50
+ Coordinates of detected grid peaks (N x 2 array)
51
+ """
52
+
53
+ score: float
54
+ spacing: np.ndarray
55
+ orientation: np.ndarray
56
+ ellipse: np.ndarray
57
+ ellipse_theta_deg: float
58
+ center_radius: float
59
+ optimal_radius: float
60
+ peak_locations: np.ndarray | None = None
61
+
62
+
63
+ class GridnessAnalyzer:
64
+ """
65
+ Analyzer for computing gridness scores from spatial autocorrelograms.
66
+
67
+ This implements the rotation-correlation method for quantifying hexagonal
68
+ grid patterns in neural firing rate maps.
69
+
70
+ Parameters
71
+ ----------
72
+ threshold : float, optional
73
+ Normalized threshold for contour detection (0-1). Default is 0.2.
74
+ min_orientation : float, optional
75
+ Minimum angular difference between fields (degrees). Default is 15.
76
+ min_center_radius : int, optional
77
+ Minimum center field radius in pixels. Default is 2.
78
+ num_gridness_radii : int, optional
79
+ Number of adjacent radii to average for gridness score. Default is 3.
80
+
81
+ Examples
82
+ --------
83
+ >>> analyzer = GridnessAnalyzer()
84
+ >>> # Assume we have a 2D rate map
85
+ >>> autocorr = compute_2d_autocorrelation(rate_map)
86
+ >>> result = analyzer.compute_gridness_score(autocorr)
87
+ >>> print(f"Gridness score: {result.score:.3f}")
88
+ >>> print(f"Grid spacing: {result.spacing}")
89
+
90
+ Notes
91
+ -----
92
+ Based on gridnessScore.m from the MATLAB codebase.
93
+
94
+ References
95
+ ----------
96
+ The gridness score algorithm computes correlations between the autocorrelogram
97
+ and rotated versions at 30°, 60°, 90°, 120°, and 150°. The score is:
98
+ min(r_60°, r_120°) - max(r_30°, r_90°, r_150°)
99
+
100
+ This exploits the 60° rotational symmetry of hexagonal grids.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ threshold: float = 0.2,
106
+ min_orientation: float = 15.0,
107
+ min_center_radius: int = 2,
108
+ num_gridness_radii: int = 3,
109
+ ):
110
+ self.threshold = threshold
111
+ self.min_orientation = min_orientation
112
+ self.min_center_radius = min_center_radius
113
+ self.num_gridness_radii = num_gridness_radii
114
+
115
+ def compute_gridness_score(self, autocorr: np.ndarray) -> GridnessResult:
116
+ """
117
+ Compute gridness score from a 2D autocorrelogram.
118
+
119
+ Parameters
120
+ ----------
121
+ autocorr : np.ndarray
122
+ 2D autocorrelogram of a firing rate map
123
+
124
+ Returns
125
+ -------
126
+ result : GridnessResult
127
+ Complete gridness analysis results
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If autocorr is not 2D or if center field cannot be detected
133
+ """
134
+ if autocorr.ndim != 1 and (autocorr.shape[0] == 1 or autocorr.shape[1] == 1):
135
+ # Degenerate case: 1D array
136
+ return self._create_nan_result()
137
+
138
+ # Normalize autocorrelogram to [0, 1]
139
+ autocorr = autocorr / np.max(autocorr)
140
+
141
+ # Find central field radius using contour detection
142
+ center_radius = self._find_center_radius(autocorr, self.threshold)
143
+
144
+ if center_radius < self.min_center_radius:
145
+ return self._create_nan_result()
146
+
147
+ # Get autocorr dimensions
148
+ half_height, half_width = np.array(autocorr.shape) // 2 + 1
149
+ autocorr_rad = min(half_height, half_width)
150
+
151
+ if center_radius >= autocorr_rad:
152
+ return self._create_nan_result()
153
+
154
+ # Compute rotation correlations and gridness score
155
+ gridness_scores, rad_steps = self._compute_rotation_correlations(
156
+ autocorr, center_radius, autocorr_rad, half_height, half_width
157
+ )
158
+
159
+ # Find optimal radius with maximum gridness
160
+ score, optimal_radius = self._find_optimal_gridness(gridness_scores, rad_steps)
161
+
162
+ # Extract grid statistics (spacing, orientation, ellipse fit)
163
+ grid_stats = self._extract_grid_statistics(
164
+ autocorr, optimal_radius, half_height, half_width
165
+ )
166
+
167
+ # Create result object
168
+ result = GridnessResult(
169
+ score=score,
170
+ spacing=grid_stats["spacing"],
171
+ orientation=grid_stats["orientation"],
172
+ ellipse=grid_stats["ellipse"],
173
+ ellipse_theta_deg=grid_stats["ellipse_theta_deg"],
174
+ center_radius=center_radius,
175
+ optimal_radius=optimal_radius,
176
+ peak_locations=grid_stats.get("peak_locations"),
177
+ )
178
+
179
+ return result
180
+
181
+ def _find_center_radius(self, autocorr: np.ndarray, threshold: float) -> float:
182
+ """
183
+ Find the radius of the central autocorrelation field.
184
+
185
+ Uses contour detection at the specified threshold level.
186
+
187
+ Parameters
188
+ ----------
189
+ autocorr : np.ndarray
190
+ Normalized autocorrelogram
191
+ threshold : float
192
+ Contour detection threshold
193
+
194
+ Returns
195
+ -------
196
+ radius : float
197
+ Radius of central field in pixels
198
+ """
199
+ half_height, half_width = np.array(autocorr.shape) // 2 + 1
200
+
201
+ # Find contours at threshold level
202
+ contours = find_contours_at_level(autocorr, threshold)
203
+
204
+ if len(contours) == 0:
205
+ return -1
206
+
207
+ # Find contour closest to center
208
+ # Note: find_contours returns (row, col) = (y, x)
209
+ center_point = np.array([half_height - 1, half_width - 1]) # -1 for 0-indexing
210
+
211
+ min_dist = np.inf
212
+ center_contour = None
213
+
214
+ for contour in contours:
215
+ # Compute mean position of contour
216
+ mean_pos = np.mean(contour, axis=0)
217
+ dist = np.linalg.norm(mean_pos - center_point)
218
+
219
+ if dist < min_dist:
220
+ min_dist = dist
221
+ center_contour = contour
222
+
223
+ if center_contour is None:
224
+ return -1
225
+
226
+ # Compute area of central field contour
227
+ # Note: contour is (row, col) = (y, x), need to swap for polyarea
228
+ area = polyarea(center_contour[:, 1], center_contour[:, 0])
229
+
230
+ # Radius from area: r = sqrt(area / pi)
231
+ radius = np.floor(np.sqrt(area / np.pi))
232
+
233
+ return float(radius)
234
+
235
+ def _compute_rotation_correlations(
236
+ self,
237
+ autocorr: np.ndarray,
238
+ center_radius: float,
239
+ autocorr_rad: float,
240
+ half_height: int,
241
+ half_width: int,
242
+ ) -> tuple[np.ndarray, np.ndarray]:
243
+ """
244
+ Compute correlations between autocorr and rotated versions.
245
+
246
+ Rotates at 30°, 60°, 90°, 120°, 150° and computes Pearson correlation
247
+ for expanding circular regions.
248
+
249
+ Parameters
250
+ ----------
251
+ autocorr : np.ndarray
252
+ Normalized autocorrelogram
253
+ center_radius : float
254
+ Radius of central field to exclude
255
+ autocorr_rad : float
256
+ Radius of autocorrelogram
257
+ half_height : int
258
+ Half height of autocorr
259
+ half_width : int
260
+ Half width of autocorr
261
+
262
+ Returns
263
+ -------
264
+ gridness_scores : np.ndarray
265
+ Gridness scores at each radius
266
+ rad_steps : np.ndarray
267
+ Radius values tested
268
+ """
269
+ # Define rotation angles
270
+ rot_angles_deg = np.array([30, 60, 90, 120, 150])
271
+ n_rot = len(rot_angles_deg)
272
+
273
+ # Create distance mask from center
274
+ rr, cc = np.meshgrid(np.arange(autocorr.shape[1]), np.arange(autocorr.shape[0]))
275
+ dist_from_center = np.sqrt((cc - (half_height - 1)) ** 2 + (rr - (half_width - 1)) ** 2)
276
+
277
+ # Mask for excluding central field
278
+ center_exclusion_mask = dist_from_center > center_radius
279
+
280
+ # Define radius steps
281
+ rad_steps = np.arange(center_radius + 1, autocorr_rad + 1)
282
+ num_steps = len(rad_steps)
283
+
284
+ if num_steps == 0:
285
+ return np.array([np.nan]), np.array([center_radius])
286
+
287
+ # Pre-compute all rotated versions
288
+ rotated_autocorr = np.zeros((autocorr.shape[0], autocorr.shape[1], n_rot))
289
+ for i, angle in enumerate(rot_angles_deg):
290
+ rotated_autocorr[:, :, i] = rotate_image(
291
+ autocorr, angle, output_shape=autocorr.shape, method="bilinear"
292
+ )
293
+
294
+ # Vectorize autocorr and rotated versions
295
+ autocorr_vec = autocorr[center_exclusion_mask].ravel()
296
+ rotated_vec = rotated_autocorr[center_exclusion_mask, :].reshape(-1, n_rot)
297
+ dist_vec = dist_from_center[center_exclusion_mask].ravel()
298
+
299
+ # Compute correlations at each radius
300
+ gridness_scores = np.zeros((num_steps, 2))
301
+ gridness_scores[:, 1] = rad_steps
302
+
303
+ for i, radius in enumerate(rad_steps):
304
+ # Select points within this radius
305
+ in_radius = dist_vec < radius
306
+
307
+ if np.sum(in_radius) < 10: # Need minimum points
308
+ gridness_scores[i, 0] = np.nan
309
+ continue
310
+
311
+ ref_circle = autocorr_vec[in_radius]
312
+ rot_circles = rotated_vec[in_radius, :]
313
+
314
+ # Compute Pearson correlations
315
+ rot_corr = pearson_correlation(ref_circle[:, np.newaxis], rot_circles)
316
+
317
+ # Gridness score: min(r_60, r_120) - max(r_30, r_90, r_150)
318
+ # rot_corr indices: [30°, 60°, 90°, 120°, 150°] = [0, 1, 2, 3, 4]
319
+ score = min(rot_corr[1], rot_corr[3]) - max(rot_corr[0], rot_corr[2], rot_corr[4])
320
+ gridness_scores[i, 0] = score
321
+
322
+ return gridness_scores, rad_steps
323
+
324
+ def _find_optimal_gridness(
325
+ self, gridness_scores: np.ndarray, rad_steps: np.ndarray
326
+ ) -> tuple[float, float]:
327
+ """
328
+ Find the radius with maximum gridness score.
329
+
330
+ Averages over num_gridness_radii adjacent radii for stability.
331
+
332
+ Parameters
333
+ ----------
334
+ gridness_scores : np.ndarray
335
+ Gridness scores at each radius (N x 2 array)
336
+ rad_steps : np.ndarray
337
+ Radius values
338
+
339
+ Returns
340
+ -------
341
+ max_score : float
342
+ Maximum gridness score
343
+ optimal_radius : float
344
+ Radius at maximum score
345
+ """
346
+ scores = gridness_scores[:, 0]
347
+ num_steps = len(scores)
348
+
349
+ if num_steps < self.num_gridness_radii:
350
+ # Not enough radii, just take max
351
+ valid_scores = scores[~np.isnan(scores)]
352
+ if len(valid_scores) == 0:
353
+ return np.nan, rad_steps[0] if len(rad_steps) > 0 else 0
354
+
355
+ max_idx = np.nanargmax(scores)
356
+ return scores[max_idx], rad_steps[max_idx]
357
+
358
+ # Average over adjacent radii
359
+ num_windows = num_steps - self.num_gridness_radii + 1
360
+ mean_scores = np.zeros(num_windows)
361
+
362
+ for i in range(num_windows):
363
+ window = scores[i : i + self.num_gridness_radii]
364
+ mean_scores[i] = np.nanmean(window)
365
+
366
+ # Find maximum
367
+ max_idx = np.nanargmax(mean_scores)
368
+ max_score = mean_scores[max_idx]
369
+
370
+ # Optimal radius is at center of window
371
+ optimal_idx = max_idx + (self.num_gridness_radii - 1) // 2
372
+ optimal_radius = rad_steps[optimal_idx]
373
+
374
+ return float(max_score), float(optimal_radius)
375
+
376
+ def _extract_grid_statistics(
377
+ self, autocorr: np.ndarray, optimal_radius: float, half_height: int, half_width: int
378
+ ) -> dict:
379
+ """
380
+ Extract grid field statistics (spacing, orientation, ellipse).
381
+
382
+ Finds peaks in the autocorrelogram and fits an ellipse to them.
383
+
384
+ Parameters
385
+ ----------
386
+ autocorr : np.ndarray
387
+ Autocorrelogram
388
+ optimal_radius : float
389
+ Optimal radius for analysis
390
+ half_height : int
391
+ Half height of autocorr
392
+ half_width : int
393
+ Half width of autocorr
394
+
395
+ Returns
396
+ -------
397
+ stats : dict
398
+ Dictionary with spacing, orientation, ellipse parameters
399
+ """
400
+ # Create distance mask
401
+ rr, cc = np.meshgrid(np.arange(autocorr.shape[1]), np.arange(autocorr.shape[0]))
402
+ dist_from_center = np.sqrt((cc - (half_height - 1)) ** 2 + (rr - (half_width - 1)) ** 2)
403
+
404
+ # Define search window around optimal radius
405
+ w = optimal_radius / 4
406
+ mask_outer = dist_from_center < (optimal_radius + w)
407
+
408
+ # Smooth autocorr to eliminate spurious maxima
409
+ autocorr_sm = gaussian_filter_2d(
410
+ autocorr, sigma=optimal_radius / (2 * np.pi) / 2, mode="reflect"
411
+ )
412
+
413
+ # Apply mask
414
+ masked_autocorr = mask_outer * autocorr_sm
415
+
416
+ # Find regional maxima
417
+ maxima_map = find_regional_maxima(masked_autocorr, connectivity=1)
418
+
419
+ # Dilate to eliminate fragmentation
420
+ maxima_map_dilated = dilate_image(maxima_map, selem_type="square", selem_size=3)
421
+
422
+ # Label connected components
423
+ labels, num_labels = label_connected_components(maxima_map_dilated, connectivity=2)
424
+
425
+ if num_labels < 5:
426
+ warnings.warn("Not enough grid peaks found for statistics", stacklevel=2)
427
+ return self._create_nan_stats()
428
+
429
+ # Get region properties
430
+ props = regionprops(labels)
431
+
432
+ # Extract centroids
433
+ centroids = np.array([prop.centroid for prop in props]) # (N, 2) array of (row, col)
434
+
435
+ # Convert to (x, y) coordinates
436
+ centers_of_mass = centroids[:, ::-1] # Swap to (col, row) = (x, y)
437
+
438
+ # Compute orientations relative to center
439
+ center_point = np.array([half_width - 1, half_height - 1])
440
+ orientations = np.arctan2(
441
+ centers_of_mass[:, 1] - center_point[1], centers_of_mass[:, 0] - center_point[0]
442
+ )
443
+
444
+ # Compute distances to center
445
+ peaks_to_center = squared_distance(centers_of_mass.T, center_point[:, np.newaxis]).ravel()
446
+
447
+ # Remove zero orientation (central peak if any)
448
+ zero_idx = np.where(orientations == 0)[0]
449
+ if len(zero_idx) > 0:
450
+ mask = np.ones(len(orientations), dtype=bool)
451
+ mask[zero_idx] = False
452
+ orientations = orientations[mask]
453
+ centers_of_mass = centers_of_mass[mask]
454
+ peaks_to_center = peaks_to_center[mask]
455
+
456
+ # Filter fields with similar orientations
457
+ orient_dist_sq = circ_dist2(orientations)
458
+ close_fields = np.abs(orient_dist_sq) < np.deg2rad(self.min_orientation)
459
+ np.fill_diagonal(close_fields, False)
460
+ close_fields = np.triu(close_fields) # Keep upper triangle only
461
+
462
+ rows, cols = np.where(close_fields)
463
+ to_delete = []
464
+ for row, col in zip(rows, cols, strict=True):
465
+ # Keep the one closer to center
466
+ if peaks_to_center[row] > peaks_to_center[col]:
467
+ to_delete.append(row)
468
+ else:
469
+ to_delete.append(col)
470
+
471
+ to_delete = np.unique(to_delete)
472
+ if len(to_delete) > 0:
473
+ mask = np.ones(len(orientations), dtype=bool)
474
+ mask[to_delete] = False
475
+ orientations = orientations[mask]
476
+ centers_of_mass = centers_of_mass[mask]
477
+ peaks_to_center = peaks_to_center[mask]
478
+
479
+ if len(centers_of_mass) < 4:
480
+ warnings.warn("Not enough grid peaks after filtering", stacklevel=2)
481
+ return self._create_nan_stats()
482
+
483
+ # Sort by distance to center and keep 6 closest
484
+ sort_idx = np.argsort(peaks_to_center)
485
+ centers_of_mass = centers_of_mass[sort_idx]
486
+ if len(centers_of_mass) > 6:
487
+ centers_of_mass = centers_of_mass[:6]
488
+
489
+ # Compute final orientations and spacings
490
+ orientations_deg = np.rad2deg(
491
+ np.arctan2(
492
+ centers_of_mass[:, 1] - center_point[1], centers_of_mass[:, 0] - center_point[0]
493
+ )
494
+ )
495
+
496
+ spacings = np.sqrt(
497
+ (centers_of_mass[:, 0] - center_point[0]) ** 2
498
+ + (centers_of_mass[:, 1] - center_point[1]) ** 2
499
+ )
500
+
501
+ # Fit ellipse to peaks
502
+ try:
503
+ ellipse = fit_ellipse(centers_of_mass[:, 0], centers_of_mass[:, 1])
504
+ ellipse_theta_deg = np.rad2deg(wrap_to_pi(ellipse[4]) + np.pi)
505
+ except Exception:
506
+ ellipse = np.full(5, np.nan)
507
+ ellipse_theta_deg = np.nan
508
+
509
+ # Select 3 orientations with smallest absolute values (closest to main axes)
510
+ abs_orient = np.abs(orientations_deg)
511
+ orient_sort_idx = np.argsort(abs_orient)
512
+ orient_sort_idx2 = np.argsort(
513
+ np.abs(orientations_deg - orientations_deg[orient_sort_idx[0]])
514
+ )
515
+
516
+ final_idx = orient_sort_idx2[:3]
517
+ orientations_deg = orientations_deg[final_idx]
518
+ spacings = spacings[final_idx]
519
+
520
+ # Sort by orientation
521
+ sort_idx = np.argsort(orientations_deg)
522
+ orientations_deg = orientations_deg[sort_idx]
523
+ spacings = spacings[sort_idx]
524
+
525
+ return {
526
+ "spacing": spacings,
527
+ "orientation": orientations_deg,
528
+ "ellipse": ellipse,
529
+ "ellipse_theta_deg": ellipse_theta_deg,
530
+ "peak_locations": centers_of_mass,
531
+ }
532
+
533
+ def _create_nan_result(self) -> GridnessResult:
534
+ """Create a result with NaN values for failed analysis."""
535
+ return GridnessResult(
536
+ score=np.nan,
537
+ spacing=np.full(3, np.nan),
538
+ orientation=np.full(3, np.nan),
539
+ ellipse=np.full(5, np.nan),
540
+ ellipse_theta_deg=np.nan,
541
+ center_radius=0,
542
+ optimal_radius=np.nan,
543
+ )
544
+
545
+ def _create_nan_stats(self) -> dict:
546
+ """Create statistics dict with NaN values."""
547
+ return {
548
+ "spacing": np.full(3, np.nan),
549
+ "orientation": np.full(3, np.nan),
550
+ "ellipse": np.full(5, np.nan),
551
+ "ellipse_theta_deg": np.nan,
552
+ }
553
+
554
+
555
+ def compute_2d_autocorrelation(rate_map: np.ndarray, overlap: float = 0.8) -> np.ndarray:
556
+ """
557
+ Compute 2D spatial autocorrelation of a firing rate map.
558
+
559
+ This is a convenience wrapper around the autocorrelation function
560
+ from the correlation module.
561
+
562
+ Parameters
563
+ ----------
564
+ rate_map : np.ndarray
565
+ 2D firing rate map
566
+ overlap : float, optional
567
+ Overlap percentage (0-1). Default is 0.8.
568
+
569
+ Returns
570
+ -------
571
+ autocorr : np.ndarray
572
+ 2D autocorrelogram
573
+
574
+ Examples
575
+ --------
576
+ >>> rate_map = np.random.rand(50, 50)
577
+ >>> autocorr = compute_2d_autocorrelation(rate_map)
578
+ >>> print(autocorr.shape)
579
+
580
+ Notes
581
+ -----
582
+ Based on autocorrelation.m from the MATLAB codebase.
583
+ """
584
+ return autocorrelation_2d(rate_map, overlap=overlap, normalize=True)
585
+
586
+
587
+ if __name__ == "__main__":
588
+ print("Testing GridnessAnalyzer...")
589
+
590
+ # Create a synthetic grid-like pattern
591
+ print("\nCreating synthetic grid pattern...")
592
+ x = np.linspace(-2, 2, 100)
593
+ xx, yy = np.meshgrid(x, x)
594
+
595
+ # Hexagonal grid pattern (sum of 3 cosines at 60° angles)
596
+ theta1, theta2, theta3 = 0, np.pi / 3, 2 * np.pi / 3
597
+ k = 2 * np.pi / 0.4 # Spatial frequency
598
+
599
+ grid_pattern = (
600
+ np.cos(k * (xx * np.cos(theta1) + yy * np.sin(theta1)))
601
+ + np.cos(k * (xx * np.cos(theta2) + yy * np.sin(theta2)))
602
+ + np.cos(k * (xx * np.cos(theta3) + yy * np.sin(theta3)))
603
+ ) / 3
604
+
605
+ # Make it a rate map (positive values)
606
+ rate_map = (grid_pattern + 1.5) / 2.5 * 10 # Scale to 0-10 Hz range
607
+
608
+ print(f"Rate map shape: {rate_map.shape}")
609
+ print(f"Rate map range: [{rate_map.min():.2f}, {rate_map.max():.2f}] Hz")
610
+
611
+ # Compute autocorrelation
612
+ print("\nComputing autocorrelation...")
613
+ autocorr = compute_2d_autocorrelation(rate_map)
614
+ print(f"Autocorr shape: {autocorr.shape}")
615
+
616
+ # Compute gridness score
617
+ print("\nComputing gridness score...")
618
+ analyzer = GridnessAnalyzer()
619
+ result = analyzer.compute_gridness_score(autocorr)
620
+
621
+ print("\nResults:")
622
+ print(f" Gridness score: {result.score:.3f}")
623
+ print(f" Grid spacing: {result.spacing}")
624
+ print(f" Grid orientation: {result.orientation}°")
625
+ print(f" Center radius: {result.center_radius}")
626
+ print(f" Optimal radius: {result.optimal_radius}")
627
+
628
+ if not np.isnan(result.ellipse).any():
629
+ print(f" Ellipse center: ({result.ellipse[0]:.1f}, {result.ellipse[1]:.1f})")
630
+ print(f" Ellipse radii: ({result.ellipse[2]:.1f}, {result.ellipse[3]:.1f})")
631
+ print(f" Ellipse angle: {result.ellipse_theta_deg:.1f}°")
632
+
633
+ print("\nGridnessAnalyzer test completed!")