canns 0.13.2__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/pipeline/__init__.py +13 -1
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/app.py +5 -7
- canns/pipeline/gallery/runner.py +16 -9
- canns/pipeline/gallery/state.py +0 -1
- {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/entry_points.txt +1 -0
- canns-0.13.2.dist-info/RECORD +0 -95
- {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.2.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _base_mask(shape: tuple[int, int], center_bins: int = 2) -> np.ndarray:
|
|
9
|
+
"""Create a shared mask: center disk + outside circle (corners)."""
|
|
10
|
+
h, w = shape
|
|
11
|
+
cy = (h - 1) / 2.0
|
|
12
|
+
cx = (w - 1) / 2.0
|
|
13
|
+
yy, xx = np.ogrid[:h, :w]
|
|
14
|
+
rr = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
|
|
15
|
+
|
|
16
|
+
# outer circle radius: half-min dimension
|
|
17
|
+
outer_r = min(cy, cx)
|
|
18
|
+
mask_outer = rr > outer_r
|
|
19
|
+
|
|
20
|
+
mask_center = rr <= float(center_bins)
|
|
21
|
+
return mask_outer | mask_center
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _safe_corr(a: np.ndarray, b: np.ndarray) -> float:
|
|
25
|
+
a = np.asarray(a).ravel()
|
|
26
|
+
b = np.asarray(b).ravel()
|
|
27
|
+
if a.size != b.size or a.size == 0:
|
|
28
|
+
return 0.0
|
|
29
|
+
a = a - a.mean()
|
|
30
|
+
b = b - b.mean()
|
|
31
|
+
da = np.sqrt((a * a).mean())
|
|
32
|
+
db = np.sqrt((b * b).mean())
|
|
33
|
+
if da <= 1e-12 or db <= 1e-12:
|
|
34
|
+
return 0.0
|
|
35
|
+
return float((a * b).mean() / (da * db))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _vectorize_autocorrs(
|
|
39
|
+
autocorrs: np.ndarray, center_bins: int = 2
|
|
40
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
41
|
+
"""Vectorize autocorrs into point-cloud matrix X and return the shared mask."""
|
|
42
|
+
if autocorrs.ndim != 3:
|
|
43
|
+
raise ValueError(f"autocorrs must be (N,H,W), got {autocorrs.shape}")
|
|
44
|
+
N, H, W = autocorrs.shape
|
|
45
|
+
mask = _base_mask((H, W), center_bins=center_bins)
|
|
46
|
+
# Use shared mask only; replace NaN with 0 so dimensions match across cells.
|
|
47
|
+
ac = np.nan_to_num(autocorrs, nan=0.0, posinf=0.0, neginf=0.0)
|
|
48
|
+
X = ac[:, ~mask] # (N, n_features)
|
|
49
|
+
return X.astype(np.float32, copy=False), mask
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _build_knn_graph(X: np.ndarray, k: int = 30, metric: str = "manhattan"):
|
|
53
|
+
"""Return an igraph graph built from kNN with edge weights."""
|
|
54
|
+
try:
|
|
55
|
+
from sklearn.neighbors import NearestNeighbors
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise ImportError(f"scikit-learn is required for kNN graph: {e}") from e
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
import igraph as ig
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise ImportError(f"python-igraph is required for Leiden clustering: {e}") from e
|
|
63
|
+
|
|
64
|
+
N = X.shape[0]
|
|
65
|
+
k_eff = min(max(int(k), 1), max(N - 1, 1))
|
|
66
|
+
|
|
67
|
+
nbrs = NearestNeighbors(n_neighbors=k_eff + 1, metric=metric)
|
|
68
|
+
nbrs.fit(X)
|
|
69
|
+
dist, ind = nbrs.kneighbors(X, return_distance=True)
|
|
70
|
+
|
|
71
|
+
edges = []
|
|
72
|
+
weights = []
|
|
73
|
+
eps = 1e-6
|
|
74
|
+
seen = set()
|
|
75
|
+
for i in range(N):
|
|
76
|
+
for jj in range(1, k_eff + 1):
|
|
77
|
+
j = int(ind[i, jj])
|
|
78
|
+
if i == j:
|
|
79
|
+
continue
|
|
80
|
+
a, b = (i, j) if i < j else (j, i)
|
|
81
|
+
if (a, b) in seen:
|
|
82
|
+
continue
|
|
83
|
+
seen.add((a, b))
|
|
84
|
+
d = float(dist[i, jj])
|
|
85
|
+
w = 1.0 / (d + eps)
|
|
86
|
+
edges.append((a, b))
|
|
87
|
+
weights.append(w)
|
|
88
|
+
|
|
89
|
+
g = ig.Graph(n=N, edges=edges, directed=False)
|
|
90
|
+
g.es["weight"] = weights
|
|
91
|
+
return g
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _leiden_membership(graph, resolution: float = 1.0) -> np.ndarray:
|
|
95
|
+
"""Run Leiden and return membership array."""
|
|
96
|
+
try:
|
|
97
|
+
import leidenalg
|
|
98
|
+
except Exception as e:
|
|
99
|
+
raise ImportError(f"leidenalg is required for Leiden clustering: {e}") from e
|
|
100
|
+
|
|
101
|
+
part = leidenalg.find_partition(
|
|
102
|
+
graph,
|
|
103
|
+
leidenalg.RBConfigurationVertexPartition,
|
|
104
|
+
weights=graph.es["weight"] if "weight" in graph.es.attributes() else None,
|
|
105
|
+
resolution_parameter=float(resolution),
|
|
106
|
+
)
|
|
107
|
+
return np.asarray(part.membership, dtype=int)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class _DSU:
|
|
111
|
+
def __init__(self, n: int):
|
|
112
|
+
self.p = list(range(n))
|
|
113
|
+
self.r = [0] * n
|
|
114
|
+
|
|
115
|
+
def find(self, a: int) -> int:
|
|
116
|
+
while self.p[a] != a:
|
|
117
|
+
self.p[a] = self.p[self.p[a]]
|
|
118
|
+
a = self.p[a]
|
|
119
|
+
return a
|
|
120
|
+
|
|
121
|
+
def union(self, a: int, b: int):
|
|
122
|
+
ra, rb = self.find(a), self.find(b)
|
|
123
|
+
if ra == rb:
|
|
124
|
+
return
|
|
125
|
+
if self.r[ra] < self.r[rb]:
|
|
126
|
+
ra, rb = rb, ra
|
|
127
|
+
self.p[rb] = ra
|
|
128
|
+
if self.r[ra] == self.r[rb]:
|
|
129
|
+
self.r[ra] += 1
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def identify_grid_modules_and_stats(
|
|
133
|
+
autocorrs: np.ndarray,
|
|
134
|
+
*,
|
|
135
|
+
gridness_analyzer,
|
|
136
|
+
center_bins: int = 2,
|
|
137
|
+
k: int = 30,
|
|
138
|
+
resolution: float = 1.0,
|
|
139
|
+
score_thr: float = 0.3,
|
|
140
|
+
consistency_thr: float = 0.5,
|
|
141
|
+
min_cells: int = 10,
|
|
142
|
+
merge_corr_thr: float = 0.7,
|
|
143
|
+
metric: str = "manhattan",
|
|
144
|
+
) -> dict[str, Any]:
|
|
145
|
+
"""Identify grid modules with Leiden clustering on autocorrelogram point cloud.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
autocorrs : np.ndarray
|
|
150
|
+
Array of shape (N, H, W).
|
|
151
|
+
gridness_analyzer :
|
|
152
|
+
An instance that provides compute_gridness_score(autocorr)->GridnessResult.
|
|
153
|
+
center_bins : int
|
|
154
|
+
Radius (in bins) to mask around center peak.
|
|
155
|
+
k : int
|
|
156
|
+
Neighbors for kNN graph.
|
|
157
|
+
resolution : float
|
|
158
|
+
Leiden resolution parameter.
|
|
159
|
+
score_thr, consistency_thr, min_cells, merge_corr_thr
|
|
160
|
+
Module acceptance and merging thresholds.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
dict with keys:
|
|
165
|
+
module_id (N,), cluster_id (N,), modules (list of dict), params
|
|
166
|
+
"""
|
|
167
|
+
if autocorrs.ndim != 3:
|
|
168
|
+
raise ValueError(f"autocorrs must be (N,H,W). Got {autocorrs.shape}")
|
|
169
|
+
N, H, W = autocorrs.shape
|
|
170
|
+
|
|
171
|
+
X, mask = _vectorize_autocorrs(autocorrs, center_bins=int(center_bins))
|
|
172
|
+
g = _build_knn_graph(X, k=int(k), metric=str(metric))
|
|
173
|
+
cluster_id = _leiden_membership(g, resolution=float(resolution))
|
|
174
|
+
|
|
175
|
+
# group members
|
|
176
|
+
clusters: dict[int, np.ndarray] = {}
|
|
177
|
+
for cid in np.unique(cluster_id):
|
|
178
|
+
clusters[int(cid)] = np.where(cluster_id == cid)[0]
|
|
179
|
+
|
|
180
|
+
base_mask = mask # shared
|
|
181
|
+
ac = np.nan_to_num(autocorrs, nan=0.0, posinf=0.0, neginf=0.0)
|
|
182
|
+
|
|
183
|
+
cluster_stats = {}
|
|
184
|
+
candidate_cids = []
|
|
185
|
+
|
|
186
|
+
# compute cluster metrics
|
|
187
|
+
for cid, idxs in clusters.items():
|
|
188
|
+
if idxs.size == 0:
|
|
189
|
+
continue
|
|
190
|
+
avg = ac[idxs].mean(axis=0)
|
|
191
|
+
med = np.median(ac[idxs], axis=0)
|
|
192
|
+
|
|
193
|
+
gr = gridness_analyzer.compute_gridness_score(med)
|
|
194
|
+
grid_score = float(getattr(gr, "score", getattr(gr, "grid_score", np.nan)))
|
|
195
|
+
|
|
196
|
+
flat_avg = avg[~base_mask]
|
|
197
|
+
cors = []
|
|
198
|
+
for i in idxs:
|
|
199
|
+
c = _safe_corr(flat_avg, ac[i][~base_mask])
|
|
200
|
+
cors.append(c)
|
|
201
|
+
consistency = float(np.median(cors)) if len(cors) else 0.0
|
|
202
|
+
|
|
203
|
+
cluster_stats[cid] = {
|
|
204
|
+
"cid": cid,
|
|
205
|
+
"size": int(idxs.size),
|
|
206
|
+
"grid_score": grid_score,
|
|
207
|
+
"consistency": consistency,
|
|
208
|
+
"gridness_result": gr,
|
|
209
|
+
"avg_autocorr": avg,
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if (
|
|
213
|
+
(grid_score > float(score_thr))
|
|
214
|
+
and (consistency > float(consistency_thr))
|
|
215
|
+
and (idxs.size >= int(min_cells))
|
|
216
|
+
):
|
|
217
|
+
candidate_cids.append(cid)
|
|
218
|
+
|
|
219
|
+
# Merge candidate clusters based on avg autocorr correlation
|
|
220
|
+
cand = list(candidate_cids)
|
|
221
|
+
dsu = _DSU(len(cand))
|
|
222
|
+
for i in range(len(cand)):
|
|
223
|
+
for j in range(i + 1, len(cand)):
|
|
224
|
+
ci, cj = cand[i], cand[j]
|
|
225
|
+
ai = cluster_stats[ci]["avg_autocorr"][~base_mask]
|
|
226
|
+
aj = cluster_stats[cj]["avg_autocorr"][~base_mask]
|
|
227
|
+
corr = _safe_corr(ai, aj)
|
|
228
|
+
if corr > float(merge_corr_thr):
|
|
229
|
+
dsu.union(i, j)
|
|
230
|
+
|
|
231
|
+
# Build merged modules
|
|
232
|
+
root_to_members: dict[int, list[int]] = {}
|
|
233
|
+
for idx, cid in enumerate(cand):
|
|
234
|
+
r = dsu.find(idx)
|
|
235
|
+
root_to_members.setdefault(r, []).append(cid)
|
|
236
|
+
|
|
237
|
+
modules = []
|
|
238
|
+
module_id = np.full((N,), -1, dtype=int)
|
|
239
|
+
|
|
240
|
+
# assign module ids in stable order
|
|
241
|
+
roots = sorted(root_to_members.keys())
|
|
242
|
+
for mid, r in enumerate(roots):
|
|
243
|
+
member_cids = root_to_members[r]
|
|
244
|
+
member_idxs = (
|
|
245
|
+
np.concatenate([clusters[c] for c in member_cids])
|
|
246
|
+
if member_cids
|
|
247
|
+
else np.array([], dtype=int)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# compute module-level stats (median over clusters)
|
|
251
|
+
gs = [cluster_stats[c]["grid_score"] for c in member_cids]
|
|
252
|
+
cs = [cluster_stats[c]["consistency"] for c in member_cids]
|
|
253
|
+
grid_score = float(np.median(gs)) if gs else float("nan")
|
|
254
|
+
consistency = float(np.median(cs)) if cs else float("nan")
|
|
255
|
+
|
|
256
|
+
module_id[member_idxs] = mid
|
|
257
|
+
|
|
258
|
+
modules.append(
|
|
259
|
+
{
|
|
260
|
+
"module_id": mid,
|
|
261
|
+
"clusters": member_cids,
|
|
262
|
+
"indices": member_idxs.astype(int),
|
|
263
|
+
"size": int(member_idxs.size),
|
|
264
|
+
"grid_score": grid_score,
|
|
265
|
+
"consistency": consistency,
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
n_grid_cells = int((module_id != -1).sum())
|
|
270
|
+
out = {
|
|
271
|
+
"cluster_id": cluster_id.astype(int),
|
|
272
|
+
"module_id": module_id.astype(int),
|
|
273
|
+
"n_units": int(N),
|
|
274
|
+
"n_grid_cells": int(n_grid_cells),
|
|
275
|
+
"n_modules": int(len(modules)),
|
|
276
|
+
"modules": modules,
|
|
277
|
+
"params": {
|
|
278
|
+
"center_bins": int(center_bins),
|
|
279
|
+
"k": int(k),
|
|
280
|
+
"resolution": float(resolution),
|
|
281
|
+
"score_thr": float(score_thr),
|
|
282
|
+
"consistency_thr": float(consistency_thr),
|
|
283
|
+
"min_cells": int(min_cells),
|
|
284
|
+
"merge_corr_thr": float(merge_corr_thr),
|
|
285
|
+
"metric": str(metric),
|
|
286
|
+
},
|
|
287
|
+
}
|
|
288
|
+
return out
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Head Direction Cell Classification
|
|
3
|
+
|
|
4
|
+
Implementation of head direction cell identification based on Mean Vector Length (MVL).
|
|
5
|
+
|
|
6
|
+
Based on MATLAB code from the sweeps analysis pipeline.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from ..utils.circular_stats import circ_mean, circ_r, circ_rtest
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class HDCellResult:
|
|
18
|
+
"""
|
|
19
|
+
Results from head direction cell classification.
|
|
20
|
+
|
|
21
|
+
Attributes
|
|
22
|
+
----------
|
|
23
|
+
is_hd : bool
|
|
24
|
+
Whether the cell is classified as a head direction cell
|
|
25
|
+
mvl_hd : float
|
|
26
|
+
Mean Vector Length for head direction tuning
|
|
27
|
+
preferred_direction : float
|
|
28
|
+
Preferred head direction in radians
|
|
29
|
+
mvl_theta : float or None
|
|
30
|
+
Mean Vector Length for theta phase tuning (if provided)
|
|
31
|
+
tuning_curve : tuple
|
|
32
|
+
Tuple of (bin_centers, firing_rates)
|
|
33
|
+
rayleigh_p : float
|
|
34
|
+
P-value from Rayleigh test for non-uniformity
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
is_hd: bool
|
|
38
|
+
mvl_hd: float
|
|
39
|
+
preferred_direction: float
|
|
40
|
+
mvl_theta: float | None
|
|
41
|
+
tuning_curve: tuple[np.ndarray, np.ndarray]
|
|
42
|
+
rayleigh_p: float
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class HeadDirectionAnalyzer:
|
|
46
|
+
"""
|
|
47
|
+
Analyzer for classifying head direction cells based on directional tuning.
|
|
48
|
+
|
|
49
|
+
Head direction cells fire when the animal's head points in a specific direction.
|
|
50
|
+
Classification is based on the strength of directional tuning measured by
|
|
51
|
+
Mean Vector Length (MVL).
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
mvl_hd_threshold : float, optional
|
|
56
|
+
MVL threshold for head direction. Default is 0.4 (strict).
|
|
57
|
+
Use 0.2 for looser threshold.
|
|
58
|
+
mvl_theta_threshold : float, optional
|
|
59
|
+
MVL threshold for theta phase modulation. Default is 0.3.
|
|
60
|
+
strict_mode : bool, optional
|
|
61
|
+
If True, requires both HD and theta criteria. Default is True.
|
|
62
|
+
n_bins : int, optional
|
|
63
|
+
Number of directional bins for tuning curve. Default is 60 (6° bins).
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> analyzer = HeadDirectionAnalyzer(mvl_hd_threshold=0.4, strict_mode=True)
|
|
68
|
+
>>> result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
|
|
69
|
+
>>> print(f"Is HD cell: {result.is_hd}")
|
|
70
|
+
>>> print(f"MVL: {result.mvl_hd:.3f}")
|
|
71
|
+
>>> print(f"Preferred direction: {np.rad2deg(result.preferred_direction):.1f}°")
|
|
72
|
+
|
|
73
|
+
Notes
|
|
74
|
+
-----
|
|
75
|
+
Based on MATLAB classification from fig2.m and plotSwsExample.m:
|
|
76
|
+
- Strict: MVL_hd > 0.4 AND MVL_theta > 0.3
|
|
77
|
+
- Loose: MVL_hd > 0.2 AND MVL_theta > 0.3
|
|
78
|
+
|
|
79
|
+
References
|
|
80
|
+
----------
|
|
81
|
+
Classification thresholds follow standard conventions in head direction
|
|
82
|
+
cell literature and the CircStat toolbox.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
mvl_hd_threshold: float = 0.4,
|
|
88
|
+
mvl_theta_threshold: float = 0.3,
|
|
89
|
+
strict_mode: bool = True,
|
|
90
|
+
n_bins: int = 60,
|
|
91
|
+
):
|
|
92
|
+
self.mvl_hd_threshold = mvl_hd_threshold
|
|
93
|
+
self.mvl_theta_threshold = mvl_theta_threshold
|
|
94
|
+
self.strict_mode = strict_mode
|
|
95
|
+
self.n_bins = n_bins
|
|
96
|
+
|
|
97
|
+
def classify_hd_cell(
|
|
98
|
+
self,
|
|
99
|
+
spike_times: np.ndarray,
|
|
100
|
+
head_directions: np.ndarray,
|
|
101
|
+
time_stamps: np.ndarray,
|
|
102
|
+
theta_phases: np.ndarray | None = None,
|
|
103
|
+
) -> HDCellResult:
|
|
104
|
+
"""
|
|
105
|
+
Classify a cell as head direction cell based on MVL thresholds.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
spike_times : np.ndarray
|
|
110
|
+
Spike times in seconds
|
|
111
|
+
head_directions : np.ndarray
|
|
112
|
+
Head direction at each time point (radians)
|
|
113
|
+
time_stamps : np.ndarray
|
|
114
|
+
Time stamps corresponding to head_directions (seconds)
|
|
115
|
+
theta_phases : np.ndarray, optional
|
|
116
|
+
Theta phase at each time point (radians). If None, theta
|
|
117
|
+
criterion is not checked.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
result : HDCellResult
|
|
122
|
+
Classification result with MVL, preferred direction, and tuning curve
|
|
123
|
+
|
|
124
|
+
Examples
|
|
125
|
+
--------
|
|
126
|
+
>>> # Simulate a head direction cell
|
|
127
|
+
>>> time_stamps = np.linspace(0, 100, 10000)
|
|
128
|
+
>>> head_directions = np.linspace(0, 20*np.pi, 10000) % (2*np.pi) - np.pi
|
|
129
|
+
>>> preferred_dir = 0.5
|
|
130
|
+
>>> spike_times = time_stamps[np.abs(head_directions - preferred_dir) < 0.3]
|
|
131
|
+
>>> result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
|
|
132
|
+
"""
|
|
133
|
+
# Compute directional tuning curve
|
|
134
|
+
bin_centers, firing_rates, occupancy = self.compute_tuning_curve(
|
|
135
|
+
spike_times, head_directions, time_stamps
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Compute MVL for head direction
|
|
139
|
+
mvl_hd = self.compute_mvl(bin_centers, weights=firing_rates)
|
|
140
|
+
|
|
141
|
+
# Compute preferred direction
|
|
142
|
+
preferred_direction = circ_mean(bin_centers, w=firing_rates)
|
|
143
|
+
|
|
144
|
+
# Rayleigh test for non-uniformity
|
|
145
|
+
rayleigh_p = circ_rtest(bin_centers, w=firing_rates)
|
|
146
|
+
|
|
147
|
+
# Compute MVL for theta phase if provided
|
|
148
|
+
mvl_theta = None
|
|
149
|
+
if theta_phases is not None:
|
|
150
|
+
# Get theta phases at spike times
|
|
151
|
+
spike_theta = np.interp(spike_times, time_stamps, theta_phases)
|
|
152
|
+
mvl_theta = self.compute_mvl(spike_theta)
|
|
153
|
+
|
|
154
|
+
# Classification logic
|
|
155
|
+
is_hd = mvl_hd > self.mvl_hd_threshold
|
|
156
|
+
|
|
157
|
+
if self.strict_mode and theta_phases is not None:
|
|
158
|
+
# Strict mode: require both HD and theta criteria
|
|
159
|
+
is_hd = is_hd and (mvl_theta > self.mvl_theta_threshold)
|
|
160
|
+
elif self.strict_mode and theta_phases is None:
|
|
161
|
+
# If strict mode but no theta data, just use HD threshold
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
# Create result
|
|
165
|
+
result = HDCellResult(
|
|
166
|
+
is_hd=is_hd,
|
|
167
|
+
mvl_hd=mvl_hd,
|
|
168
|
+
preferred_direction=preferred_direction,
|
|
169
|
+
mvl_theta=mvl_theta,
|
|
170
|
+
tuning_curve=(bin_centers, firing_rates),
|
|
171
|
+
rayleigh_p=rayleigh_p,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return result
|
|
175
|
+
|
|
176
|
+
def compute_tuning_curve(
|
|
177
|
+
self,
|
|
178
|
+
spike_times: np.ndarray,
|
|
179
|
+
head_directions: np.ndarray,
|
|
180
|
+
time_stamps: np.ndarray,
|
|
181
|
+
n_bins: int | None = None,
|
|
182
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
183
|
+
"""
|
|
184
|
+
Compute directional tuning curve.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
spike_times : np.ndarray
|
|
189
|
+
Spike times in seconds
|
|
190
|
+
head_directions : np.ndarray
|
|
191
|
+
Head direction at each time point (radians)
|
|
192
|
+
time_stamps : np.ndarray
|
|
193
|
+
Time stamps corresponding to head_directions (seconds)
|
|
194
|
+
n_bins : int, optional
|
|
195
|
+
Number of bins. If None, uses self.n_bins.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
bin_centers : np.ndarray
|
|
200
|
+
Center of each directional bin (radians)
|
|
201
|
+
firing_rates : np.ndarray
|
|
202
|
+
Firing rate in each bin (Hz)
|
|
203
|
+
occupancy : np.ndarray
|
|
204
|
+
Time spent in each bin (seconds)
|
|
205
|
+
|
|
206
|
+
Examples
|
|
207
|
+
--------
|
|
208
|
+
>>> bins, rates, occ = analyzer.compute_tuning_curve(
|
|
209
|
+
... spike_times, head_directions, time_stamps
|
|
210
|
+
... )
|
|
211
|
+
>>> # Plot polar tuning curve
|
|
212
|
+
>>> import matplotlib.pyplot as plt
|
|
213
|
+
>>> ax = plt.subplot(111, projection='polar')
|
|
214
|
+
>>> ax.plot(bins, rates)
|
|
215
|
+
"""
|
|
216
|
+
if n_bins is None:
|
|
217
|
+
n_bins = self.n_bins
|
|
218
|
+
|
|
219
|
+
# Define bin edges
|
|
220
|
+
bin_edges = np.linspace(-np.pi, np.pi, n_bins + 1)
|
|
221
|
+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
222
|
+
|
|
223
|
+
# Compute occupancy (time spent in each bin)
|
|
224
|
+
dt = np.median(np.diff(time_stamps)) # Sampling interval
|
|
225
|
+
hd_binned = np.digitize(head_directions, bin_edges) - 1 # 0-indexed
|
|
226
|
+
hd_binned = np.clip(hd_binned, 0, n_bins - 1) # Handle edge cases
|
|
227
|
+
|
|
228
|
+
occupancy = np.bincount(hd_binned, minlength=n_bins) * dt
|
|
229
|
+
|
|
230
|
+
# Count spikes in each bin
|
|
231
|
+
# Interpolate HD at spike times
|
|
232
|
+
spike_hd = np.interp(spike_times, time_stamps, head_directions)
|
|
233
|
+
spike_bins = np.digitize(spike_hd, bin_edges) - 1
|
|
234
|
+
spike_bins = np.clip(spike_bins, 0, n_bins - 1)
|
|
235
|
+
|
|
236
|
+
spike_counts = np.bincount(spike_bins, minlength=n_bins)
|
|
237
|
+
|
|
238
|
+
# Compute firing rates
|
|
239
|
+
firing_rates = np.zeros(n_bins)
|
|
240
|
+
valid = occupancy > 0
|
|
241
|
+
firing_rates[valid] = spike_counts[valid] / occupancy[valid]
|
|
242
|
+
|
|
243
|
+
return bin_centers, firing_rates, occupancy
|
|
244
|
+
|
|
245
|
+
def compute_mvl(self, angles: np.ndarray, weights: np.ndarray | None = None) -> float:
|
|
246
|
+
"""
|
|
247
|
+
Compute Mean Vector Length (MVL).
|
|
248
|
+
|
|
249
|
+
The MVL is a measure of circular variance, ranging from 0 (uniform
|
|
250
|
+
distribution) to 1 (concentrated distribution).
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
angles : np.ndarray
|
|
255
|
+
Angles in radians
|
|
256
|
+
weights : np.ndarray, optional
|
|
257
|
+
Weights for each angle (e.g., firing rates). If None, uniform weights.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
mvl : float
|
|
262
|
+
Mean vector length
|
|
263
|
+
|
|
264
|
+
Examples
|
|
265
|
+
--------
|
|
266
|
+
>>> # Concentrated distribution
|
|
267
|
+
>>> angles = np.random.normal(0, 0.1, 100)
|
|
268
|
+
>>> mvl = analyzer.compute_mvl(angles)
|
|
269
|
+
>>> print(f"MVL: {mvl:.3f}") # Should be close to 1
|
|
270
|
+
|
|
271
|
+
>>> # Uniform distribution
|
|
272
|
+
>>> angles = np.random.uniform(-np.pi, np.pi, 100)
|
|
273
|
+
>>> mvl = analyzer.compute_mvl(angles)
|
|
274
|
+
>>> print(f"MVL: {mvl:.3f}") # Should be close to 0
|
|
275
|
+
|
|
276
|
+
Notes
|
|
277
|
+
-----
|
|
278
|
+
Uses the circ_r function from circular statistics utilities.
|
|
279
|
+
"""
|
|
280
|
+
return circ_r(angles, w=weights)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
if __name__ == "__main__":
|
|
284
|
+
print("Testing HeadDirectionAnalyzer...")
|
|
285
|
+
|
|
286
|
+
# Simulate a head direction cell
|
|
287
|
+
print("\nSimulating head direction cell...")
|
|
288
|
+
time_stamps = np.linspace(0, 100, 10000) # 100 seconds, 100 Hz sampling
|
|
289
|
+
dt = time_stamps[1] - time_stamps[0]
|
|
290
|
+
|
|
291
|
+
# Animal rotates and explores
|
|
292
|
+
angular_velocity = 0.5 # rad/s average
|
|
293
|
+
head_directions = np.cumsum(np.random.randn(len(time_stamps)) * angular_velocity * dt)
|
|
294
|
+
head_directions = np.arctan2(
|
|
295
|
+
np.sin(head_directions), np.cos(head_directions)
|
|
296
|
+
) # Wrap to [-π, π]
|
|
297
|
+
|
|
298
|
+
# Cell fires preferentially at 0.5 radians (~28°)
|
|
299
|
+
preferred_dir = 0.5
|
|
300
|
+
tuning_width = 0.5 # radians (~28° width)
|
|
301
|
+
|
|
302
|
+
# Generate spikes based on von Mises tuning
|
|
303
|
+
from scipy.stats import vonmises
|
|
304
|
+
|
|
305
|
+
firing_prob = vonmises.pdf(head_directions - preferred_dir, kappa=1 / tuning_width**2)
|
|
306
|
+
firing_prob = firing_prob / firing_prob.max() * 0.1 # Max 10% per bin
|
|
307
|
+
|
|
308
|
+
# Poisson spike generation
|
|
309
|
+
spike_mask = np.random.rand(len(time_stamps)) < firing_prob
|
|
310
|
+
spike_times = time_stamps[spike_mask]
|
|
311
|
+
|
|
312
|
+
print(f"Generated {len(spike_times)} spikes")
|
|
313
|
+
print(f"Mean firing rate: {len(spike_times) / time_stamps[-1]:.2f} Hz")
|
|
314
|
+
|
|
315
|
+
# Classify
|
|
316
|
+
print("\nClassifying cell...")
|
|
317
|
+
analyzer = HeadDirectionAnalyzer(mvl_hd_threshold=0.4, strict_mode=False)
|
|
318
|
+
result = analyzer.classify_hd_cell(spike_times, head_directions, time_stamps)
|
|
319
|
+
|
|
320
|
+
print("\nResults:")
|
|
321
|
+
print(f" Is HD cell: {result.is_hd}")
|
|
322
|
+
print(f" MVL: {result.mvl_hd:.3f}")
|
|
323
|
+
print(f" Preferred direction: {np.rad2deg(result.preferred_direction):.1f}°")
|
|
324
|
+
print(f" True preferred direction: {np.rad2deg(preferred_dir):.1f}°")
|
|
325
|
+
print(f" Rayleigh test p-value: {result.rayleigh_p:.6f}")
|
|
326
|
+
|
|
327
|
+
# Check tuning curve
|
|
328
|
+
bin_centers, firing_rates = result.tuning_curve
|
|
329
|
+
max_rate_idx = np.argmax(firing_rates)
|
|
330
|
+
print(f" Peak firing rate: {firing_rates[max_rate_idx]:.2f} Hz")
|
|
331
|
+
print(f" Peak at direction: {np.rad2deg(bin_centers[max_rate_idx]):.1f}°")
|
|
332
|
+
|
|
333
|
+
# Test with non-directional cell
|
|
334
|
+
print("\n\nSimulating non-directional cell...")
|
|
335
|
+
# Random spikes (Poisson process, no directional tuning)
|
|
336
|
+
mean_rate = 5 # Hz
|
|
337
|
+
n_spikes = int(mean_rate * time_stamps[-1])
|
|
338
|
+
spike_times_random = np.sort(np.random.uniform(0, time_stamps[-1], n_spikes))
|
|
339
|
+
|
|
340
|
+
result_random = analyzer.classify_hd_cell(spike_times_random, head_directions, time_stamps)
|
|
341
|
+
|
|
342
|
+
print("Results for non-directional cell:")
|
|
343
|
+
print(f" Is HD cell: {result_random.is_hd}")
|
|
344
|
+
print(f" MVL: {result_random.mvl_hd:.3f}")
|
|
345
|
+
print(f" Rayleigh test p-value: {result_random.rayleigh_p:.3f}")
|
|
346
|
+
|
|
347
|
+
print("\nHeadDirectionAnalyzer test completed!")
|