canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
canns/pipeline/theta_sweep.py
DELETED
|
@@ -1,573 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Theta Sweep Pipeline for External Trajectory Analysis
|
|
3
|
-
|
|
4
|
-
This module provides a high-level pipeline for experimental scientists to analyze
|
|
5
|
-
their trajectory data using CANN theta sweep models without needing to understand
|
|
6
|
-
the underlying implementation details.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any
|
|
11
|
-
|
|
12
|
-
import brainpy.math as bm
|
|
13
|
-
import numpy as np
|
|
14
|
-
|
|
15
|
-
from ..analyzer.visualization import PlotConfig
|
|
16
|
-
from ..analyzer.visualization.theta_sweep_plots import (
|
|
17
|
-
create_theta_sweep_grid_cell_animation,
|
|
18
|
-
plot_population_activity_with_theta,
|
|
19
|
-
)
|
|
20
|
-
from ..models.basic.theta_sweep_model import (
|
|
21
|
-
DirectionCellNetwork,
|
|
22
|
-
GridCellNetwork,
|
|
23
|
-
calculate_theta_modulation,
|
|
24
|
-
)
|
|
25
|
-
from ..task.open_loop_navigation import OpenLoopNavigationTask
|
|
26
|
-
from ._base import Pipeline
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ThetaSweepPipeline(Pipeline):
|
|
30
|
-
"""
|
|
31
|
-
High-level pipeline for theta sweep analysis of external trajectory data.
|
|
32
|
-
|
|
33
|
-
This pipeline abstracts the complex workflow of running CANN theta sweep models
|
|
34
|
-
on experimental trajectory data, making it accessible to researchers who want
|
|
35
|
-
to analyze neural responses without diving into implementation details.
|
|
36
|
-
|
|
37
|
-
Example:
|
|
38
|
-
```python
|
|
39
|
-
# Simple usage - just provide trajectory data
|
|
40
|
-
pipeline = ThetaSweepPipeline(
|
|
41
|
-
trajectory_data=positions, # shape: (n_steps, 2)
|
|
42
|
-
times=times # shape: (n_steps,)
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
results = pipeline.run(output_dir="my_results/")
|
|
46
|
-
print(f"Animation saved to: {results['animation_path']}")
|
|
47
|
-
```
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
trajectory_data: np.ndarray,
|
|
53
|
-
times: np.ndarray | None = None,
|
|
54
|
-
env_size: float = 2.0,
|
|
55
|
-
dt: float = 0.001,
|
|
56
|
-
direction_cell_params: dict[str, Any] | None = None,
|
|
57
|
-
grid_cell_params: dict[str, Any] | None = None,
|
|
58
|
-
theta_params: dict[str, Any] | None = None,
|
|
59
|
-
spatial_nav_params: dict[str, Any] | None = None,
|
|
60
|
-
):
|
|
61
|
-
"""
|
|
62
|
-
Initialize the theta sweep pipeline.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
trajectory_data: Position coordinates with shape (n_steps, 2) for 2D trajectories
|
|
66
|
-
times: Optional time array with shape (n_steps,). If None, uniform time steps will be used
|
|
67
|
-
env_size: Environment size (assumes square environment)
|
|
68
|
-
dt: Simulation time step
|
|
69
|
-
direction_cell_params: Parameters for DirectionCellNetwork. If None, uses defaults
|
|
70
|
-
grid_cell_params: Parameters for GridCellNetwork. If None, uses defaults
|
|
71
|
-
theta_params: Parameters for theta modulation. If None, uses defaults
|
|
72
|
-
spatial_nav_params: Additional parameters for OpenLoopNavigationTask. If None, uses defaults
|
|
73
|
-
"""
|
|
74
|
-
super().__init__()
|
|
75
|
-
# Store trajectory data
|
|
76
|
-
self.trajectory_data = np.array(trajectory_data)
|
|
77
|
-
self.times = np.array(times) if times is not None else None
|
|
78
|
-
self.env_size = env_size
|
|
79
|
-
self.dt = dt
|
|
80
|
-
|
|
81
|
-
# Validate trajectory data
|
|
82
|
-
self._validate_trajectory_data()
|
|
83
|
-
|
|
84
|
-
# Set up default parameters
|
|
85
|
-
self.direction_cell_params = self._get_default_direction_cell_params()
|
|
86
|
-
if direction_cell_params:
|
|
87
|
-
self.direction_cell_params.update(direction_cell_params)
|
|
88
|
-
|
|
89
|
-
self.grid_cell_params = self._get_default_grid_cell_params()
|
|
90
|
-
if grid_cell_params:
|
|
91
|
-
self.grid_cell_params.update(grid_cell_params)
|
|
92
|
-
|
|
93
|
-
self.theta_params = self._get_default_theta_params()
|
|
94
|
-
if theta_params:
|
|
95
|
-
self.theta_params.update(theta_params)
|
|
96
|
-
|
|
97
|
-
self.spatial_nav_params = self._get_default_spatial_nav_params()
|
|
98
|
-
if spatial_nav_params:
|
|
99
|
-
self.spatial_nav_params.update(spatial_nav_params)
|
|
100
|
-
|
|
101
|
-
# Initialize components
|
|
102
|
-
self.spatial_nav_task = None
|
|
103
|
-
self.direction_network = None
|
|
104
|
-
self.grid_network = None
|
|
105
|
-
|
|
106
|
-
def _validate_trajectory_data(self):
|
|
107
|
-
"""
|
|
108
|
-
Validate input trajectory data format and dimensions.
|
|
109
|
-
|
|
110
|
-
Checks:
|
|
111
|
-
- Trajectory is 2D array (n_steps, n_dims)
|
|
112
|
-
- Only 2D spatial trajectories (n_dims=2)
|
|
113
|
-
- At least 2 time steps
|
|
114
|
-
- Times array matches trajectory length if provided
|
|
115
|
-
|
|
116
|
-
Raises:
|
|
117
|
-
ValueError: If validation fails
|
|
118
|
-
"""
|
|
119
|
-
if self.trajectory_data.ndim != 2:
|
|
120
|
-
raise ValueError("trajectory_data must be a 2D array with shape (n_steps, n_dims)")
|
|
121
|
-
|
|
122
|
-
n_steps, n_dims = self.trajectory_data.shape
|
|
123
|
-
if n_dims != 2:
|
|
124
|
-
raise ValueError("Currently only 2D trajectories are supported")
|
|
125
|
-
|
|
126
|
-
if n_steps < 2:
|
|
127
|
-
raise ValueError("trajectory_data must contain at least 2 time steps")
|
|
128
|
-
|
|
129
|
-
if self.times is not None:
|
|
130
|
-
if self.times.shape[0] != n_steps:
|
|
131
|
-
raise ValueError("times array length must match trajectory_data length")
|
|
132
|
-
|
|
133
|
-
def _get_default_direction_cell_params(self) -> dict[str, Any]:
|
|
134
|
-
"""
|
|
135
|
-
Get default parameters for DirectionCellNetwork initialization.
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
dict: Default parameters including:
|
|
139
|
-
- num: 100 neurons
|
|
140
|
-
- adaptation_strength: 15 for SFA dynamics
|
|
141
|
-
- noise_strength: 0.0 (no noise)
|
|
142
|
-
"""
|
|
143
|
-
return {
|
|
144
|
-
"num": 100,
|
|
145
|
-
"adaptation_strength": 15,
|
|
146
|
-
"noise_strength": 0.0,
|
|
147
|
-
}
|
|
148
|
-
|
|
149
|
-
def _get_default_grid_cell_params(self) -> dict[str, Any]:
|
|
150
|
-
"""
|
|
151
|
-
Get default parameters for GridCellNetwork initialization.
|
|
152
|
-
|
|
153
|
-
Returns:
|
|
154
|
-
dict: Default parameters including:
|
|
155
|
-
- num_gc_x: 100 neurons per dimension (100x100 grid)
|
|
156
|
-
- adaptation_strength: 8 for SFA dynamics
|
|
157
|
-
- mapping_ratio: 5 (controls grid spacing)
|
|
158
|
-
- noise_strength: 0.0 (no noise)
|
|
159
|
-
"""
|
|
160
|
-
return {
|
|
161
|
-
"num_gc_x": 100,
|
|
162
|
-
"adaptation_strength": 8,
|
|
163
|
-
"mapping_ratio": 5,
|
|
164
|
-
"noise_strength": 0.0,
|
|
165
|
-
}
|
|
166
|
-
|
|
167
|
-
def _get_default_theta_params(self) -> dict[str, Any]:
|
|
168
|
-
"""
|
|
169
|
-
Get default parameters for theta oscillation modulation.
|
|
170
|
-
|
|
171
|
-
Returns:
|
|
172
|
-
dict: Default parameters including:
|
|
173
|
-
- theta_strength_hd: 1.0 for direction cells
|
|
174
|
-
- theta_strength_gc: 0.5 for grid cells
|
|
175
|
-
- theta_cycle_len: 100.0 ms per cycle
|
|
176
|
-
"""
|
|
177
|
-
return {
|
|
178
|
-
"theta_strength_hd": 1.0,
|
|
179
|
-
"theta_strength_gc": 0.5,
|
|
180
|
-
"theta_cycle_len": 100.0,
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
def _get_default_spatial_nav_params(self) -> dict[str, Any]:
|
|
184
|
-
"""
|
|
185
|
-
Get default parameters for OpenLoopNavigationTask initialization.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
dict: Default parameters including environment size, dt, etc.
|
|
189
|
-
"""
|
|
190
|
-
return {
|
|
191
|
-
"width": self.env_size,
|
|
192
|
-
"height": self.env_size,
|
|
193
|
-
"dt": self.dt,
|
|
194
|
-
"progress_bar": False,
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
def _setup_open_loop_navigation_task(self):
|
|
198
|
-
"""
|
|
199
|
-
Set up and configure the spatial navigation task with trajectory data.
|
|
200
|
-
|
|
201
|
-
Creates OpenLoopNavigationTask, imports external trajectory data,
|
|
202
|
-
and calculates theta sweep parameters (velocity, angular speed, etc.).
|
|
203
|
-
"""
|
|
204
|
-
# Calculate duration from trajectory data
|
|
205
|
-
if self.times is not None:
|
|
206
|
-
duration = self.times[-1] - self.times[0]
|
|
207
|
-
else:
|
|
208
|
-
duration = len(self.trajectory_data) * self.dt
|
|
209
|
-
|
|
210
|
-
# Create spatial navigation task
|
|
211
|
-
self.spatial_nav_task = OpenLoopNavigationTask(duration=duration, **self.spatial_nav_params)
|
|
212
|
-
|
|
213
|
-
# Import external trajectory data
|
|
214
|
-
self.spatial_nav_task.import_data(
|
|
215
|
-
position_data=self.trajectory_data, times=self.times, dt=self.dt
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
# Calculate theta sweep data
|
|
219
|
-
self.spatial_nav_task.calculate_theta_sweep_data()
|
|
220
|
-
|
|
221
|
-
def _setup_neural_networks(self):
|
|
222
|
-
"""
|
|
223
|
-
Initialize and configure direction cell and grid cell networks.
|
|
224
|
-
|
|
225
|
-
Creates DirectionCellNetwork and GridCellNetwork instances with
|
|
226
|
-
configured parameters and initializes their states.
|
|
227
|
-
"""
|
|
228
|
-
# Create direction cell network
|
|
229
|
-
self.direction_network = DirectionCellNetwork(**self.direction_cell_params)
|
|
230
|
-
|
|
231
|
-
# Create grid cell network (ensure consistency with direction network)
|
|
232
|
-
grid_params = self.grid_cell_params.copy()
|
|
233
|
-
grid_params["num_dc"] = self.direction_network.num
|
|
234
|
-
self.grid_network = GridCellNetwork(**grid_params)
|
|
235
|
-
|
|
236
|
-
def _run_simulation(self):
|
|
237
|
-
"""
|
|
238
|
-
Run the main theta sweep simulation loop.
|
|
239
|
-
|
|
240
|
-
Executes time-stepped simulation of direction and grid cell networks
|
|
241
|
-
with theta modulation. Records neural activity, theta phase, and
|
|
242
|
-
decoded positions at each time step.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
dict: Simulation results containing:
|
|
246
|
-
- dc_activity: Direction cell firing rates over time
|
|
247
|
-
- gc_activity: Grid cell firing rates over time
|
|
248
|
-
- gc_center_phase: Grid cell bump centers in phase space
|
|
249
|
-
- gc_center_position: Decoded positions from grid cells
|
|
250
|
-
- theta_phase: Theta oscillation phase over time
|
|
251
|
-
"""
|
|
252
|
-
# Set BrainState environment
|
|
253
|
-
bm.set_dt(dt=1.0)
|
|
254
|
-
|
|
255
|
-
# Extract data from spatial navigation task
|
|
256
|
-
snt_data = self.spatial_nav_task.data
|
|
257
|
-
position = snt_data.position
|
|
258
|
-
direction = snt_data.hd_angle
|
|
259
|
-
linear_speed_gains = snt_data.linear_speed_gains
|
|
260
|
-
ang_speed_gains = snt_data.ang_speed_gains
|
|
261
|
-
|
|
262
|
-
def run_step(i, pos, hd_angle, linear_gain, ang_gain):
|
|
263
|
-
"""Single simulation step."""
|
|
264
|
-
theta_phase, theta_modulation_hd, theta_modulation_gc = calculate_theta_modulation(
|
|
265
|
-
time_step=i,
|
|
266
|
-
linear_gain=linear_gain,
|
|
267
|
-
ang_gain=ang_gain,
|
|
268
|
-
theta_strength_hd=self.theta_params["theta_strength_hd"],
|
|
269
|
-
theta_strength_gc=self.theta_params["theta_strength_gc"],
|
|
270
|
-
theta_cycle_len=self.theta_params["theta_cycle_len"],
|
|
271
|
-
dt=self.dt,
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
# Update direction cell network
|
|
275
|
-
self.direction_network(hd_angle, theta_modulation_hd)
|
|
276
|
-
dc_activity = self.direction_network.r.value
|
|
277
|
-
|
|
278
|
-
# Update grid cell network
|
|
279
|
-
self.grid_network(pos, dc_activity, theta_modulation_gc)
|
|
280
|
-
gc_activity = self.grid_network.r.value
|
|
281
|
-
|
|
282
|
-
return (
|
|
283
|
-
self.grid_network.center_position.value,
|
|
284
|
-
self.direction_network.center.value,
|
|
285
|
-
gc_activity,
|
|
286
|
-
self.grid_network.gc_bump.value,
|
|
287
|
-
dc_activity,
|
|
288
|
-
theta_phase,
|
|
289
|
-
theta_modulation_hd,
|
|
290
|
-
theta_modulation_gc,
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
# Run compiled simulation loop
|
|
294
|
-
results = bm.for_loop(
|
|
295
|
-
run_step,
|
|
296
|
-
bm.arange(len(position)),
|
|
297
|
-
position,
|
|
298
|
-
direction,
|
|
299
|
-
linear_speed_gains,
|
|
300
|
-
ang_speed_gains,
|
|
301
|
-
pbar=None,
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
# Unpack results
|
|
305
|
-
(
|
|
306
|
-
internal_position,
|
|
307
|
-
internal_direction,
|
|
308
|
-
gc_activity,
|
|
309
|
-
gc_bump,
|
|
310
|
-
dc_activity,
|
|
311
|
-
theta_phase,
|
|
312
|
-
theta_modulation_hd,
|
|
313
|
-
theta_modulation_gc,
|
|
314
|
-
) = results
|
|
315
|
-
|
|
316
|
-
# Store simulation results
|
|
317
|
-
self.simulation_results = {
|
|
318
|
-
"internal_position": internal_position,
|
|
319
|
-
"internal_direction": internal_direction,
|
|
320
|
-
"gc_activity": gc_activity,
|
|
321
|
-
"gc_bump": gc_bump,
|
|
322
|
-
"dc_activity": dc_activity,
|
|
323
|
-
"theta_phase": theta_phase,
|
|
324
|
-
"theta_modulation_hd": theta_modulation_hd,
|
|
325
|
-
"theta_modulation_gc": theta_modulation_gc,
|
|
326
|
-
"position": position,
|
|
327
|
-
"direction": direction,
|
|
328
|
-
"linear_speed_gains": linear_speed_gains,
|
|
329
|
-
"ang_speed_gains": ang_speed_gains,
|
|
330
|
-
"time_steps": self.spatial_nav_task.run_steps,
|
|
331
|
-
}
|
|
332
|
-
|
|
333
|
-
def run(
|
|
334
|
-
self,
|
|
335
|
-
output_dir: str | Path = "theta_sweep_results",
|
|
336
|
-
save_animation: bool = True,
|
|
337
|
-
save_plots: bool = True,
|
|
338
|
-
show_plots: bool = False,
|
|
339
|
-
animation_fps: int = 10,
|
|
340
|
-
animation_dpi: int = 120,
|
|
341
|
-
verbose: bool = True,
|
|
342
|
-
) -> dict[str, Any]:
|
|
343
|
-
"""
|
|
344
|
-
Run the complete theta sweep pipeline.
|
|
345
|
-
|
|
346
|
-
Args:
|
|
347
|
-
output_dir: Directory to save output files
|
|
348
|
-
save_animation: Whether to save the theta sweep animation
|
|
349
|
-
save_plots: Whether to save analysis plots
|
|
350
|
-
show_plots: Whether to display plots interactively
|
|
351
|
-
animation_fps: Frame rate for animation
|
|
352
|
-
animation_dpi: DPI for animation output
|
|
353
|
-
verbose: Whether to print progress messages
|
|
354
|
-
|
|
355
|
-
Returns:
|
|
356
|
-
Dictionary containing paths to generated files and analysis data
|
|
357
|
-
"""
|
|
358
|
-
self.reset()
|
|
359
|
-
if verbose:
|
|
360
|
-
print("🚀 Starting Theta Sweep Pipeline...")
|
|
361
|
-
|
|
362
|
-
# Create output directory
|
|
363
|
-
output_path = self.prepare_output_dir(output_dir)
|
|
364
|
-
|
|
365
|
-
# Setup pipeline components
|
|
366
|
-
if verbose:
|
|
367
|
-
print("📊 Setting up spatial navigation task...")
|
|
368
|
-
self._setup_open_loop_navigation_task()
|
|
369
|
-
|
|
370
|
-
if verbose:
|
|
371
|
-
print("🧠 Setting up neural networks...")
|
|
372
|
-
self._setup_neural_networks()
|
|
373
|
-
|
|
374
|
-
if verbose:
|
|
375
|
-
print("⚡ Running theta sweep simulation...")
|
|
376
|
-
self._run_simulation()
|
|
377
|
-
|
|
378
|
-
# Generate outputs
|
|
379
|
-
outputs = {"data": self.simulation_results}
|
|
380
|
-
|
|
381
|
-
if save_plots or show_plots:
|
|
382
|
-
outputs.update(self._generate_plots(output_path, show_plots, verbose))
|
|
383
|
-
|
|
384
|
-
if save_animation:
|
|
385
|
-
outputs.update(
|
|
386
|
-
self._generate_animation(output_path, animation_fps, animation_dpi, verbose)
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
if verbose:
|
|
390
|
-
print("✅ Pipeline completed successfully!")
|
|
391
|
-
print(f"📁 Results saved to: {output_path.absolute()}")
|
|
392
|
-
|
|
393
|
-
return self.set_results(outputs)
|
|
394
|
-
|
|
395
|
-
def _generate_plots(self, output_path: Path, show_plots: bool, verbose: bool) -> dict[str, str]:
|
|
396
|
-
"""
|
|
397
|
-
Generate analysis plots for theta sweep results.
|
|
398
|
-
|
|
399
|
-
Creates trajectory analysis and population activity visualizations.
|
|
400
|
-
|
|
401
|
-
Args:
|
|
402
|
-
output_path: Directory to save plots
|
|
403
|
-
show_plots: Whether to display plots interactively
|
|
404
|
-
verbose: Whether to print progress messages
|
|
405
|
-
|
|
406
|
-
Returns:
|
|
407
|
-
dict: Mapping of plot names to file paths
|
|
408
|
-
"""
|
|
409
|
-
plot_outputs = {}
|
|
410
|
-
|
|
411
|
-
# Trajectory analysis
|
|
412
|
-
if verbose:
|
|
413
|
-
print("📈 Generating trajectory analysis...")
|
|
414
|
-
trajectory_path = output_path / "trajectory_analysis.png"
|
|
415
|
-
self.spatial_nav_task.show_trajectory_analysis(
|
|
416
|
-
save_path=str(trajectory_path), show=show_plots, smooth_window=50
|
|
417
|
-
)
|
|
418
|
-
plot_outputs["trajectory_analysis"] = str(trajectory_path)
|
|
419
|
-
|
|
420
|
-
# Population activity with theta
|
|
421
|
-
if verbose:
|
|
422
|
-
print("📊 Generating population activity plot...")
|
|
423
|
-
config_pop = PlotConfig(
|
|
424
|
-
title="Direction Cell Population Activity with Theta",
|
|
425
|
-
xlabel="Time (s)",
|
|
426
|
-
ylabel="Direction (°)",
|
|
427
|
-
figsize=(10, 4),
|
|
428
|
-
show=show_plots,
|
|
429
|
-
save_path=str(output_path / "population_activity.png"),
|
|
430
|
-
)
|
|
431
|
-
|
|
432
|
-
plot_population_activity_with_theta(
|
|
433
|
-
time_steps=self.simulation_results["time_steps"] * self.dt,
|
|
434
|
-
theta_phase=self.simulation_results["theta_phase"],
|
|
435
|
-
net_activity=self.simulation_results["dc_activity"],
|
|
436
|
-
direction=self.simulation_results["direction"],
|
|
437
|
-
config=config_pop,
|
|
438
|
-
add_lines=True,
|
|
439
|
-
atol=5e-2,
|
|
440
|
-
)
|
|
441
|
-
plot_outputs["population_activity"] = str(output_path / "population_activity.png")
|
|
442
|
-
|
|
443
|
-
return plot_outputs
|
|
444
|
-
|
|
445
|
-
def _generate_animation(
|
|
446
|
-
self, output_path: Path, fps: int, dpi: int, verbose: bool
|
|
447
|
-
) -> dict[str, str]:
|
|
448
|
-
"""
|
|
449
|
-
Generate theta sweep animation showing neural dynamics over time.
|
|
450
|
-
|
|
451
|
-
Creates animated visualization of direction and grid cell activity
|
|
452
|
-
with theta phase modulation.
|
|
453
|
-
|
|
454
|
-
Args:
|
|
455
|
-
output_path: Directory to save animation
|
|
456
|
-
fps: Frames per second for animation
|
|
457
|
-
dpi: Resolution for animation frames
|
|
458
|
-
verbose: Whether to print progress messages
|
|
459
|
-
|
|
460
|
-
Returns:
|
|
461
|
-
dict: Mapping containing 'animation' key with file path
|
|
462
|
-
"""
|
|
463
|
-
animation_path = output_path / "theta_sweep_animation.gif"
|
|
464
|
-
|
|
465
|
-
config_animation = PlotConfig(
|
|
466
|
-
figsize=(12, 3),
|
|
467
|
-
fps=fps,
|
|
468
|
-
save_path=str(animation_path),
|
|
469
|
-
show=False,
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
if verbose:
|
|
473
|
-
print("🎬 Creating theta sweep animation...")
|
|
474
|
-
import sys
|
|
475
|
-
|
|
476
|
-
sys.stdout.flush() # Ensure message is printed before animation starts
|
|
477
|
-
|
|
478
|
-
# Brief pause to ensure message ordering
|
|
479
|
-
import time
|
|
480
|
-
|
|
481
|
-
time.sleep(0.01)
|
|
482
|
-
|
|
483
|
-
create_theta_sweep_grid_cell_animation(
|
|
484
|
-
position_data=self.simulation_results["position"],
|
|
485
|
-
direction_data=self.simulation_results["direction"],
|
|
486
|
-
dc_activity_data=self.simulation_results["dc_activity"],
|
|
487
|
-
gc_activity_data=self.simulation_results["gc_activity"],
|
|
488
|
-
gc_network=self.grid_network,
|
|
489
|
-
env_size=self.env_size,
|
|
490
|
-
mapping_ratio=self.grid_cell_params["mapping_ratio"],
|
|
491
|
-
dt=self.dt,
|
|
492
|
-
config=config_animation,
|
|
493
|
-
n_step=10,
|
|
494
|
-
show_progress_bar=verbose,
|
|
495
|
-
render_backend="auto",
|
|
496
|
-
output_dpi=dpi,
|
|
497
|
-
render_worker_batch_size=2,
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
return {"animation_path": str(animation_path)}
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
# Convenience functions for common use cases
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
def load_trajectory_from_csv(
|
|
507
|
-
filepath: str | Path,
|
|
508
|
-
x_col: str = "x",
|
|
509
|
-
y_col: str = "y",
|
|
510
|
-
time_col: str | None = "time",
|
|
511
|
-
**kwargs,
|
|
512
|
-
) -> dict[str, Any]:
|
|
513
|
-
"""
|
|
514
|
-
Load trajectory data from CSV file and run theta sweep analysis.
|
|
515
|
-
|
|
516
|
-
Args:
|
|
517
|
-
filepath: Path to CSV file
|
|
518
|
-
x_col: Column name for x coordinates
|
|
519
|
-
y_col: Column name for y coordinates
|
|
520
|
-
time_col: Column name for time data (optional)
|
|
521
|
-
**kwargs: Additional parameters passed to ThetaSweepPipeline
|
|
522
|
-
|
|
523
|
-
Returns:
|
|
524
|
-
Dictionary containing analysis results and file paths
|
|
525
|
-
"""
|
|
526
|
-
import pandas as pd
|
|
527
|
-
|
|
528
|
-
df = pd.read_csv(filepath)
|
|
529
|
-
|
|
530
|
-
trajectory_data = df[[x_col, y_col]].values
|
|
531
|
-
times = df[time_col].values if time_col and time_col in df.columns else None
|
|
532
|
-
|
|
533
|
-
pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs)
|
|
534
|
-
return pipeline.run(verbose=True)
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
def batch_process_trajectories(
|
|
538
|
-
trajectory_list: list, output_base_dir: str = "batch_results", **kwargs
|
|
539
|
-
) -> dict[str, dict[str, Any]]:
|
|
540
|
-
"""
|
|
541
|
-
Process multiple trajectories in batch.
|
|
542
|
-
|
|
543
|
-
Args:
|
|
544
|
-
trajectory_list: List of (trajectory_data, times) tuples or trajectory_data arrays
|
|
545
|
-
output_base_dir: Base directory for batch results
|
|
546
|
-
**kwargs: Additional parameters passed to ThetaSweepPipeline
|
|
547
|
-
|
|
548
|
-
Returns:
|
|
549
|
-
Dictionary mapping trajectory indices to results
|
|
550
|
-
"""
|
|
551
|
-
batch_results = {}
|
|
552
|
-
|
|
553
|
-
for i, trajectory_input in enumerate(trajectory_list):
|
|
554
|
-
print(f"\n🔄 Processing trajectory {i + 1}/{len(trajectory_list)}...")
|
|
555
|
-
|
|
556
|
-
if isinstance(trajectory_input, tuple):
|
|
557
|
-
trajectory_data, times = trajectory_input
|
|
558
|
-
else:
|
|
559
|
-
trajectory_data, times = trajectory_input, None
|
|
560
|
-
|
|
561
|
-
output_dir = Path(output_base_dir) / f"trajectory_{i:03d}"
|
|
562
|
-
|
|
563
|
-
try:
|
|
564
|
-
pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs)
|
|
565
|
-
results = pipeline.run(output_dir=str(output_dir), verbose=False)
|
|
566
|
-
batch_results[f"trajectory_{i:03d}"] = results
|
|
567
|
-
print(f"✅ Trajectory {i + 1} completed successfully")
|
|
568
|
-
|
|
569
|
-
except Exception as e:
|
|
570
|
-
print(f"❌ Error processing trajectory {i + 1}: {e}")
|
|
571
|
-
batch_results[f"trajectory_{i:03d}"] = {"error": str(e)}
|
|
572
|
-
|
|
573
|
-
return batch_results
|
|
File without changes
|
|
File without changes
|