canns 0.13.2__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/pipeline/__init__.py +13 -1
  28. canns/pipeline/asa/runner.py +105 -41
  29. canns/pipeline/asa_gui/__init__.py +68 -0
  30. canns/pipeline/asa_gui/__main__.py +6 -0
  31. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  32. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  33. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  34. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  35. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  36. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  37. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  38. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  39. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  40. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  41. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  42. canns/pipeline/asa_gui/app.py +29 -0
  43. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  44. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  45. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  46. canns/pipeline/asa_gui/core/__init__.py +15 -0
  47. canns/pipeline/asa_gui/core/cache.py +14 -0
  48. canns/pipeline/asa_gui/core/runner.py +1936 -0
  49. canns/pipeline/asa_gui/core/state.py +324 -0
  50. canns/pipeline/asa_gui/core/worker.py +260 -0
  51. canns/pipeline/asa_gui/main_window.py +184 -0
  52. canns/pipeline/asa_gui/models/__init__.py +7 -0
  53. canns/pipeline/asa_gui/models/config.py +14 -0
  54. canns/pipeline/asa_gui/models/job.py +31 -0
  55. canns/pipeline/asa_gui/models/presets.py +21 -0
  56. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  57. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  58. canns/pipeline/asa_gui/resources/light.qss +163 -0
  59. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  60. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  61. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  62. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  63. canns/pipeline/asa_gui/utils/validators.py +41 -0
  64. canns/pipeline/asa_gui/views/__init__.py +1 -0
  65. canns/pipeline/asa_gui/views/help_content.py +171 -0
  66. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  67. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  68. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  69. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  70. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  71. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  72. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  73. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  74. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  75. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  76. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  77. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  78. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  79. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  80. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  81. canns/pipeline/gallery/app.py +5 -7
  82. canns/pipeline/gallery/runner.py +16 -9
  83. canns/pipeline/gallery/state.py +0 -1
  84. {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  85. canns-0.14.0.dist-info/RECORD +163 -0
  86. {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/entry_points.txt +1 -0
  87. canns-0.13.2.dist-info/RECORD +0 -95
  88. {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  89. {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,288 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+
8
+ def _base_mask(shape: tuple[int, int], center_bins: int = 2) -> np.ndarray:
9
+ """Create a shared mask: center disk + outside circle (corners)."""
10
+ h, w = shape
11
+ cy = (h - 1) / 2.0
12
+ cx = (w - 1) / 2.0
13
+ yy, xx = np.ogrid[:h, :w]
14
+ rr = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
15
+
16
+ # outer circle radius: half-min dimension
17
+ outer_r = min(cy, cx)
18
+ mask_outer = rr > outer_r
19
+
20
+ mask_center = rr <= float(center_bins)
21
+ return mask_outer | mask_center
22
+
23
+
24
+ def _safe_corr(a: np.ndarray, b: np.ndarray) -> float:
25
+ a = np.asarray(a).ravel()
26
+ b = np.asarray(b).ravel()
27
+ if a.size != b.size or a.size == 0:
28
+ return 0.0
29
+ a = a - a.mean()
30
+ b = b - b.mean()
31
+ da = np.sqrt((a * a).mean())
32
+ db = np.sqrt((b * b).mean())
33
+ if da <= 1e-12 or db <= 1e-12:
34
+ return 0.0
35
+ return float((a * b).mean() / (da * db))
36
+
37
+
38
+ def _vectorize_autocorrs(
39
+ autocorrs: np.ndarray, center_bins: int = 2
40
+ ) -> tuple[np.ndarray, np.ndarray]:
41
+ """Vectorize autocorrs into point-cloud matrix X and return the shared mask."""
42
+ if autocorrs.ndim != 3:
43
+ raise ValueError(f"autocorrs must be (N,H,W), got {autocorrs.shape}")
44
+ N, H, W = autocorrs.shape
45
+ mask = _base_mask((H, W), center_bins=center_bins)
46
+ # Use shared mask only; replace NaN with 0 so dimensions match across cells.
47
+ ac = np.nan_to_num(autocorrs, nan=0.0, posinf=0.0, neginf=0.0)
48
+ X = ac[:, ~mask] # (N, n_features)
49
+ return X.astype(np.float32, copy=False), mask
50
+
51
+
52
+ def _build_knn_graph(X: np.ndarray, k: int = 30, metric: str = "manhattan"):
53
+ """Return an igraph graph built from kNN with edge weights."""
54
+ try:
55
+ from sklearn.neighbors import NearestNeighbors
56
+ except Exception as e:
57
+ raise ImportError(f"scikit-learn is required for kNN graph: {e}") from e
58
+
59
+ try:
60
+ import igraph as ig
61
+ except Exception as e:
62
+ raise ImportError(f"python-igraph is required for Leiden clustering: {e}") from e
63
+
64
+ N = X.shape[0]
65
+ k_eff = min(max(int(k), 1), max(N - 1, 1))
66
+
67
+ nbrs = NearestNeighbors(n_neighbors=k_eff + 1, metric=metric)
68
+ nbrs.fit(X)
69
+ dist, ind = nbrs.kneighbors(X, return_distance=True)
70
+
71
+ edges = []
72
+ weights = []
73
+ eps = 1e-6
74
+ seen = set()
75
+ for i in range(N):
76
+ for jj in range(1, k_eff + 1):
77
+ j = int(ind[i, jj])
78
+ if i == j:
79
+ continue
80
+ a, b = (i, j) if i < j else (j, i)
81
+ if (a, b) in seen:
82
+ continue
83
+ seen.add((a, b))
84
+ d = float(dist[i, jj])
85
+ w = 1.0 / (d + eps)
86
+ edges.append((a, b))
87
+ weights.append(w)
88
+
89
+ g = ig.Graph(n=N, edges=edges, directed=False)
90
+ g.es["weight"] = weights
91
+ return g
92
+
93
+
94
+ def _leiden_membership(graph, resolution: float = 1.0) -> np.ndarray:
95
+ """Run Leiden and return membership array."""
96
+ try:
97
+ import leidenalg
98
+ except Exception as e:
99
+ raise ImportError(f"leidenalg is required for Leiden clustering: {e}") from e
100
+
101
+ part = leidenalg.find_partition(
102
+ graph,
103
+ leidenalg.RBConfigurationVertexPartition,
104
+ weights=graph.es["weight"] if "weight" in graph.es.attributes() else None,
105
+ resolution_parameter=float(resolution),
106
+ )
107
+ return np.asarray(part.membership, dtype=int)
108
+
109
+
110
+ class _DSU:
111
+ def __init__(self, n: int):
112
+ self.p = list(range(n))
113
+ self.r = [0] * n
114
+
115
+ def find(self, a: int) -> int:
116
+ while self.p[a] != a:
117
+ self.p[a] = self.p[self.p[a]]
118
+ a = self.p[a]
119
+ return a
120
+
121
+ def union(self, a: int, b: int):
122
+ ra, rb = self.find(a), self.find(b)
123
+ if ra == rb:
124
+ return
125
+ if self.r[ra] < self.r[rb]:
126
+ ra, rb = rb, ra
127
+ self.p[rb] = ra
128
+ if self.r[ra] == self.r[rb]:
129
+ self.r[ra] += 1
130
+
131
+
132
+ def identify_grid_modules_and_stats(
133
+ autocorrs: np.ndarray,
134
+ *,
135
+ gridness_analyzer,
136
+ center_bins: int = 2,
137
+ k: int = 30,
138
+ resolution: float = 1.0,
139
+ score_thr: float = 0.3,
140
+ consistency_thr: float = 0.5,
141
+ min_cells: int = 10,
142
+ merge_corr_thr: float = 0.7,
143
+ metric: str = "manhattan",
144
+ ) -> dict[str, Any]:
145
+ """Identify grid modules with Leiden clustering on autocorrelogram point cloud.
146
+
147
+ Parameters
148
+ ----------
149
+ autocorrs : np.ndarray
150
+ Array of shape (N, H, W).
151
+ gridness_analyzer :
152
+ An instance that provides compute_gridness_score(autocorr)->GridnessResult.
153
+ center_bins : int
154
+ Radius (in bins) to mask around center peak.
155
+ k : int
156
+ Neighbors for kNN graph.
157
+ resolution : float
158
+ Leiden resolution parameter.
159
+ score_thr, consistency_thr, min_cells, merge_corr_thr
160
+ Module acceptance and merging thresholds.
161
+
162
+ Returns
163
+ -------
164
+ dict with keys:
165
+ module_id (N,), cluster_id (N,), modules (list of dict), params
166
+ """
167
+ if autocorrs.ndim != 3:
168
+ raise ValueError(f"autocorrs must be (N,H,W). Got {autocorrs.shape}")
169
+ N, H, W = autocorrs.shape
170
+
171
+ X, mask = _vectorize_autocorrs(autocorrs, center_bins=int(center_bins))
172
+ g = _build_knn_graph(X, k=int(k), metric=str(metric))
173
+ cluster_id = _leiden_membership(g, resolution=float(resolution))
174
+
175
+ # group members
176
+ clusters: dict[int, np.ndarray] = {}
177
+ for cid in np.unique(cluster_id):
178
+ clusters[int(cid)] = np.where(cluster_id == cid)[0]
179
+
180
+ base_mask = mask # shared
181
+ ac = np.nan_to_num(autocorrs, nan=0.0, posinf=0.0, neginf=0.0)
182
+
183
+ cluster_stats = {}
184
+ candidate_cids = []
185
+
186
+ # compute cluster metrics
187
+ for cid, idxs in clusters.items():
188
+ if idxs.size == 0:
189
+ continue
190
+ avg = ac[idxs].mean(axis=0)
191
+ med = np.median(ac[idxs], axis=0)
192
+
193
+ gr = gridness_analyzer.compute_gridness_score(med)
194
+ grid_score = float(getattr(gr, "score", getattr(gr, "grid_score", np.nan)))
195
+
196
+ flat_avg = avg[~base_mask]
197
+ cors = []
198
+ for i in idxs:
199
+ c = _safe_corr(flat_avg, ac[i][~base_mask])
200
+ cors.append(c)
201
+ consistency = float(np.median(cors)) if len(cors) else 0.0
202
+
203
+ cluster_stats[cid] = {
204
+ "cid": cid,
205
+ "size": int(idxs.size),
206
+ "grid_score": grid_score,
207
+ "consistency": consistency,
208
+ "gridness_result": gr,
209
+ "avg_autocorr": avg,
210
+ }
211
+
212
+ if (
213
+ (grid_score > float(score_thr))
214
+ and (consistency > float(consistency_thr))
215
+ and (idxs.size >= int(min_cells))
216
+ ):
217
+ candidate_cids.append(cid)
218
+
219
+ # Merge candidate clusters based on avg autocorr correlation
220
+ cand = list(candidate_cids)
221
+ dsu = _DSU(len(cand))
222
+ for i in range(len(cand)):
223
+ for j in range(i + 1, len(cand)):
224
+ ci, cj = cand[i], cand[j]
225
+ ai = cluster_stats[ci]["avg_autocorr"][~base_mask]
226
+ aj = cluster_stats[cj]["avg_autocorr"][~base_mask]
227
+ corr = _safe_corr(ai, aj)
228
+ if corr > float(merge_corr_thr):
229
+ dsu.union(i, j)
230
+
231
+ # Build merged modules
232
+ root_to_members: dict[int, list[int]] = {}
233
+ for idx, cid in enumerate(cand):
234
+ r = dsu.find(idx)
235
+ root_to_members.setdefault(r, []).append(cid)
236
+
237
+ modules = []
238
+ module_id = np.full((N,), -1, dtype=int)
239
+
240
+ # assign module ids in stable order
241
+ roots = sorted(root_to_members.keys())
242
+ for mid, r in enumerate(roots):
243
+ member_cids = root_to_members[r]
244
+ member_idxs = (
245
+ np.concatenate([clusters[c] for c in member_cids])
246
+ if member_cids
247
+ else np.array([], dtype=int)
248
+ )
249
+
250
+ # compute module-level stats (median over clusters)
251
+ gs = [cluster_stats[c]["grid_score"] for c in member_cids]
252
+ cs = [cluster_stats[c]["consistency"] for c in member_cids]
253
+ grid_score = float(np.median(gs)) if gs else float("nan")
254
+ consistency = float(np.median(cs)) if cs else float("nan")
255
+
256
+ module_id[member_idxs] = mid
257
+
258
+ modules.append(
259
+ {
260
+ "module_id": mid,
261
+ "clusters": member_cids,
262
+ "indices": member_idxs.astype(int),
263
+ "size": int(member_idxs.size),
264
+ "grid_score": grid_score,
265
+ "consistency": consistency,
266
+ }
267
+ )
268
+
269
+ n_grid_cells = int((module_id != -1).sum())
270
+ out = {
271
+ "cluster_id": cluster_id.astype(int),
272
+ "module_id": module_id.astype(int),
273
+ "n_units": int(N),
274
+ "n_grid_cells": int(n_grid_cells),
275
+ "n_modules": int(len(modules)),
276
+ "modules": modules,
277
+ "params": {
278
+ "center_bins": int(center_bins),
279
+ "k": int(k),
280
+ "resolution": float(resolution),
281
+ "score_thr": float(score_thr),
282
+ "consistency_thr": float(consistency_thr),
283
+ "min_cells": int(min_cells),
284
+ "merge_corr_thr": float(merge_corr_thr),
285
+ "metric": str(metric),
286
+ },
287
+ }
288
+ return out
@@ -0,0 +1,347 @@
1
+ """
2
+ Head Direction Cell Classification
3
+
4
+ Implementation of head direction cell identification based on Mean Vector Length (MVL).
5
+
6
+ Based on MATLAB code from the sweeps analysis pipeline.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+
11
+ import numpy as np
12
+
13
+ from ..utils.circular_stats import circ_mean, circ_r, circ_rtest
14
+
15
+
16
+ @dataclass
17
+ class HDCellResult:
18
+ """
19
+ Results from head direction cell classification.
20
+
21
+ Attributes
22
+ ----------
23
+ is_hd : bool
24
+ Whether the cell is classified as a head direction cell
25
+ mvl_hd : float
26
+ Mean Vector Length for head direction tuning
27
+ preferred_direction : float
28
+ Preferred head direction in radians
29
+ mvl_theta : float or None
30
+ Mean Vector Length for theta phase tuning (if provided)
31
+ tuning_curve : tuple
32
+ Tuple of (bin_centers, firing_rates)
33
+ rayleigh_p : float
34
+ P-value from Rayleigh test for non-uniformity
35
+ """
36
+
37
+ is_hd: bool
38
+ mvl_hd: float
39
+ preferred_direction: float
40
+ mvl_theta: float | None
41
+ tuning_curve: tuple[np.ndarray, np.ndarray]
42
+ rayleigh_p: float
43
+
44
+
45
+ class HeadDirectionAnalyzer:
46
+ """
47
+ Analyzer for classifying head direction cells based on directional tuning.
48
+
49
+ Head direction cells fire when the animal's head points in a specific direction.
50
+ Classification is based on the strength of directional tuning measured by
51
+ Mean Vector Length (MVL).
52
+
53
+ Parameters
54
+ ----------
55
+ mvl_hd_threshold : float, optional
56
+ MVL threshold for head direction. Default is 0.4 (strict).
57
+ Use 0.2 for looser threshold.
58
+ mvl_theta_threshold : float, optional
59
+ MVL threshold for theta phase modulation. Default is 0.3.
60
+ strict_mode : bool, optional
61
+ If True, requires both HD and theta criteria. Default is True.
62
+ n_bins : int, optional
63
+ Number of directional bins for tuning curve. Default is 60 (6° bins).
64
+
65
+ Examples
66
+ --------
67
+ >>> analyzer = HeadDirectionAnalyzer(mvl_hd_threshold=0.4, strict_mode=True)
68
+ >>> result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
69
+ >>> print(f"Is HD cell: {result.is_hd}")
70
+ >>> print(f"MVL: {result.mvl_hd:.3f}")
71
+ >>> print(f"Preferred direction: {np.rad2deg(result.preferred_direction):.1f}°")
72
+
73
+ Notes
74
+ -----
75
+ Based on MATLAB classification from fig2.m and plotSwsExample.m:
76
+ - Strict: MVL_hd > 0.4 AND MVL_theta > 0.3
77
+ - Loose: MVL_hd > 0.2 AND MVL_theta > 0.3
78
+
79
+ References
80
+ ----------
81
+ Classification thresholds follow standard conventions in head direction
82
+ cell literature and the CircStat toolbox.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ mvl_hd_threshold: float = 0.4,
88
+ mvl_theta_threshold: float = 0.3,
89
+ strict_mode: bool = True,
90
+ n_bins: int = 60,
91
+ ):
92
+ self.mvl_hd_threshold = mvl_hd_threshold
93
+ self.mvl_theta_threshold = mvl_theta_threshold
94
+ self.strict_mode = strict_mode
95
+ self.n_bins = n_bins
96
+
97
+ def classify_hd_cell(
98
+ self,
99
+ spike_times: np.ndarray,
100
+ head_directions: np.ndarray,
101
+ time_stamps: np.ndarray,
102
+ theta_phases: np.ndarray | None = None,
103
+ ) -> HDCellResult:
104
+ """
105
+ Classify a cell as head direction cell based on MVL thresholds.
106
+
107
+ Parameters
108
+ ----------
109
+ spike_times : np.ndarray
110
+ Spike times in seconds
111
+ head_directions : np.ndarray
112
+ Head direction at each time point (radians)
113
+ time_stamps : np.ndarray
114
+ Time stamps corresponding to head_directions (seconds)
115
+ theta_phases : np.ndarray, optional
116
+ Theta phase at each time point (radians). If None, theta
117
+ criterion is not checked.
118
+
119
+ Returns
120
+ -------
121
+ result : HDCellResult
122
+ Classification result with MVL, preferred direction, and tuning curve
123
+
124
+ Examples
125
+ --------
126
+ >>> # Simulate a head direction cell
127
+ >>> time_stamps = np.linspace(0, 100, 10000)
128
+ >>> head_directions = np.linspace(0, 20*np.pi, 10000) % (2*np.pi) - np.pi
129
+ >>> preferred_dir = 0.5
130
+ >>> spike_times = time_stamps[np.abs(head_directions - preferred_dir) < 0.3]
131
+ >>> result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
132
+ """
133
+ # Compute directional tuning curve
134
+ bin_centers, firing_rates, occupancy = self.compute_tuning_curve(
135
+ spike_times, head_directions, time_stamps
136
+ )
137
+
138
+ # Compute MVL for head direction
139
+ mvl_hd = self.compute_mvl(bin_centers, weights=firing_rates)
140
+
141
+ # Compute preferred direction
142
+ preferred_direction = circ_mean(bin_centers, w=firing_rates)
143
+
144
+ # Rayleigh test for non-uniformity
145
+ rayleigh_p = circ_rtest(bin_centers, w=firing_rates)
146
+
147
+ # Compute MVL for theta phase if provided
148
+ mvl_theta = None
149
+ if theta_phases is not None:
150
+ # Get theta phases at spike times
151
+ spike_theta = np.interp(spike_times, time_stamps, theta_phases)
152
+ mvl_theta = self.compute_mvl(spike_theta)
153
+
154
+ # Classification logic
155
+ is_hd = mvl_hd > self.mvl_hd_threshold
156
+
157
+ if self.strict_mode and theta_phases is not None:
158
+ # Strict mode: require both HD and theta criteria
159
+ is_hd = is_hd and (mvl_theta > self.mvl_theta_threshold)
160
+ elif self.strict_mode and theta_phases is None:
161
+ # If strict mode but no theta data, just use HD threshold
162
+ pass
163
+
164
+ # Create result
165
+ result = HDCellResult(
166
+ is_hd=is_hd,
167
+ mvl_hd=mvl_hd,
168
+ preferred_direction=preferred_direction,
169
+ mvl_theta=mvl_theta,
170
+ tuning_curve=(bin_centers, firing_rates),
171
+ rayleigh_p=rayleigh_p,
172
+ )
173
+
174
+ return result
175
+
176
+ def compute_tuning_curve(
177
+ self,
178
+ spike_times: np.ndarray,
179
+ head_directions: np.ndarray,
180
+ time_stamps: np.ndarray,
181
+ n_bins: int | None = None,
182
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
183
+ """
184
+ Compute directional tuning curve.
185
+
186
+ Parameters
187
+ ----------
188
+ spike_times : np.ndarray
189
+ Spike times in seconds
190
+ head_directions : np.ndarray
191
+ Head direction at each time point (radians)
192
+ time_stamps : np.ndarray
193
+ Time stamps corresponding to head_directions (seconds)
194
+ n_bins : int, optional
195
+ Number of bins. If None, uses self.n_bins.
196
+
197
+ Returns
198
+ -------
199
+ bin_centers : np.ndarray
200
+ Center of each directional bin (radians)
201
+ firing_rates : np.ndarray
202
+ Firing rate in each bin (Hz)
203
+ occupancy : np.ndarray
204
+ Time spent in each bin (seconds)
205
+
206
+ Examples
207
+ --------
208
+ >>> bins, rates, occ = analyzer.compute_tuning_curve(
209
+ ... spike_times, head_directions, time_stamps
210
+ ... )
211
+ >>> # Plot polar tuning curve
212
+ >>> import matplotlib.pyplot as plt
213
+ >>> ax = plt.subplot(111, projection='polar')
214
+ >>> ax.plot(bins, rates)
215
+ """
216
+ if n_bins is None:
217
+ n_bins = self.n_bins
218
+
219
+ # Define bin edges
220
+ bin_edges = np.linspace(-np.pi, np.pi, n_bins + 1)
221
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
222
+
223
+ # Compute occupancy (time spent in each bin)
224
+ dt = np.median(np.diff(time_stamps)) # Sampling interval
225
+ hd_binned = np.digitize(head_directions, bin_edges) - 1 # 0-indexed
226
+ hd_binned = np.clip(hd_binned, 0, n_bins - 1) # Handle edge cases
227
+
228
+ occupancy = np.bincount(hd_binned, minlength=n_bins) * dt
229
+
230
+ # Count spikes in each bin
231
+ # Interpolate HD at spike times
232
+ spike_hd = np.interp(spike_times, time_stamps, head_directions)
233
+ spike_bins = np.digitize(spike_hd, bin_edges) - 1
234
+ spike_bins = np.clip(spike_bins, 0, n_bins - 1)
235
+
236
+ spike_counts = np.bincount(spike_bins, minlength=n_bins)
237
+
238
+ # Compute firing rates
239
+ firing_rates = np.zeros(n_bins)
240
+ valid = occupancy > 0
241
+ firing_rates[valid] = spike_counts[valid] / occupancy[valid]
242
+
243
+ return bin_centers, firing_rates, occupancy
244
+
245
+ def compute_mvl(self, angles: np.ndarray, weights: np.ndarray | None = None) -> float:
246
+ """
247
+ Compute Mean Vector Length (MVL).
248
+
249
+ The MVL is a measure of circular variance, ranging from 0 (uniform
250
+ distribution) to 1 (concentrated distribution).
251
+
252
+ Parameters
253
+ ----------
254
+ angles : np.ndarray
255
+ Angles in radians
256
+ weights : np.ndarray, optional
257
+ Weights for each angle (e.g., firing rates). If None, uniform weights.
258
+
259
+ Returns
260
+ -------
261
+ mvl : float
262
+ Mean vector length
263
+
264
+ Examples
265
+ --------
266
+ >>> # Concentrated distribution
267
+ >>> angles = np.random.normal(0, 0.1, 100)
268
+ >>> mvl = analyzer.compute_mvl(angles)
269
+ >>> print(f"MVL: {mvl:.3f}") # Should be close to 1
270
+
271
+ >>> # Uniform distribution
272
+ >>> angles = np.random.uniform(-np.pi, np.pi, 100)
273
+ >>> mvl = analyzer.compute_mvl(angles)
274
+ >>> print(f"MVL: {mvl:.3f}") # Should be close to 0
275
+
276
+ Notes
277
+ -----
278
+ Uses the circ_r function from circular statistics utilities.
279
+ """
280
+ return circ_r(angles, w=weights)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ print("Testing HeadDirectionAnalyzer...")
285
+
286
+ # Simulate a head direction cell
287
+ print("\nSimulating head direction cell...")
288
+ time_stamps = np.linspace(0, 100, 10000) # 100 seconds, 100 Hz sampling
289
+ dt = time_stamps[1] - time_stamps[0]
290
+
291
+ # Animal rotates and explores
292
+ angular_velocity = 0.5 # rad/s average
293
+ head_directions = np.cumsum(np.random.randn(len(time_stamps)) * angular_velocity * dt)
294
+ head_directions = np.arctan2(
295
+ np.sin(head_directions), np.cos(head_directions)
296
+ ) # Wrap to [-π, π]
297
+
298
+ # Cell fires preferentially at 0.5 radians (~28°)
299
+ preferred_dir = 0.5
300
+ tuning_width = 0.5 # radians (~28° width)
301
+
302
+ # Generate spikes based on von Mises tuning
303
+ from scipy.stats import vonmises
304
+
305
+ firing_prob = vonmises.pdf(head_directions - preferred_dir, kappa=1 / tuning_width**2)
306
+ firing_prob = firing_prob / firing_prob.max() * 0.1 # Max 10% per bin
307
+
308
+ # Poisson spike generation
309
+ spike_mask = np.random.rand(len(time_stamps)) < firing_prob
310
+ spike_times = time_stamps[spike_mask]
311
+
312
+ print(f"Generated {len(spike_times)} spikes")
313
+ print(f"Mean firing rate: {len(spike_times) / time_stamps[-1]:.2f} Hz")
314
+
315
+ # Classify
316
+ print("\nClassifying cell...")
317
+ analyzer = HeadDirectionAnalyzer(mvl_hd_threshold=0.4, strict_mode=False)
318
+ result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
319
+
320
+ print("\nResults:")
321
+ print(f" Is HD cell: {result.is_hd}")
322
+ print(f" MVL: {result.mvl_hd:.3f}")
323
+ print(f" Preferred direction: {np.rad2deg(result.preferred_direction):.1f}°")
324
+ print(f" True preferred direction: {np.rad2deg(preferred_dir):.1f}°")
325
+ print(f" Rayleigh test p-value: {result.rayleigh_p:.6f}")
326
+
327
+ # Check tuning curve
328
+ bin_centers, firing_rates = result.tuning_curve
329
+ max_rate_idx = np.argmax(firing_rates)
330
+ print(f" Peak firing rate: {firing_rates[max_rate_idx]:.2f} Hz")
331
+ print(f" Peak at direction: {np.rad2deg(bin_centers[max_rate_idx]):.1f}°")
332
+
333
+ # Test with non-directional cell
334
+ print("\n\nSimulating non-directional cell...")
335
+ # Random spikes (Poisson process, no directional tuning)
336
+ mean_rate = 5 # Hz
337
+ n_spikes = int(mean_rate * time_stamps[-1])
338
+ spike_times_random = np.sort(np.random.uniform(0, time_stamps[-1], n_spikes))
339
+
340
+ result_random = analyzer.classify_hd_cell(spike_times_random, head_directions, time_stamps)
341
+
342
+ print("Results for non-directional cell:")
343
+ print(f" Is HD cell: {result_random.is_hd}")
344
+ print(f" MVL: {result_random.mvl_hd:.3f}")
345
+ print(f" Rayleigh test p-value: {result_random.rayleigh_p:.3f}")
346
+
347
+ print("\nHeadDirectionAnalyzer test completed!")