canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,383 @@
1
+ """
2
+ Circular Statistics Utilities
3
+
4
+ Python port of CircStat MATLAB toolbox functions for circular statistics.
5
+
6
+ References:
7
+ - Statistical analysis of circular data, N.I. Fisher
8
+ - Topics in circular statistics, S.R. Jammalamadaka et al.
9
+ - Biostatistical Analysis, J. H. Zar
10
+ - CircStat MATLAB toolbox by Philipp Berens (2009)
11
+ """
12
+
13
+ import numpy as np
14
+
15
+
16
+ def circ_r(
17
+ alpha: np.ndarray, w: np.ndarray | None = None, d: float = 0.0, axis: int = 0
18
+ ) -> float | np.ndarray:
19
+ """
20
+ Compute mean resultant vector length for circular data.
21
+
22
+ This is a measure of circular variance (concentration). Values near 1 indicate
23
+ high concentration, values near 0 indicate uniform distribution.
24
+
25
+ Parameters
26
+ ----------
27
+ alpha : np.ndarray
28
+ Sample of angles in radians
29
+ w : np.ndarray, optional
30
+ Weights for each angle (e.g., for binned data). If None, uniform weights assumed.
31
+ d : float, optional
32
+ Spacing of bin centers for binned data. If supplied, correction factor is used
33
+ to correct for bias in estimation of r (in radians). Default is 0 (no correction).
34
+ axis : int, optional
35
+ Compute along this dimension. Default is 0.
36
+
37
+ Returns
38
+ -------
39
+ r : float or np.ndarray
40
+ Mean resultant vector length
41
+
42
+ Examples
43
+ --------
44
+ >>> angles = np.array([0, 0.1, 0.2, -0.1, -0.2]) # Concentrated around 0
45
+ >>> r = circ_r(angles)
46
+ >>> print(f"MVL: {r:.3f}") # Should be close to 1
47
+
48
+ >>> angles = np.linspace(0, 2*np.pi, 100) # Uniform distribution
49
+ >>> r = circ_r(angles)
50
+ >>> print(f"MVL: {r:.3f}") # Should be close to 0
51
+
52
+ Notes
53
+ -----
54
+ Based on CircStat toolbox circ_r.m by Philipp Berens (2009)
55
+ """
56
+ if w is None:
57
+ w = np.ones_like(alpha)
58
+
59
+ # Compute weighted sum of complex exponentials
60
+ r = np.sum(w * np.exp(1j * alpha), axis=axis)
61
+
62
+ # Obtain length normalized by sum of weights
63
+ r = np.abs(r) / np.sum(w, axis=axis)
64
+
65
+ # Apply correction factor for binned data if spacing is provided
66
+ if d != 0:
67
+ c = d / (2 * np.sin(d / 2))
68
+ r = c * r
69
+
70
+ return r
71
+
72
+
73
+ def circ_mean(alpha: np.ndarray, w: np.ndarray | None = None, axis: int = 0) -> float | np.ndarray:
74
+ """
75
+ Compute mean direction for circular data.
76
+
77
+ Parameters
78
+ ----------
79
+ alpha : np.ndarray
80
+ Sample of angles in radians
81
+ w : np.ndarray, optional
82
+ Weights for each angle (e.g., for binned data). If None, uniform weights assumed.
83
+ axis : int, optional
84
+ Compute along this dimension. Default is 0.
85
+
86
+ Returns
87
+ -------
88
+ mu : float or np.ndarray
89
+ Mean direction in radians, range [-π, π]
90
+
91
+ Examples
92
+ --------
93
+ >>> angles = np.array([0, 0.1, 0.2, -0.1, -0.2])
94
+ >>> mean_angle = circ_mean(angles)
95
+ >>> print(f"Mean direction: {mean_angle:.3f} rad")
96
+
97
+ >>> # Weighted mean
98
+ >>> angles = np.array([0, np.pi])
99
+ >>> weights = np.array([3, 1]) # 3x more weight on 0
100
+ >>> mean_angle = circ_mean(angles, w=weights)
101
+
102
+ Notes
103
+ -----
104
+ Based on CircStat toolbox circ_mean.m by Philipp Berens (2009)
105
+ """
106
+ if w is None:
107
+ w = np.ones_like(alpha)
108
+ else:
109
+ if alpha.ndim > 0 and w.shape != alpha.shape:
110
+ raise ValueError("Input dimensions do not match")
111
+
112
+ # Compute weighted sum of complex exponentials
113
+ r = np.sum(w * np.exp(1j * alpha), axis=axis)
114
+
115
+ # Obtain mean angle from complex sum
116
+ mu = np.angle(r)
117
+
118
+ return mu
119
+
120
+
121
+ def circ_std(
122
+ alpha: np.ndarray, w: np.ndarray | None = None, d: float = 0.0, axis: int = 0
123
+ ) -> tuple[float | np.ndarray, float | np.ndarray]:
124
+ """
125
+ Compute circular standard deviation for circular data.
126
+
127
+ Parameters
128
+ ----------
129
+ alpha : np.ndarray
130
+ Sample of angles in radians
131
+ w : np.ndarray, optional
132
+ Weights for each angle. If None, uniform weights assumed.
133
+ d : float, optional
134
+ Spacing of bin centers for binned data (correction factor). Default is 0.
135
+ axis : int, optional
136
+ Compute along this dimension. Default is 0.
137
+
138
+ Returns
139
+ -------
140
+ s : float or np.ndarray
141
+ Angular deviation (equation 26.20, Zar)
142
+ s0 : float or np.ndarray
143
+ Circular standard deviation (equation 26.21, Zar)
144
+
145
+ Examples
146
+ --------
147
+ >>> angles = np.array([0, 0.1, 0.2, -0.1, -0.2])
148
+ >>> s, s0 = circ_std(angles)
149
+ >>> print(f"Angular deviation: {s:.3f} rad")
150
+
151
+ Notes
152
+ -----
153
+ Based on CircStat toolbox circ_std.m by Philipp Berens (2009)
154
+ References: Biostatistical Analysis, J. H. Zar
155
+ """
156
+ if w is None:
157
+ w = np.ones_like(alpha)
158
+ else:
159
+ if w.shape != alpha.shape:
160
+ raise ValueError("Input dimensions do not match")
161
+
162
+ # Compute mean resultant vector length
163
+ r = circ_r(alpha, w=w, d=d, axis=axis)
164
+
165
+ # Compute standard deviations (equations from Zar)
166
+ s = np.sqrt(2 * (1 - r)) # 26.20 - angular deviation
167
+ s0 = np.sqrt(-2 * np.log(r)) # 26.21 - circular standard deviation
168
+
169
+ return s, s0
170
+
171
+
172
+ def circ_dist(x: np.ndarray, y: np.ndarray) -> np.ndarray:
173
+ """
174
+ Pairwise angular distance between angles (x_i - y_i) around the circle.
175
+
176
+ Computes the shortest signed angular distance from y to x, respecting
177
+ circular topology (wrapping at ±π).
178
+
179
+ Parameters
180
+ ----------
181
+ x : np.ndarray
182
+ First set of angles in radians
183
+ y : np.ndarray
184
+ Second set of angles in radians (must be same shape as x, or scalar)
185
+
186
+ Returns
187
+ -------
188
+ r : np.ndarray
189
+ Angular distances in radians, range [-π, π]
190
+
191
+ Examples
192
+ --------
193
+ >>> x = np.array([0.1, np.pi])
194
+ >>> y = np.array([0.0, -np.pi]) # -π and π are same location
195
+ >>> dist = circ_dist(x, y)
196
+ >>> print(dist) # [0.1, 0.0]
197
+
198
+ >>> # Distance wraps around at ±π
199
+ >>> x = np.array([np.pi - 0.1])
200
+ >>> y = np.array([-np.pi + 0.1])
201
+ >>> dist = circ_dist(x, y)
202
+ >>> print(dist) # Small value, not 2π - 0.2
203
+
204
+ Notes
205
+ -----
206
+ Based on CircStat toolbox circ_dist.m by Philipp Berens (2009)
207
+ References: Biostatistical Analysis, J. H. Zar, p. 651
208
+ """
209
+ # Compute angular difference using complex exponentials
210
+ # This automatically wraps to [-π, π]
211
+ r = np.angle(np.exp(1j * x) / np.exp(1j * y))
212
+
213
+ return r
214
+
215
+
216
+ def circ_dist2(x: np.ndarray, y: np.ndarray | None = None) -> np.ndarray:
217
+ """
218
+ All pairwise angular distances (x_i - y_j) around the circle.
219
+
220
+ Computes the matrix of all pairwise angular distances between two sets
221
+ of angles, or within one set if y is not provided.
222
+
223
+ Parameters
224
+ ----------
225
+ x : np.ndarray
226
+ First set of angles in radians (will be treated as column vector)
227
+ y : np.ndarray, optional
228
+ Second set of angles in radians (will be treated as column vector).
229
+ If None, computes pairwise distances within x. Default is None.
230
+
231
+ Returns
232
+ -------
233
+ r : np.ndarray
234
+ Matrix of pairwise angular distances, shape (len(x), len(y))
235
+ Element (i, j) contains the distance from y[j] to x[i]
236
+
237
+ Examples
238
+ --------
239
+ >>> x = np.array([0, np.pi/2, np.pi])
240
+ >>> D = circ_dist2(x) # All pairwise distances within x
241
+ >>> print(D.shape) # (3, 3)
242
+
243
+ >>> y = np.array([0, np.pi])
244
+ >>> D = circ_dist2(x, y) # All distances from y to x
245
+ >>> print(D.shape) # (3, 2)
246
+
247
+ Notes
248
+ -----
249
+ Based on CircStat toolbox circ_dist2.m by Philipp Berens (2009)
250
+ """
251
+ if y is None:
252
+ y = x
253
+
254
+ # Ensure column vectors
255
+ x = np.atleast_1d(x)
256
+ y = np.atleast_1d(y)
257
+
258
+ if x.ndim == 1:
259
+ x = x[:, np.newaxis]
260
+ elif x.shape[1] > x.shape[0]:
261
+ x = x.T
262
+
263
+ if y.ndim == 1:
264
+ y = y[:, np.newaxis]
265
+ elif y.shape[1] > y.shape[0]:
266
+ y = y.T
267
+
268
+ # Compute all pairwise distances using broadcasting
269
+ # Shape: (len(x), len(y))
270
+ r = np.angle(np.exp(1j * x) / np.exp(1j * y.T))
271
+
272
+ return r
273
+
274
+
275
+ def circ_rtest(alpha: np.ndarray, w: np.ndarray | None = None) -> float:
276
+ """
277
+ Rayleigh test for non-uniformity of circular data.
278
+
279
+ H0: The population is uniformly distributed around the circle.
280
+ HA: The population is not uniformly distributed.
281
+
282
+ Parameters
283
+ ----------
284
+ alpha : np.ndarray
285
+ Sample of angles in radians
286
+ w : np.ndarray, optional
287
+ Weights for each angle. If None, uniform weights assumed.
288
+
289
+ Returns
290
+ -------
291
+ pval : float
292
+ p-value of Rayleigh test. Small values (< 0.05) indicate
293
+ significant deviation from uniformity.
294
+
295
+ Examples
296
+ --------
297
+ >>> # Concentrated distribution
298
+ >>> angles = np.random.normal(0, 0.1, 100)
299
+ >>> p = circ_rtest(angles)
300
+ >>> print(f"p-value: {p:.4f}") # Should be < 0.05
301
+
302
+ >>> # Uniform distribution
303
+ >>> angles = np.random.uniform(-np.pi, np.pi, 100)
304
+ >>> p = circ_rtest(angles)
305
+ >>> print(f"p-value: {p:.4f}") # Should be > 0.05
306
+
307
+ Notes
308
+ -----
309
+ Test statistic: Z = n * r^2, where n is sample size and r is MVL
310
+ Approximation for p-value: p ≈ exp(-Z) * (1 + (2*Z - Z^2)/(4*n))
311
+
312
+ References: Topics in Circular Statistics, S.R. Jammalamadaka et al., p. 48
313
+ """
314
+ if w is None:
315
+ w = np.ones_like(alpha)
316
+
317
+ # Compute MVL
318
+ r = circ_r(alpha, w=w)
319
+
320
+ # Sample size
321
+ n = len(alpha)
322
+
323
+ # Rayleigh test statistic
324
+ Z = n * r**2
325
+
326
+ # Approximate p-value (good for n > 50, reasonable for n > 20)
327
+ pval = np.exp(-Z) * (
328
+ 1 + (2 * Z - Z**2) / (4 * n) - (24 * Z - 132 * Z**2 + 76 * Z**3 - 9 * Z**4) / (288 * n**2)
329
+ )
330
+
331
+ return pval
332
+
333
+
334
+ # Convenience aliases for compatibility
335
+ mvl = circ_r # Mean Vector Length is same as resultant length
336
+ angular_mean = circ_mean
337
+ angular_std = circ_std
338
+ angular_distance = circ_dist
339
+
340
+
341
+ if __name__ == "__main__":
342
+ # Simple tests
343
+ print("Testing circular statistics functions...")
344
+
345
+ # Test 1: Concentrated distribution
346
+ angles = np.random.normal(0, 0.1, 100)
347
+ r = circ_r(angles)
348
+ mu = circ_mean(angles)
349
+ s, s0 = circ_std(angles)
350
+ p = circ_rtest(angles)
351
+
352
+ print("\nConcentrated distribution (mean=0, std=0.1):")
353
+ print(f" MVL: {r:.3f} (should be close to 1)")
354
+ print(f" Mean direction: {mu:.3f} rad (should be close to 0)")
355
+ print(f" Angular deviation: {s:.3f} rad")
356
+ print(f" Rayleigh test p-value: {p:.4f} (should be < 0.05)")
357
+
358
+ # Test 2: Uniform distribution
359
+ angles = np.random.uniform(-np.pi, np.pi, 100)
360
+ r = circ_r(angles)
361
+ p = circ_rtest(angles)
362
+
363
+ print("\nUniform distribution:")
364
+ print(f" MVL: {r:.3f} (should be close to 0)")
365
+ print(f" Rayleigh test p-value: {p:.4f} (should be > 0.05)")
366
+
367
+ # Test 3: Angular distances
368
+ x = np.array([0.1, np.pi])
369
+ y = np.array([0.0, -np.pi])
370
+ dist = circ_dist(x, y)
371
+
372
+ print("\nAngular distances:")
373
+ print(f" dist([0.1, π], [0, -π]) = {dist}")
374
+ print(" Note: π and -π are the same location, so distance is 0")
375
+
376
+ # Test 4: Pairwise distances
377
+ angles = np.array([0, np.pi / 2, np.pi, -np.pi / 2])
378
+ D = circ_dist2(angles)
379
+
380
+ print(f"\nPairwise distance matrix shape: {D.shape}")
381
+ print(f"Diagonal should be all zeros: {np.diag(D)}")
382
+
383
+ print("\nAll tests completed!")
@@ -0,0 +1,318 @@
1
+ """
2
+ Correlation Utilities
3
+
4
+ Functions for computing Pearson correlation and normalized cross-correlation,
5
+ optimized for neuroscience data analysis.
6
+ """
7
+
8
+ import numpy as np
9
+ from scipy import signal
10
+
11
+
12
+ def pearson_correlation(x: np.ndarray, y: np.ndarray) -> np.ndarray:
13
+ """
14
+ Compute Pearson correlation coefficient between x and each column of y.
15
+
16
+ This is an optimized implementation that efficiently handles multiple
17
+ correlations when y has multiple columns.
18
+
19
+ Parameters
20
+ ----------
21
+ x : np.ndarray
22
+ First array, shape (n,) or (n, 1)
23
+ y : np.ndarray
24
+ Second array, shape (n,) or (n, m) where m is number of columns
25
+
26
+ Returns
27
+ -------
28
+ r : np.ndarray
29
+ Correlation coefficients. If y is 1-D, returns scalar.
30
+ If y is 2-D with m columns, returns array of shape (m,)
31
+
32
+ Examples
33
+ --------
34
+ >>> x = np.array([1, 2, 3, 4, 5])
35
+ >>> y = np.array([2, 4, 6, 8, 10])
36
+ >>> r = pearson_correlation(x, y)
37
+ >>> print(f"Correlation: {r:.3f}") # Should be 1.0
38
+
39
+ >>> # Multiple correlations at once
40
+ >>> y_multi = np.column_stack([
41
+ ... [2, 4, 6, 8, 10], # Perfect positive correlation
42
+ ... [5, 4, 3, 2, 1], # Perfect negative correlation
43
+ ... ])
44
+ >>> r = pearson_correlation(x, y_multi)
45
+ >>> print(r) # [1.0, -1.0]
46
+
47
+ Notes
48
+ -----
49
+ Based on corrPearson.m from the MATLAB codebase.
50
+ Normalization factor (n-1) omitted since we renormalize anyway.
51
+ """
52
+ # Ensure arrays are at least 2D for matrix operations
53
+ x = np.atleast_2d(x)
54
+ y = np.atleast_2d(y)
55
+
56
+ # If x or y were passed as row vectors, transpose them
57
+ if x.shape[0] == 1:
58
+ x = x.T
59
+ if y.shape[0] == 1:
60
+ y = y.T
61
+
62
+ n = x.shape[0]
63
+
64
+ # Center the data (remove mean)
65
+ x = x - np.sum(x, axis=0) / n
66
+ y = y - np.sum(y, axis=0) / n
67
+
68
+ # Compute correlation: x^T * y
69
+ r = x.T @ y # Shape: (1, m) if x is column, y has m columns
70
+
71
+ # Compute norms
72
+ dx = np.linalg.norm(x, axis=0, keepdims=True) # Shape: (1,) or (1, k)
73
+ dy = np.linalg.norm(y, axis=0, keepdims=True) # Shape: (1, m)
74
+
75
+ # Normalize: r / (dx * dy)
76
+ # Broadcasting handles the division correctly
77
+ r = r / dx.T # Divide by dx (column-wise)
78
+ r = r / dy # Divide by dy (row-wise)
79
+
80
+ # Return as 1D array or scalar
81
+ r = np.squeeze(r)
82
+
83
+ return r
84
+
85
+
86
+ def normalized_xcorr2(
87
+ template: np.ndarray, image: np.ndarray, mode: str = "same", min_overlap: int = 0
88
+ ) -> np.ndarray:
89
+ """
90
+ Normalized 2D cross-correlation.
91
+
92
+ Computes the normalized cross-correlation of two 2D arrays. Unlike
93
+ scipy.signal.correlate, this function properly handles varying overlap
94
+ regions and works correctly even when template and image are the same size.
95
+
96
+ Parameters
97
+ ----------
98
+ template : np.ndarray
99
+ 2D template array
100
+ image : np.ndarray
101
+ 2D image array
102
+ mode : str, optional
103
+ 'full' - full correlation (default for autocorrelation)
104
+ 'same' - output size same as image
105
+ 'valid' - only where template fully overlaps image
106
+ min_overlap : int, optional
107
+ Minimum number of overlapping pixels required for valid correlation.
108
+ Locations with fewer overlapping pixels are set to 0.
109
+ Default is 0 (no threshold).
110
+
111
+ Returns
112
+ -------
113
+ C : np.ndarray
114
+ Normalized cross-correlation. Values range from -1 to 1.
115
+
116
+ Examples
117
+ --------
118
+ >>> # Autocorrelation (template = image)
119
+ >>> image = np.random.rand(50, 50)
120
+ >>> autocorr = normalized_xcorr2(image, image, mode='full')
121
+ >>> # Peak should be at center with value 1.0
122
+ >>> center = np.array(autocorr.shape) // 2
123
+ >>> print(f"Peak value: {autocorr[tuple(center)]:.3f}")
124
+
125
+ >>> # Template matching
126
+ >>> image = np.random.rand(100, 100)
127
+ >>> template = image[40:60, 40:60] # Extract 20x20 patch
128
+ >>> corr = normalized_xcorr2(template, image)
129
+ >>> # Should find the template location
130
+
131
+ Notes
132
+ -----
133
+ This is a simplified Python implementation. For the full general version
134
+ (handling all edge cases), see normxcorr2_general.m by Dirk Padfield.
135
+
136
+ For most neuroscience applications (autocorrelation of rate maps),
137
+ scipy.signal.correlate with normalization is sufficient.
138
+
139
+ References
140
+ ----------
141
+ Padfield, D. "Masked FFT registration". CVPR, 2010.
142
+ Lewis, J.P. "Fast Normalized Cross-Correlation". Industrial Light & Magic.
143
+ """
144
+ # Convert to double for numerical stability
145
+ template = np.asarray(template, dtype=np.float64)
146
+ image = np.asarray(image, dtype=np.float64)
147
+
148
+ # Ensure arrays are 2D
149
+ if template.ndim != 2 or image.ndim != 2:
150
+ raise ValueError("Both template and image must be 2D arrays")
151
+
152
+ # Check for flat template (no variation)
153
+ if np.std(template) == 0:
154
+ raise ValueError("Template cannot have all identical values")
155
+
156
+ # For neuroscience applications, we typically use FFT-based correlation
157
+ # which is faster for larger arrays
158
+ if mode == "full":
159
+ # Full correlation (output larger than both inputs)
160
+ C = signal.correlate(image, template, mode="full", method="fft")
161
+
162
+ # Compute normalization factors
163
+ # This is a simplified normalization; full version would handle
164
+ # varying overlap regions precisely
165
+ n_template = template.size
166
+ template_mean = np.mean(template)
167
+ template_std = np.std(template)
168
+
169
+ image_mean = np.mean(image)
170
+ image_std = np.std(image)
171
+
172
+ # Normalize
173
+ if template_std > 0 and image_std > 0:
174
+ C = (C - n_template * template_mean * image_mean) / (
175
+ n_template * template_std * image_std
176
+ )
177
+
178
+ else:
179
+ # For 'same' or 'valid', use scipy's built-in
180
+ C = signal.correlate(image, template, mode=mode, method="fft")
181
+
182
+ # Simple normalization (assumes full overlap in valid region)
183
+ template_norm = np.sqrt(np.sum(template**2))
184
+ image_norm = np.sqrt(np.sum(image**2))
185
+
186
+ if template_norm > 0 and image_norm > 0:
187
+ C = C / (template_norm * image_norm)
188
+
189
+ # Clip to valid correlation range
190
+ C = np.clip(C, -1, 1)
191
+
192
+ return C
193
+
194
+
195
+ def autocorrelation_2d(
196
+ array: np.ndarray, overlap: float = 0.8, normalize: bool = True
197
+ ) -> np.ndarray:
198
+ """
199
+ Compute 2D autocorrelation of an array.
200
+
201
+ This is a convenience function specifically for computing spatial
202
+ autocorrelation of firing rate maps, which is needed for grid cell analysis.
203
+
204
+ Parameters
205
+ ----------
206
+ array : np.ndarray
207
+ 2D array (e.g., firing rate map)
208
+ overlap : float, optional
209
+ Percentage of overlap region to keep (0-1). Default is 0.8.
210
+ The autocorrelogram is cropped to this central region to avoid
211
+ edge artifacts.
212
+ normalize : bool, optional
213
+ Whether to normalize the correlation. Default is True.
214
+
215
+ Returns
216
+ -------
217
+ autocorr : np.ndarray
218
+ 2D autocorrelation array
219
+
220
+ Examples
221
+ --------
222
+ >>> # Create a simple periodic pattern (grid-like)
223
+ >>> x = np.linspace(0, 4*np.pi, 50)
224
+ >>> xx, yy = np.meshgrid(x, x)
225
+ >>> pattern = np.cos(xx) * np.cos(yy)
226
+ >>> autocorr = autocorrelation_2d(pattern)
227
+ >>> # Autocorr should show hexagonal/grid pattern
228
+
229
+ Notes
230
+ -----
231
+ Based on autocorrelation.m from the MATLAB codebase.
232
+ Replaces NaN values with 0 before computing correlation.
233
+ """
234
+ # Replace NaN with 0
235
+ array = np.nan_to_num(array, nan=0.0)
236
+
237
+ # Compute new size for overlap region
238
+ new_size_v = int(np.round(array.shape[0] * (1 + overlap)))
239
+ new_size_h = int(np.round(array.shape[1] * (1 + overlap)))
240
+
241
+ # Ensure odd dimensions for symmetry
242
+ if new_size_v % 2 == 0 and new_size_v > 0:
243
+ new_size_v -= 1
244
+ if new_size_h % 2 == 0 and new_size_h > 0:
245
+ new_size_h -= 1
246
+
247
+ # Handle empty or all-zero arrays
248
+ if array.size == 0 or np.all(array == 0):
249
+ return np.zeros((new_size_v, new_size_h))
250
+
251
+ # Subtract mean for proper autocorrelation
252
+ # This is crucial for grid cell analysis - ensures the autocorrelation
253
+ # captures the spatial periodicity rather than the mean firing rate
254
+ array_demean = array - np.mean(array)
255
+
256
+ # Compute full autocorrelation
257
+ Rxx = signal.correlate(array_demean, array_demean, mode="full", method="fft")
258
+
259
+ # Extract central overlap region first
260
+ offset_v = (Rxx.shape[0] - new_size_v) // 2
261
+ offset_h = (Rxx.shape[1] - new_size_h) // 2
262
+
263
+ if offset_v >= 0 and offset_h >= 0:
264
+ Rxx = Rxx[offset_v : offset_v + new_size_v, offset_h : offset_h + new_size_h]
265
+ else:
266
+ # If requested size is larger than autocorr, just return full
267
+ pass
268
+
269
+ # Normalize by the center (zero-lag) value
270
+ if normalize:
271
+ center_v = Rxx.shape[0] // 2
272
+ center_h = Rxx.shape[1] // 2
273
+ center_value = Rxx[center_v, center_h]
274
+
275
+ if center_value > 0:
276
+ Rxx = Rxx / center_value
277
+
278
+ return Rxx
279
+
280
+
281
+ if __name__ == "__main__":
282
+ # Simple tests
283
+ print("Testing correlation functions...")
284
+
285
+ # Test 1: Pearson correlation
286
+ x = np.array([1, 2, 3, 4, 5], dtype=float)
287
+ y = np.array([2, 4, 6, 8, 10], dtype=float)
288
+ r = pearson_correlation(x, y)
289
+ print(f"\nTest 1 - Perfect positive correlation: r = {r:.3f} (should be 1.0)")
290
+
291
+ # Test 2: Multiple correlations
292
+ y_multi = np.column_stack(
293
+ [
294
+ [2, 4, 6, 8, 10], # Perfect positive
295
+ [5, 4, 3, 2, 1], # Perfect negative
296
+ [1, 1, 1, 1, 1], # Constant (should be NaN or 0)
297
+ ]
298
+ )
299
+ r_multi = pearson_correlation(x, y_multi[:, :2]) # Skip constant
300
+ print(f"\nTest 2 - Multiple correlations: r = {r_multi} (should be [1.0, -1.0])")
301
+
302
+ # Test 3: 2D autocorrelation
303
+ # Create a simple grid-like pattern
304
+ x_coords = np.linspace(0, 4 * np.pi, 30)
305
+ xx, yy = np.meshgrid(x_coords, x_coords)
306
+ grid_pattern = np.cos(xx) * np.cos(yy)
307
+
308
+ autocorr = autocorrelation_2d(grid_pattern)
309
+ print("\nTest 3 - 2D Autocorrelation:")
310
+ print(f" Input shape: {grid_pattern.shape}")
311
+ print(f" Autocorr shape: {autocorr.shape}")
312
+ print(f" Autocorr max: {np.max(autocorr):.3f} (should be close to 1.0 at center)")
313
+
314
+ # Center should have maximum correlation
315
+ center = np.array(autocorr.shape) // 2
316
+ print(f" Center value: {autocorr[tuple(center)]:.3f}")
317
+
318
+ print("\nAll tests completed!")