canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,6 @@
1
+ """Legacy data-analysis modules (deprecated)."""
2
+
3
+ __all__ = [
4
+ "cann1d",
5
+ "cann2d",
6
+ ]
@@ -9,7 +9,7 @@ from scipy.optimize import linear_sum_assignment
9
9
  from scipy.special import i0
10
10
  from tqdm import tqdm
11
11
 
12
- from ..visualization.core.jupyter_utils import (
12
+ from ...visualization.core.jupyter_utils import (
13
13
  display_animation_in_jupyter,
14
14
  is_jupyter_environment,
15
15
  )
@@ -45,7 +45,7 @@ except ImportError:
45
45
  from canns.data.loaders import load_roi_data
46
46
 
47
47
  # Import PlotConfig for unified plotting
48
- from ..visualization import PlotConfig
48
+ from ...visualization import PlotConfig
49
49
 
50
50
 
51
51
  # ==================== Configuration Classes ====================
@@ -29,8 +29,8 @@ from sklearn import preprocessing
29
29
  from tqdm import tqdm
30
30
 
31
31
  # Import PlotConfig for unified plotting
32
- from ..visualization import PlotConfig
33
- from ..visualization.core.jupyter_utils import (
32
+ from ...visualization import PlotConfig
33
+ from ...visualization.core.jupyter_utils import (
34
34
  display_animation_in_jupyter,
35
35
  is_jupyter_environment,
36
36
  )
@@ -2168,7 +2168,7 @@ def plot_2d_bump_on_manifold(
2168
2168
  """
2169
2169
  import matplotlib.animation as animation
2170
2170
 
2171
- from ..visualization.core.jupyter_utils import (
2171
+ from ...visualization.core.jupyter_utils import (
2172
2172
  display_animation_in_jupyter,
2173
2173
  is_jupyter_environment,
2174
2174
  )
@@ -99,7 +99,7 @@ def select_animation_backend(
99
99
  return BackendSelection(
100
100
  backend="imageio",
101
101
  supports_parallel=True,
102
- reason=f"User explicitly requested imageio backend",
102
+ reason="User explicitly requested imageio backend",
103
103
  warnings=[],
104
104
  )
105
105
 
@@ -268,6 +268,83 @@ class PlotConfigs:
268
268
  defaults.update(kwargs)
269
269
  return PlotConfig.for_animation(time_steps, **defaults)
270
270
 
271
+ @staticmethod
272
+ def cohomap(**kwargs: Any) -> PlotConfig:
273
+ defaults: dict[str, Any] = {
274
+ "title": "CohoMap",
275
+ "xlabel": "",
276
+ "ylabel": "",
277
+ "figsize": (10, 4),
278
+ }
279
+ defaults.update(kwargs)
280
+ return PlotConfig.for_static_plot(**defaults)
281
+
282
+ @staticmethod
283
+ def cohospace_trajectory(**kwargs: Any) -> PlotConfig:
284
+ defaults: dict[str, Any] = {
285
+ "title": "CohoSpace trajectory",
286
+ "xlabel": "theta1 (deg)",
287
+ "ylabel": "theta2 (deg)",
288
+ "figsize": (6, 6),
289
+ }
290
+ defaults.update(kwargs)
291
+ return PlotConfig.for_static_plot(**defaults)
292
+
293
+ @staticmethod
294
+ def cohospace_neuron(**kwargs: Any) -> PlotConfig:
295
+ defaults: dict[str, Any] = {
296
+ "title": "Neuron activity on coho-space",
297
+ "xlabel": "Theta 1 (deg)",
298
+ "ylabel": "Theta 2 (deg)",
299
+ "figsize": (6, 6),
300
+ }
301
+ defaults.update(kwargs)
302
+ return PlotConfig.for_static_plot(**defaults)
303
+
304
+ @staticmethod
305
+ def cohospace_population(**kwargs: Any) -> PlotConfig:
306
+ defaults: dict[str, Any] = {
307
+ "title": "Population activity on coho-space",
308
+ "xlabel": "Theta 1 (deg)",
309
+ "ylabel": "Theta 2 (deg)",
310
+ "figsize": (6, 6),
311
+ }
312
+ defaults.update(kwargs)
313
+ return PlotConfig.for_static_plot(**defaults)
314
+
315
+ @staticmethod
316
+ def fr_heatmap(**kwargs: Any) -> PlotConfig:
317
+ defaults: dict[str, Any] = {
318
+ "title": "Firing Rate Heatmap",
319
+ "xlabel": "Time",
320
+ "ylabel": "Neuron",
321
+ "figsize": (10, 5),
322
+ "clabel": "Value",
323
+ }
324
+ defaults.update(kwargs)
325
+ return PlotConfig.for_static_plot(**defaults)
326
+
327
+ @staticmethod
328
+ def frm(**kwargs: Any) -> PlotConfig:
329
+ defaults: dict[str, Any] = {
330
+ "title": "Firing Rate Map",
331
+ "xlabel": "X bin",
332
+ "ylabel": "Y bin",
333
+ "figsize": (6, 5),
334
+ "clabel": "Rate",
335
+ }
336
+ defaults.update(kwargs)
337
+ return PlotConfig.for_static_plot(**defaults)
338
+
339
+ @staticmethod
340
+ def path_compare(**kwargs: Any) -> PlotConfig:
341
+ defaults: dict[str, Any] = {
342
+ "title": "Path Compare",
343
+ "figsize": (12, 5),
344
+ }
345
+ defaults.update(kwargs)
346
+ return PlotConfig.for_static_plot(**defaults)
347
+
271
348
  @staticmethod
272
349
  def raster_plot(mode: str = "block", **kwargs: Any) -> PlotConfig:
273
350
  defaults: dict[str, Any] = {
@@ -325,10 +325,9 @@ def render_animation_parallel(
325
325
  ... else:
326
326
  ... print("imageio not available")
327
327
  """
328
- import os
329
328
  import multiprocessing as mp
330
- import platform
331
- from concurrent.futures import ProcessPoolExecutor
329
+ import os
330
+
332
331
  from tqdm import tqdm
333
332
 
334
333
  # Detect file format
@@ -487,7 +486,12 @@ def _render_mp4_parallel(
487
486
  if IMAGEIO_AVAILABLE:
488
487
  # Try imageio first (simpler, more reliable if ffmpeg plugin available)
489
488
  try:
490
- writer_kwargs = {"fps": fps, "codec": "libx264", "pixelformat": "yuv420p", "bitrate": "5000k"}
489
+ writer_kwargs = {
490
+ "fps": fps,
491
+ "codec": "libx264",
492
+ "pixelformat": "yuv420p",
493
+ "bitrate": "5000k",
494
+ }
491
495
  with imageio.get_writer(save_path, **writer_kwargs) as writer:
492
496
  for frame in frames:
493
497
  # Ensure RGB format
@@ -512,14 +516,14 @@ def _render_mp4_parallel(
512
516
  h, w = frames[0].shape[:2]
513
517
  fig = plt.figure(figsize=(w / 100, h / 100), dpi=100, frameon=False)
514
518
  ax = fig.add_axes([0, 0, 1, 1])
515
- ax.axis('off')
519
+ ax.axis("off")
516
520
 
517
521
  writer = FFMpegWriter(fps=fps, codec="h264", bitrate=5000)
518
522
  with writer.saving(fig, save_path, dpi=100):
519
523
  for frame in frames:
520
524
  ax.clear()
521
525
  ax.imshow(frame)
522
- ax.axis('off')
526
+ ax.axis("off")
523
527
  writer.grab_frame()
524
528
 
525
529
  plt.close(fig)
@@ -45,9 +45,10 @@ def _render_single_energy_1d_frame(
45
45
  options: _Energy1DRenderOptions,
46
46
  ) -> np.ndarray:
47
47
  """Render a single frame for 1D energy landscape animation (module-level for pickling)."""
48
+ from io import BytesIO
49
+
48
50
  import matplotlib.pyplot as plt
49
51
  import numpy as np
50
- from io import BytesIO
51
52
 
52
53
  fig, ax = plt.subplots(figsize=options.figsize)
53
54
  sim_index = options.sim_indices_to_render[frame_index]
@@ -121,9 +122,10 @@ def _render_single_energy_2d_frame(
121
122
  options: _Energy2DRenderOptions,
122
123
  ) -> np.ndarray:
123
124
  """Render a single frame for 2D energy landscape animation (module-level for pickling)."""
125
+ from io import BytesIO
126
+
124
127
  import matplotlib.pyplot as plt
125
128
  import numpy as np
126
- from io import BytesIO
127
129
 
128
130
  fig, ax = plt.subplots(figsize=options.figsize)
129
131
  sim_index = options.sim_indices_to_render[frame_index]
@@ -457,8 +459,12 @@ def energy_landscape_1d_animation(
457
459
 
458
460
  if backend == "imageio":
459
461
  # Use imageio backend with parallel rendering
460
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
461
- ctx, start_method = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
462
+ workers = (
463
+ render_workers if render_workers is not None else get_optimal_worker_count()
464
+ )
465
+ ctx, start_method = get_multiprocessing_context(
466
+ prefer_fork=(render_start_method == "fork")
467
+ )
462
468
 
463
469
  # Create render options
464
470
  render_options = _Energy1DRenderOptions(
@@ -481,9 +487,10 @@ def energy_landscape_1d_animation(
481
487
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
482
488
 
483
489
  try:
484
- import imageio
485
490
  from functools import partial
486
491
 
492
+ import imageio
493
+
487
494
  # Create partial function with options
488
495
  render_func = partial(_render_single_energy_1d_frame, options=render_options)
489
496
 
@@ -519,6 +526,7 @@ def energy_landscape_1d_animation(
519
526
 
520
527
  except Exception as e:
521
528
  import warnings
529
+
522
530
  warnings.warn(
523
531
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
524
532
  RuntimeWarning,
@@ -853,8 +861,12 @@ def energy_landscape_2d_animation(
853
861
 
854
862
  if backend == "imageio":
855
863
  # Use imageio backend with parallel rendering
856
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
857
- ctx, start_method = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
864
+ workers = (
865
+ render_workers if render_workers is not None else get_optimal_worker_count()
866
+ )
867
+ ctx, start_method = get_multiprocessing_context(
868
+ prefer_fork=(render_start_method == "fork")
869
+ )
858
870
 
859
871
  # Create render options
860
872
  render_options = _Energy2DRenderOptions(
@@ -877,9 +889,10 @@ def energy_landscape_2d_animation(
877
889
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
878
890
 
879
891
  try:
880
- import imageio
881
892
  from functools import partial
882
893
 
894
+ import imageio
895
+
883
896
  # Create partial function with options
884
897
  render_func = partial(_render_single_energy_2d_frame, options=render_options)
885
898
 
@@ -915,6 +928,7 @@ def energy_landscape_2d_animation(
915
928
 
916
929
  except Exception as e:
917
930
  import warnings
931
+
918
932
  warnings.warn(
919
933
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
920
934
  RuntimeWarning,
@@ -49,9 +49,10 @@ def _render_single_grid_tracking_frame(
49
49
  options: _GridCellTrackingRenderOptions,
50
50
  ) -> np.ndarray:
51
51
  """Render a single frame for grid cell tracking animation (module-level for pickling)."""
52
+ from io import BytesIO
53
+
52
54
  import matplotlib.pyplot as plt
53
55
  import numpy as np
54
- from io import BytesIO
55
56
 
56
57
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=options.figsize)
57
58
  sim_idx = options.sim_indices_to_render[frame_index]
@@ -84,12 +85,19 @@ def _render_single_grid_tracking_frame(
84
85
 
85
86
  # Panel 3: Rate map with position
86
87
  im = ax3.imshow(
87
- options.rate_map.T, origin="lower", cmap="hot",
88
- extent=[0, options.env_size, 0, options.env_size], aspect="auto"
88
+ options.rate_map.T,
89
+ origin="lower",
90
+ cmap="hot",
91
+ extent=[0, options.env_size, 0, options.env_size],
92
+ aspect="auto",
89
93
  )
90
94
  ax3.plot(
91
- [options.position[sim_idx, 0]], [options.position[sim_idx, 1]], "c*",
92
- markersize=15, markeredgecolor="white", markeredgewidth=1.5
95
+ [options.position[sim_idx, 0]],
96
+ [options.position[sim_idx, 1]],
97
+ "c*",
98
+ markersize=15,
99
+ markeredgecolor="white",
100
+ markeredgewidth=1.5,
93
101
  )
94
102
  ax3.set_xlabel("X Position (m)", fontsize=10)
95
103
  ax3.set_ylabel("Y Position (m)", fontsize=10)
@@ -97,7 +105,9 @@ def _render_single_grid_tracking_frame(
97
105
  plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
98
106
 
99
107
  # Overall title with time
100
- fig.suptitle(f"{options.title} | Time: {current_time_s:.2f} s", fontsize=13, fontweight="bold")
108
+ fig.suptitle(
109
+ f"{options.title} | Time: {current_time_s:.2f} s", fontsize=13, fontweight="bold"
110
+ )
101
111
 
102
112
  fig.tight_layout()
103
113
 
@@ -792,8 +802,12 @@ def create_grid_cell_tracking_animation(
792
802
 
793
803
  if backend == "imageio":
794
804
  # Use imageio backend with parallel rendering
795
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
796
- ctx, start_method = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
805
+ workers = (
806
+ render_workers if render_workers is not None else get_optimal_worker_count()
807
+ )
808
+ ctx, start_method = get_multiprocessing_context(
809
+ prefer_fork=(render_start_method == "fork")
810
+ )
797
811
 
798
812
  # Create render options
799
813
  render_options = _GridCellTrackingRenderOptions(
@@ -813,11 +827,14 @@ def create_grid_cell_tracking_animation(
813
827
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
814
828
 
815
829
  try:
816
- import imageio
817
830
  from functools import partial
818
831
 
832
+ import imageio
833
+
819
834
  # Create partial function with options
820
- render_func = partial(_render_single_grid_tracking_frame, options=render_options)
835
+ render_func = partial(
836
+ _render_single_grid_tracking_frame, options=render_options
837
+ )
821
838
 
822
839
  with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
823
840
  if workers > 1 and ctx is not None:
@@ -851,6 +868,7 @@ def create_grid_cell_tracking_animation(
851
868
 
852
869
  except Exception as e:
853
870
  import warnings
871
+
854
872
  warnings.warn(
855
873
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
856
874
  RuntimeWarning,
@@ -880,7 +898,9 @@ def create_grid_cell_tracking_animation(
880
898
  pbar.update(1)
881
899
 
882
900
  try:
883
- ani.save(config.save_path, writer=writer, progress_callback=progress_callback)
901
+ ani.save(
902
+ config.save_path, writer=writer, progress_callback=progress_callback
903
+ )
884
904
  print(f"Animation saved to: {config.save_path}")
885
905
  finally:
886
906
  pbar.close()
@@ -25,7 +25,6 @@ from tqdm import tqdm
25
25
 
26
26
  from .core.backend import (
27
27
  emit_backend_warnings,
28
- get_imageio_writer_kwargs,
29
28
  get_multiprocessing_context,
30
29
  get_optimal_worker_count,
31
30
  select_animation_backend,
@@ -851,9 +850,10 @@ def _render_single_place_cell_frame(
851
850
  options: _PlaceCellRenderOptions,
852
851
  ) -> np.ndarray:
853
852
  """Render a single frame for place cell animation (module-level for pickling)."""
853
+ from io import BytesIO
854
+
854
855
  import matplotlib.pyplot as plt
855
856
  import numpy as np
856
- from io import BytesIO
857
857
 
858
858
  fig, axes = plt.subplots(1, 2, figsize=options.figsize, width_ratios=[1, 1])
859
859
  ax_env, ax_activity = axes
@@ -1223,7 +1223,9 @@ def create_theta_sweep_place_cell_animation(
1223
1223
  if backend == "imageio":
1224
1224
  # Use imageio backend with parallel rendering
1225
1225
  workers = render_workers if render_workers is not None else get_optimal_worker_count()
1226
- ctx, start_method = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
1226
+ ctx, start_method = get_multiprocessing_context(
1227
+ prefer_fork=(render_start_method == "fork")
1228
+ )
1227
1229
 
1228
1230
  _emit_info(
1229
1231
  f"Parallel rendering enabled: {workers} workers (start_method={start_method})"
@@ -1232,7 +1234,11 @@ def create_theta_sweep_place_cell_animation(
1232
1234
  # Prepare environment data
1233
1235
  env = navigation_task.env
1234
1236
  boundary_array = np.array(env.boundary) if env.boundary is not None else None
1235
- walls_arrays = [np.array(wall) for wall in env.walls] if env.walls is not None and len(env.walls) > 0 else None
1237
+ walls_arrays = (
1238
+ [np.array(wall) for wall in env.walls]
1239
+ if env.walls is not None and len(env.walls) > 0
1240
+ else None
1241
+ )
1236
1242
 
1237
1243
  # Get accessible indices
1238
1244
  accessible_indices = pc_network.geodesic_result.accessible_indices
@@ -1257,11 +1263,14 @@ def create_theta_sweep_place_cell_animation(
1257
1263
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
1258
1264
 
1259
1265
  try:
1260
- import imageio
1261
1266
  from functools import partial
1262
1267
 
1268
+ import imageio
1269
+
1263
1270
  # Create partial function with data and options
1264
- render_func = partial(_render_single_place_cell_frame, data=data, options=render_options)
1271
+ render_func = partial(
1272
+ _render_single_place_cell_frame, data=data, options=render_options
1273
+ )
1265
1274
 
1266
1275
  with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
1267
1276
  if workers > 1 and ctx is not None:
@@ -7,15 +7,11 @@ the underlying implementations.
7
7
  """
8
8
 
9
9
  from ._base import Pipeline
10
- from .theta_sweep import (
11
- ThetaSweepPipeline,
12
- batch_process_trajectories,
13
- load_trajectory_from_csv,
14
- )
10
+ from .asa import ASAApp
11
+ from .asa import main as asa_main
15
12
 
16
13
  __all__ = [
17
14
  "Pipeline",
18
- "ThetaSweepPipeline",
19
- "load_trajectory_from_csv",
20
- "batch_process_trajectories",
15
+ "ASAApp",
16
+ "asa_main",
21
17
  ]
@@ -0,0 +1,21 @@
1
+ """ASA TUI - Terminal User Interface for ASA Analysis.
2
+
3
+ This module provides a Textual-based TUI for running ASA (Attractor State Analysis)
4
+ with 7 analysis modules: TDA, CohoMap, PathCompare, CohoSpace, FR, FRM, and GridScore.
5
+ """
6
+
7
+ import os
8
+
9
+ __all__ = ["ASAApp", "main"]
10
+
11
+
12
+ def main():
13
+ """Entry point for canns-tui command."""
14
+ os.environ.setdefault("MPLBACKEND", "Agg")
15
+ from .app import ASAApp
16
+
17
+ app = ASAApp()
18
+ app.run()
19
+
20
+
21
+ from .app import ASAApp
@@ -0,0 +1,11 @@
1
+ """Main entry point for running ASA TUI as a module."""
2
+
3
+ import os
4
+
5
+ os.environ.setdefault("MPLBACKEND", "Agg")
6
+
7
+ from .app import ASAApp
8
+
9
+ if __name__ == "__main__":
10
+ app = ASAApp()
11
+ app.run()