canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,573 +0,0 @@
1
- """
2
- Theta Sweep Pipeline for External Trajectory Analysis
3
-
4
- This module provides a high-level pipeline for experimental scientists to analyze
5
- their trajectory data using CANN theta sweep models without needing to understand
6
- the underlying implementation details.
7
- """
8
-
9
- from pathlib import Path
10
- from typing import Any
11
-
12
- import brainpy.math as bm
13
- import numpy as np
14
-
15
- from ..analyzer.visualization import PlotConfig
16
- from ..analyzer.visualization.theta_sweep_plots import (
17
- create_theta_sweep_grid_cell_animation,
18
- plot_population_activity_with_theta,
19
- )
20
- from ..models.basic.theta_sweep_model import (
21
- DirectionCellNetwork,
22
- GridCellNetwork,
23
- calculate_theta_modulation,
24
- )
25
- from ..task.open_loop_navigation import OpenLoopNavigationTask
26
- from ._base import Pipeline
27
-
28
-
29
- class ThetaSweepPipeline(Pipeline):
30
- """
31
- High-level pipeline for theta sweep analysis of external trajectory data.
32
-
33
- This pipeline abstracts the complex workflow of running CANN theta sweep models
34
- on experimental trajectory data, making it accessible to researchers who want
35
- to analyze neural responses without diving into implementation details.
36
-
37
- Example:
38
- ```python
39
- # Simple usage - just provide trajectory data
40
- pipeline = ThetaSweepPipeline(
41
- trajectory_data=positions, # shape: (n_steps, 2)
42
- times=times # shape: (n_steps,)
43
- )
44
-
45
- results = pipeline.run(output_dir="my_results/")
46
- print(f"Animation saved to: {results['animation_path']}")
47
- ```
48
- """
49
-
50
- def __init__(
51
- self,
52
- trajectory_data: np.ndarray,
53
- times: np.ndarray | None = None,
54
- env_size: float = 2.0,
55
- dt: float = 0.001,
56
- direction_cell_params: dict[str, Any] | None = None,
57
- grid_cell_params: dict[str, Any] | None = None,
58
- theta_params: dict[str, Any] | None = None,
59
- spatial_nav_params: dict[str, Any] | None = None,
60
- ):
61
- """
62
- Initialize the theta sweep pipeline.
63
-
64
- Args:
65
- trajectory_data: Position coordinates with shape (n_steps, 2) for 2D trajectories
66
- times: Optional time array with shape (n_steps,). If None, uniform time steps will be used
67
- env_size: Environment size (assumes square environment)
68
- dt: Simulation time step
69
- direction_cell_params: Parameters for DirectionCellNetwork. If None, uses defaults
70
- grid_cell_params: Parameters for GridCellNetwork. If None, uses defaults
71
- theta_params: Parameters for theta modulation. If None, uses defaults
72
- spatial_nav_params: Additional parameters for OpenLoopNavigationTask. If None, uses defaults
73
- """
74
- super().__init__()
75
- # Store trajectory data
76
- self.trajectory_data = np.array(trajectory_data)
77
- self.times = np.array(times) if times is not None else None
78
- self.env_size = env_size
79
- self.dt = dt
80
-
81
- # Validate trajectory data
82
- self._validate_trajectory_data()
83
-
84
- # Set up default parameters
85
- self.direction_cell_params = self._get_default_direction_cell_params()
86
- if direction_cell_params:
87
- self.direction_cell_params.update(direction_cell_params)
88
-
89
- self.grid_cell_params = self._get_default_grid_cell_params()
90
- if grid_cell_params:
91
- self.grid_cell_params.update(grid_cell_params)
92
-
93
- self.theta_params = self._get_default_theta_params()
94
- if theta_params:
95
- self.theta_params.update(theta_params)
96
-
97
- self.spatial_nav_params = self._get_default_spatial_nav_params()
98
- if spatial_nav_params:
99
- self.spatial_nav_params.update(spatial_nav_params)
100
-
101
- # Initialize components
102
- self.spatial_nav_task = None
103
- self.direction_network = None
104
- self.grid_network = None
105
-
106
- def _validate_trajectory_data(self):
107
- """
108
- Validate input trajectory data format and dimensions.
109
-
110
- Checks:
111
- - Trajectory is 2D array (n_steps, n_dims)
112
- - Only 2D spatial trajectories (n_dims=2)
113
- - At least 2 time steps
114
- - Times array matches trajectory length if provided
115
-
116
- Raises:
117
- ValueError: If validation fails
118
- """
119
- if self.trajectory_data.ndim != 2:
120
- raise ValueError("trajectory_data must be a 2D array with shape (n_steps, n_dims)")
121
-
122
- n_steps, n_dims = self.trajectory_data.shape
123
- if n_dims != 2:
124
- raise ValueError("Currently only 2D trajectories are supported")
125
-
126
- if n_steps < 2:
127
- raise ValueError("trajectory_data must contain at least 2 time steps")
128
-
129
- if self.times is not None:
130
- if self.times.shape[0] != n_steps:
131
- raise ValueError("times array length must match trajectory_data length")
132
-
133
- def _get_default_direction_cell_params(self) -> dict[str, Any]:
134
- """
135
- Get default parameters for DirectionCellNetwork initialization.
136
-
137
- Returns:
138
- dict: Default parameters including:
139
- - num: 100 neurons
140
- - adaptation_strength: 15 for SFA dynamics
141
- - noise_strength: 0.0 (no noise)
142
- """
143
- return {
144
- "num": 100,
145
- "adaptation_strength": 15,
146
- "noise_strength": 0.0,
147
- }
148
-
149
- def _get_default_grid_cell_params(self) -> dict[str, Any]:
150
- """
151
- Get default parameters for GridCellNetwork initialization.
152
-
153
- Returns:
154
- dict: Default parameters including:
155
- - num_gc_x: 100 neurons per dimension (100x100 grid)
156
- - adaptation_strength: 8 for SFA dynamics
157
- - mapping_ratio: 5 (controls grid spacing)
158
- - noise_strength: 0.0 (no noise)
159
- """
160
- return {
161
- "num_gc_x": 100,
162
- "adaptation_strength": 8,
163
- "mapping_ratio": 5,
164
- "noise_strength": 0.0,
165
- }
166
-
167
- def _get_default_theta_params(self) -> dict[str, Any]:
168
- """
169
- Get default parameters for theta oscillation modulation.
170
-
171
- Returns:
172
- dict: Default parameters including:
173
- - theta_strength_hd: 1.0 for direction cells
174
- - theta_strength_gc: 0.5 for grid cells
175
- - theta_cycle_len: 100.0 ms per cycle
176
- """
177
- return {
178
- "theta_strength_hd": 1.0,
179
- "theta_strength_gc": 0.5,
180
- "theta_cycle_len": 100.0,
181
- }
182
-
183
- def _get_default_spatial_nav_params(self) -> dict[str, Any]:
184
- """
185
- Get default parameters for OpenLoopNavigationTask initialization.
186
-
187
- Returns:
188
- dict: Default parameters including environment size, dt, etc.
189
- """
190
- return {
191
- "width": self.env_size,
192
- "height": self.env_size,
193
- "dt": self.dt,
194
- "progress_bar": False,
195
- }
196
-
197
- def _setup_open_loop_navigation_task(self):
198
- """
199
- Set up and configure the spatial navigation task with trajectory data.
200
-
201
- Creates OpenLoopNavigationTask, imports external trajectory data,
202
- and calculates theta sweep parameters (velocity, angular speed, etc.).
203
- """
204
- # Calculate duration from trajectory data
205
- if self.times is not None:
206
- duration = self.times[-1] - self.times[0]
207
- else:
208
- duration = len(self.trajectory_data) * self.dt
209
-
210
- # Create spatial navigation task
211
- self.spatial_nav_task = OpenLoopNavigationTask(duration=duration, **self.spatial_nav_params)
212
-
213
- # Import external trajectory data
214
- self.spatial_nav_task.import_data(
215
- position_data=self.trajectory_data, times=self.times, dt=self.dt
216
- )
217
-
218
- # Calculate theta sweep data
219
- self.spatial_nav_task.calculate_theta_sweep_data()
220
-
221
- def _setup_neural_networks(self):
222
- """
223
- Initialize and configure direction cell and grid cell networks.
224
-
225
- Creates DirectionCellNetwork and GridCellNetwork instances with
226
- configured parameters and initializes their states.
227
- """
228
- # Create direction cell network
229
- self.direction_network = DirectionCellNetwork(**self.direction_cell_params)
230
-
231
- # Create grid cell network (ensure consistency with direction network)
232
- grid_params = self.grid_cell_params.copy()
233
- grid_params["num_dc"] = self.direction_network.num
234
- self.grid_network = GridCellNetwork(**grid_params)
235
-
236
- def _run_simulation(self):
237
- """
238
- Run the main theta sweep simulation loop.
239
-
240
- Executes time-stepped simulation of direction and grid cell networks
241
- with theta modulation. Records neural activity, theta phase, and
242
- decoded positions at each time step.
243
-
244
- Returns:
245
- dict: Simulation results containing:
246
- - dc_activity: Direction cell firing rates over time
247
- - gc_activity: Grid cell firing rates over time
248
- - gc_center_phase: Grid cell bump centers in phase space
249
- - gc_center_position: Decoded positions from grid cells
250
- - theta_phase: Theta oscillation phase over time
251
- """
252
- # Set BrainState environment
253
- bm.set_dt(dt=1.0)
254
-
255
- # Extract data from spatial navigation task
256
- snt_data = self.spatial_nav_task.data
257
- position = snt_data.position
258
- direction = snt_data.hd_angle
259
- linear_speed_gains = snt_data.linear_speed_gains
260
- ang_speed_gains = snt_data.ang_speed_gains
261
-
262
- def run_step(i, pos, hd_angle, linear_gain, ang_gain):
263
- """Single simulation step."""
264
- theta_phase, theta_modulation_hd, theta_modulation_gc = calculate_theta_modulation(
265
- time_step=i,
266
- linear_gain=linear_gain,
267
- ang_gain=ang_gain,
268
- theta_strength_hd=self.theta_params["theta_strength_hd"],
269
- theta_strength_gc=self.theta_params["theta_strength_gc"],
270
- theta_cycle_len=self.theta_params["theta_cycle_len"],
271
- dt=self.dt,
272
- )
273
-
274
- # Update direction cell network
275
- self.direction_network(hd_angle, theta_modulation_hd)
276
- dc_activity = self.direction_network.r.value
277
-
278
- # Update grid cell network
279
- self.grid_network(pos, dc_activity, theta_modulation_gc)
280
- gc_activity = self.grid_network.r.value
281
-
282
- return (
283
- self.grid_network.center_position.value,
284
- self.direction_network.center.value,
285
- gc_activity,
286
- self.grid_network.gc_bump.value,
287
- dc_activity,
288
- theta_phase,
289
- theta_modulation_hd,
290
- theta_modulation_gc,
291
- )
292
-
293
- # Run compiled simulation loop
294
- results = bm.for_loop(
295
- run_step,
296
- bm.arange(len(position)),
297
- position,
298
- direction,
299
- linear_speed_gains,
300
- ang_speed_gains,
301
- pbar=None,
302
- )
303
-
304
- # Unpack results
305
- (
306
- internal_position,
307
- internal_direction,
308
- gc_activity,
309
- gc_bump,
310
- dc_activity,
311
- theta_phase,
312
- theta_modulation_hd,
313
- theta_modulation_gc,
314
- ) = results
315
-
316
- # Store simulation results
317
- self.simulation_results = {
318
- "internal_position": internal_position,
319
- "internal_direction": internal_direction,
320
- "gc_activity": gc_activity,
321
- "gc_bump": gc_bump,
322
- "dc_activity": dc_activity,
323
- "theta_phase": theta_phase,
324
- "theta_modulation_hd": theta_modulation_hd,
325
- "theta_modulation_gc": theta_modulation_gc,
326
- "position": position,
327
- "direction": direction,
328
- "linear_speed_gains": linear_speed_gains,
329
- "ang_speed_gains": ang_speed_gains,
330
- "time_steps": self.spatial_nav_task.run_steps,
331
- }
332
-
333
- def run(
334
- self,
335
- output_dir: str | Path = "theta_sweep_results",
336
- save_animation: bool = True,
337
- save_plots: bool = True,
338
- show_plots: bool = False,
339
- animation_fps: int = 10,
340
- animation_dpi: int = 120,
341
- verbose: bool = True,
342
- ) -> dict[str, Any]:
343
- """
344
- Run the complete theta sweep pipeline.
345
-
346
- Args:
347
- output_dir: Directory to save output files
348
- save_animation: Whether to save the theta sweep animation
349
- save_plots: Whether to save analysis plots
350
- show_plots: Whether to display plots interactively
351
- animation_fps: Frame rate for animation
352
- animation_dpi: DPI for animation output
353
- verbose: Whether to print progress messages
354
-
355
- Returns:
356
- Dictionary containing paths to generated files and analysis data
357
- """
358
- self.reset()
359
- if verbose:
360
- print("🚀 Starting Theta Sweep Pipeline...")
361
-
362
- # Create output directory
363
- output_path = self.prepare_output_dir(output_dir)
364
-
365
- # Setup pipeline components
366
- if verbose:
367
- print("📊 Setting up spatial navigation task...")
368
- self._setup_open_loop_navigation_task()
369
-
370
- if verbose:
371
- print("🧠 Setting up neural networks...")
372
- self._setup_neural_networks()
373
-
374
- if verbose:
375
- print("⚡ Running theta sweep simulation...")
376
- self._run_simulation()
377
-
378
- # Generate outputs
379
- outputs = {"data": self.simulation_results}
380
-
381
- if save_plots or show_plots:
382
- outputs.update(self._generate_plots(output_path, show_plots, verbose))
383
-
384
- if save_animation:
385
- outputs.update(
386
- self._generate_animation(output_path, animation_fps, animation_dpi, verbose)
387
- )
388
-
389
- if verbose:
390
- print("✅ Pipeline completed successfully!")
391
- print(f"📁 Results saved to: {output_path.absolute()}")
392
-
393
- return self.set_results(outputs)
394
-
395
- def _generate_plots(self, output_path: Path, show_plots: bool, verbose: bool) -> dict[str, str]:
396
- """
397
- Generate analysis plots for theta sweep results.
398
-
399
- Creates trajectory analysis and population activity visualizations.
400
-
401
- Args:
402
- output_path: Directory to save plots
403
- show_plots: Whether to display plots interactively
404
- verbose: Whether to print progress messages
405
-
406
- Returns:
407
- dict: Mapping of plot names to file paths
408
- """
409
- plot_outputs = {}
410
-
411
- # Trajectory analysis
412
- if verbose:
413
- print("📈 Generating trajectory analysis...")
414
- trajectory_path = output_path / "trajectory_analysis.png"
415
- self.spatial_nav_task.show_trajectory_analysis(
416
- save_path=str(trajectory_path), show=show_plots, smooth_window=50
417
- )
418
- plot_outputs["trajectory_analysis"] = str(trajectory_path)
419
-
420
- # Population activity with theta
421
- if verbose:
422
- print("📊 Generating population activity plot...")
423
- config_pop = PlotConfig(
424
- title="Direction Cell Population Activity with Theta",
425
- xlabel="Time (s)",
426
- ylabel="Direction (°)",
427
- figsize=(10, 4),
428
- show=show_plots,
429
- save_path=str(output_path / "population_activity.png"),
430
- )
431
-
432
- plot_population_activity_with_theta(
433
- time_steps=self.simulation_results["time_steps"] * self.dt,
434
- theta_phase=self.simulation_results["theta_phase"],
435
- net_activity=self.simulation_results["dc_activity"],
436
- direction=self.simulation_results["direction"],
437
- config=config_pop,
438
- add_lines=True,
439
- atol=5e-2,
440
- )
441
- plot_outputs["population_activity"] = str(output_path / "population_activity.png")
442
-
443
- return plot_outputs
444
-
445
- def _generate_animation(
446
- self, output_path: Path, fps: int, dpi: int, verbose: bool
447
- ) -> dict[str, str]:
448
- """
449
- Generate theta sweep animation showing neural dynamics over time.
450
-
451
- Creates animated visualization of direction and grid cell activity
452
- with theta phase modulation.
453
-
454
- Args:
455
- output_path: Directory to save animation
456
- fps: Frames per second for animation
457
- dpi: Resolution for animation frames
458
- verbose: Whether to print progress messages
459
-
460
- Returns:
461
- dict: Mapping containing 'animation' key with file path
462
- """
463
- animation_path = output_path / "theta_sweep_animation.gif"
464
-
465
- config_animation = PlotConfig(
466
- figsize=(12, 3),
467
- fps=fps,
468
- save_path=str(animation_path),
469
- show=False,
470
- )
471
-
472
- if verbose:
473
- print("🎬 Creating theta sweep animation...")
474
- import sys
475
-
476
- sys.stdout.flush() # Ensure message is printed before animation starts
477
-
478
- # Brief pause to ensure message ordering
479
- import time
480
-
481
- time.sleep(0.01)
482
-
483
- create_theta_sweep_grid_cell_animation(
484
- position_data=self.simulation_results["position"],
485
- direction_data=self.simulation_results["direction"],
486
- dc_activity_data=self.simulation_results["dc_activity"],
487
- gc_activity_data=self.simulation_results["gc_activity"],
488
- gc_network=self.grid_network,
489
- env_size=self.env_size,
490
- mapping_ratio=self.grid_cell_params["mapping_ratio"],
491
- dt=self.dt,
492
- config=config_animation,
493
- n_step=10,
494
- show_progress_bar=verbose,
495
- render_backend="auto",
496
- output_dpi=dpi,
497
- render_worker_batch_size=2,
498
- )
499
-
500
- return {"animation_path": str(animation_path)}
501
-
502
-
503
- # Convenience functions for common use cases
504
-
505
-
506
- def load_trajectory_from_csv(
507
- filepath: str | Path,
508
- x_col: str = "x",
509
- y_col: str = "y",
510
- time_col: str | None = "time",
511
- **kwargs,
512
- ) -> dict[str, Any]:
513
- """
514
- Load trajectory data from CSV file and run theta sweep analysis.
515
-
516
- Args:
517
- filepath: Path to CSV file
518
- x_col: Column name for x coordinates
519
- y_col: Column name for y coordinates
520
- time_col: Column name for time data (optional)
521
- **kwargs: Additional parameters passed to ThetaSweepPipeline
522
-
523
- Returns:
524
- Dictionary containing analysis results and file paths
525
- """
526
- import pandas as pd
527
-
528
- df = pd.read_csv(filepath)
529
-
530
- trajectory_data = df[[x_col, y_col]].values
531
- times = df[time_col].values if time_col and time_col in df.columns else None
532
-
533
- pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs)
534
- return pipeline.run(verbose=True)
535
-
536
-
537
- def batch_process_trajectories(
538
- trajectory_list: list, output_base_dir: str = "batch_results", **kwargs
539
- ) -> dict[str, dict[str, Any]]:
540
- """
541
- Process multiple trajectories in batch.
542
-
543
- Args:
544
- trajectory_list: List of (trajectory_data, times) tuples or trajectory_data arrays
545
- output_base_dir: Base directory for batch results
546
- **kwargs: Additional parameters passed to ThetaSweepPipeline
547
-
548
- Returns:
549
- Dictionary mapping trajectory indices to results
550
- """
551
- batch_results = {}
552
-
553
- for i, trajectory_input in enumerate(trajectory_list):
554
- print(f"\n🔄 Processing trajectory {i + 1}/{len(trajectory_list)}...")
555
-
556
- if isinstance(trajectory_input, tuple):
557
- trajectory_data, times = trajectory_input
558
- else:
559
- trajectory_data, times = trajectory_input, None
560
-
561
- output_dir = Path(output_base_dir) / f"trajectory_{i:03d}"
562
-
563
- try:
564
- pipeline = ThetaSweepPipeline(trajectory_data, times, **kwargs)
565
- results = pipeline.run(output_dir=str(output_dir), verbose=False)
566
- batch_results[f"trajectory_{i:03d}"] = results
567
- print(f"✅ Trajectory {i + 1} completed successfully")
568
-
569
- except Exception as e:
570
- print(f"❌ Error processing trajectory {i + 1}: {e}")
571
- batch_results[f"trajectory_{i:03d}"] = {"error": str(e)}
572
-
573
- return batch_results
File without changes