canns 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,12 @@ from .config import (
18
18
  )
19
19
  from .decode import decode_circular_coordinates, decode_circular_coordinates_multi
20
20
  from .embedding import embed_spike_trains
21
+ from .fly_roi import (
22
+ BumpFitsConfig,
23
+ CANN1DPlotConfig,
24
+ create_1d_bump_animation,
25
+ roi_bump_fits,
26
+ )
21
27
  from .fr import (
22
28
  FRMResult,
23
29
  compute_fr_heatmap_matrix,
@@ -60,6 +66,10 @@ __all__ = [
60
66
  "plot_cohomap_multi",
61
67
  "plot_3d_bump_on_torus",
62
68
  "plot_2d_bump_on_manifold",
69
+ "BumpFitsConfig",
70
+ "CANN1DPlotConfig",
71
+ "create_1d_bump_animation",
72
+ "roi_bump_fits",
63
73
  "compute_fr_heatmap_matrix",
64
74
  "save_fr_heatmap_png",
65
75
  "FRMResult",
@@ -33,7 +33,7 @@ def decode_circular_coordinates(
33
33
  real_of : bool
34
34
  Whether the experiment is open-field (controls box coordinate handling).
35
35
  save_path : str, optional
36
- Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
36
+ Path to save decoding results. If ``None``, results are not saved.
37
37
 
38
38
  Returns
39
39
  -------
@@ -174,13 +174,12 @@ def decode_circular_coordinates(
174
174
  "centsinall": centsinall,
175
175
  }
176
176
 
177
- # Save results
178
- if save_path is None:
179
- os.makedirs("Results", exist_ok=True)
180
- save_path = "Results/spikes_decoding.npz"
181
-
182
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
183
- np.savez_compressed(save_path, **results)
177
+ # Save results (only when requested)
178
+ if save_path is not None:
179
+ save_dir = os.path.dirname(save_path)
180
+ if save_dir:
181
+ os.makedirs(save_dir, exist_ok=True)
182
+ np.savez_compressed(save_path, **results)
184
183
 
185
184
  return results
186
185
 
@@ -264,13 +263,12 @@ def decode_circular_coordinates1(
264
263
  "centsinall": centsinall,
265
264
  }
266
265
 
267
- # Save results
268
- if save_path is None:
269
- os.makedirs("Results", exist_ok=True)
270
- save_path = "Results/spikes_decoding.npz"
271
-
272
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
273
- np.savez_compressed(save_path, **results)
266
+ # Save results (only when requested)
267
+ if save_path is not None:
268
+ save_dir = os.path.dirname(save_path)
269
+ if save_dir:
270
+ os.makedirs(save_dir, exist_ok=True)
271
+ np.savez_compressed(save_path, **results)
274
272
 
275
273
  return results
276
274
 
@@ -291,7 +289,7 @@ def decode_circular_coordinates_multi(
291
289
  spike_data : dict
292
290
  Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
293
291
  save_path : str, optional
294
- Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
292
+ Path to save decoding results. If ``None``, results are not saved.
295
293
  num_circ : int
296
294
  Number of H1 cocycles/circular coordinates to decode.
297
295
 
@@ -383,11 +381,10 @@ def decode_circular_coordinates_multi(
383
381
  "centsinall": centsinall,
384
382
  }
385
383
 
386
- if save_path is None:
387
- os.makedirs("Results", exist_ok=True)
388
- save_path = "Results/spikes_decoding.npz"
389
-
390
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
384
+ if save_path is not None:
385
+ save_dir = os.path.dirname(save_path)
386
+ if save_dir:
387
+ os.makedirs(save_dir, exist_ok=True)
391
388
  np.savez_compressed(save_path, **results)
392
389
  return results
393
390
 
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  import numpy as np
6
6
  from matplotlib import pyplot as plt
7
- from matplotlib.animation import FuncAnimation, PillowWriter
7
+ from matplotlib.animation import FuncAnimation
8
8
  from scipy.optimize import linear_sum_assignment
9
9
  from scipy.special import i0
10
10
  from tqdm import tqdm
@@ -66,22 +66,6 @@ class BumpFitsConfig:
66
66
  random_seed: int | None = None
67
67
 
68
68
 
69
- @dataclass
70
- class AnimationConfig:
71
- """Configuration for 1D CANN bump animation."""
72
-
73
- show: bool = False
74
- max_height_value: float = 0.5
75
- max_width_range: int = 40
76
- npoints: int = 300
77
- nframes: int | None = None
78
- fps: int = 5
79
- bump_selection: str = "strongest"
80
- show_progress_bar: bool = True
81
- repeat: bool = False
82
- title: str = "1D CANN Bump Animation"
83
-
84
-
85
69
  @dataclass
86
70
  class CANN1DPlotConfig(PlotConfig):
87
71
  """Specialized PlotConfig for CANN1D visualizations."""
@@ -141,9 +125,9 @@ class AnimationError(CANN1DError):
141
125
  pass
142
126
 
143
127
 
144
- def bump_fits(data, config: BumpFitsConfig | None = None, save_path=None, **kwargs):
128
+ def roi_bump_fits(data, config: BumpFitsConfig | None = None, save_path=None, **kwargs):
145
129
  """
146
- Fit CANN1D bumps to data using MCMC optimization.
130
+ Fit CANN1D bumps to ROI data using MCMC optimization.
147
131
 
148
132
  Parameters:
149
133
  data : numpy.ndarray
@@ -318,10 +302,10 @@ def create_1d_bump_animation(
318
302
  Parameters:
319
303
  fits_data : numpy.ndarray
320
304
  Shape (n_fits, 4) array with columns [time, position, amplitude, kappa]
321
- config : AnimationConfig, optional
305
+ config : CANN1DPlotConfig, optional
322
306
  Configuration object with all animation parameters
323
307
  save_path : str, optional
324
- Output path for the generated GIF
308
+ Output path for the generated animation (e.g. .gif or .mp4)
325
309
  **kwargs : backward compatibility parameters
326
310
 
327
311
  Returns:
@@ -336,6 +320,8 @@ def create_1d_bump_animation(
336
320
  for key, value in kwargs.items():
337
321
  if hasattr(config, key):
338
322
  setattr(config, key, value)
323
+ if save_path is not None:
324
+ config.save_path = save_path
339
325
 
340
326
  try:
341
327
  # ==== Smoothing functions ====
@@ -520,7 +506,10 @@ def create_1d_bump_animation(
520
506
  fig, update, frames=nframes, init_func=init, blit=use_blitting, repeat=config.repeat
521
507
  )
522
508
 
523
- # Save animation with progress tracking
509
+ ani = None
510
+ progress_bar_enabled = getattr(config, "show_progress_bar", True)
511
+
512
+ # Save animation with unified backend selection
524
513
  if config.save_path:
525
514
  # Warn if both saving and showing (causes double rendering)
526
515
  if config.show and nframes > 50:
@@ -528,31 +517,95 @@ def create_1d_bump_animation(
528
517
 
529
518
  warn_double_rendering(nframes, config.save_path, stacklevel=2)
530
519
 
531
- if config.show_progress_bar:
532
- pbar = tqdm(total=nframes, desc=f"Saving animation to {config.save_path}")
520
+ from ...visualization.core import (
521
+ emit_backend_warnings,
522
+ get_imageio_writer_kwargs,
523
+ get_matplotlib_writer,
524
+ select_animation_backend,
525
+ )
533
526
 
534
- def progress_callback(current_frame, total_frames):
535
- pbar.update(1)
527
+ backend_selection = select_animation_backend(
528
+ save_path=config.save_path,
529
+ requested_backend=getattr(config, "render_backend", None),
530
+ check_imageio_plugins=True,
531
+ )
532
+ emit_backend_warnings(backend_selection.warnings, stacklevel=2)
533
+ backend = backend_selection.backend
536
534
 
535
+ if backend == "imageio":
537
536
  try:
538
- ani.save(
539
- config.save_path,
540
- writer=PillowWriter(fps=config.fps),
541
- progress_callback=progress_callback,
542
- )
543
- pbar.close()
544
- print(f"\nAnimation successfully saved to: {config.save_path}")
545
- except Exception as e:
546
- pbar.close()
547
- print(f"\nError saving animation: {e}")
548
- raise
549
- else:
550
- try:
551
- ani.save(config.save_path, writer=PillowWriter(fps=config.fps))
537
+ import imageio
538
+
539
+ writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
540
+ with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
541
+ frames_iter = range(nframes)
542
+ if progress_bar_enabled:
543
+ frames_iter = tqdm(
544
+ frames_iter,
545
+ desc=f"Rendering {config.save_path}",
546
+ )
547
+
548
+ init()
549
+ for frame_idx in frames_iter:
550
+ update(frame_idx)
551
+ fig.canvas.draw()
552
+ frame = np.asarray(fig.canvas.buffer_rgba())
553
+ if frame.shape[-1] == 4:
554
+ frame = frame[:, :, :3]
555
+ writer.append_data(frame)
556
+
552
557
  print(f"Animation saved to: {config.save_path}")
553
558
  except Exception as e:
554
- print(f"Error saving animation: {e}")
555
- raise
559
+ import warnings
560
+
561
+ warnings.warn(
562
+ f"imageio rendering failed: {e}. Falling back to matplotlib.",
563
+ RuntimeWarning,
564
+ stacklevel=2,
565
+ )
566
+ backend = "matplotlib"
567
+
568
+ if backend == "matplotlib":
569
+ ani = FuncAnimation(
570
+ fig,
571
+ update,
572
+ frames=nframes,
573
+ init_func=init,
574
+ blit=use_blitting,
575
+ repeat=config.repeat,
576
+ )
577
+
578
+ writer = get_matplotlib_writer(config.save_path, fps=config.fps)
579
+
580
+ if progress_bar_enabled:
581
+ pbar = tqdm(total=nframes, desc=f"Saving to {config.save_path}")
582
+
583
+ def progress_callback(current_frame, total_frames):
584
+ pbar.update(1)
585
+
586
+ try:
587
+ ani.save(
588
+ config.save_path,
589
+ writer=writer,
590
+ progress_callback=progress_callback,
591
+ )
592
+ print(f"Animation saved to: {config.save_path}")
593
+ finally:
594
+ pbar.close()
595
+ else:
596
+ ani.save(config.save_path, writer=writer)
597
+ print(f"Animation saved to: {config.save_path}")
598
+
599
+ # Create animation object for showing (if not already created)
600
+ if config.show and ani is None:
601
+ ani = FuncAnimation(
602
+ fig,
603
+ update,
604
+ frames=nframes,
605
+ init_func=init,
606
+ blit=use_blitting,
607
+ repeat=config.repeat,
608
+ )
556
609
 
557
610
  if config.show:
558
611
  # Automatically detect Jupyter and display as HTML/JS
@@ -1043,7 +1096,7 @@ def _mcmc(
1043
1096
 
1044
1097
  if __name__ == "__main__":
1045
1098
  data = load_roi_data()
1046
- bumps, fits, nbump, centrbump = bump_fits(
1099
+ bumps, fits, nbump, centrbump = roi_bump_fits(
1047
1100
  data, n_steps=5000, n_roi=16, save_path=os.path.join(os.getcwd(), "test.npz")
1048
1101
  )
1049
1102
 
@@ -135,7 +135,8 @@ def save_fr_heatmap_png(
135
135
  Plot configuration. Use ``config.save_path`` to specify output file.
136
136
  **kwargs : Any
137
137
  Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
138
- as a fallback if not set in ``config``.
138
+ as a fallback if not set in ``config``. If ``save_path`` is omitted, the
139
+ figure is only displayed when ``show=True``.
139
140
 
140
141
  Notes
141
142
  -----
@@ -172,11 +173,6 @@ def save_fr_heatmap_png(
172
173
  if not config.ylabel:
173
174
  config.ylabel = ylabel
174
175
 
175
- if config.save_path is None:
176
- raise ValueError(
177
- "save_path must be provided via config.save_path or as a keyword argument."
178
- )
179
-
180
176
  config.save_dpi = dpi
181
177
 
182
178
  M = np.asarray(M)
@@ -391,7 +387,8 @@ def plot_frm(
391
387
  Plot configuration. Use ``config.save_path`` to specify output file.
392
388
  **kwargs : Any
393
389
  Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
394
- as a fallback if not set in ``config``.
390
+ as a fallback if not set in ``config``. If ``save_path`` is omitted, the
391
+ figure is only displayed when ``show=True``.
395
392
 
396
393
  Examples
397
394
  --------
@@ -423,11 +420,6 @@ def plot_frm(
423
420
  if not config.ylabel:
424
421
  config.ylabel = "Y bin"
425
422
 
426
- if config.save_path is None:
427
- raise ValueError(
428
- "save_path must be provided via config.save_path or as a keyword argument."
429
- )
430
-
431
423
  config.save_dpi = dpi
432
424
 
433
425
  frm = np.asarray(frm)
@@ -295,8 +295,19 @@ def plot_path_compare(
295
295
  ax0 = axes[0]
296
296
  ax0.set_title("Physical path (x,y)")
297
297
  ax0.set_aspect("equal", "box")
298
- ax0.axis("off")
299
298
  ax0.plot(x, y, lw=0.9, alpha=0.8)
299
+ # Keep a visible frame while hiding ticks for a clean path outline.
300
+ ax0.set_xticks([])
301
+ ax0.set_yticks([])
302
+ for spine in ax0.spines.values():
303
+ spine.set_visible(True)
304
+ # Add a small padding so the frame doesn't touch the path.
305
+ x_min, x_max = np.min(x), np.max(x)
306
+ y_min, y_max = np.min(y), np.max(y)
307
+ pad_x = (x_max - x_min) * 0.03 if x_max > x_min else 1.0
308
+ pad_y = (y_max - y_min) * 0.03 if y_max > y_min else 1.0
309
+ ax0.set_xlim(x_min - pad_x, x_max + pad_x)
310
+ ax0.set_ylim(y_min - pad_y, y_max + pad_y)
300
311
 
301
312
  ax1 = axes[1]
302
313
  ax1.set_title("Decoded coho path")
@@ -3,10 +3,10 @@
3
3
  This module provides reusable UI components for the ASA analysis interface.
4
4
  """
5
5
 
6
- from pathlib import Path
7
6
  import os
8
7
  import subprocess
9
8
  import sys
9
+ from pathlib import Path
10
10
 
11
11
  from rich.ansi import AnsiDecoder
12
12
  from rich.text import Text
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: canns
3
- Version: 0.13.0
3
+ Version: 0.13.1
4
4
  Summary: A Python Library for Continuous Attractor Neural Networks
5
5
  Project-URL: Repository, https://github.com/routhleck/canns
6
6
  Author-email: Sichao He <sichaohe@outlook.com>
@@ -3,19 +3,17 @@ canns/_version.py,sha256=zIvJPOGBFvo4VV6f586rlO_bvhuFp1fsxjf6xhsqkJY,1547
3
3
  canns/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  canns/analyzer/__init__.py,sha256=EQ02fYHkpMADp-ojpVCVtapuSPkl6j5WVfdPy0mOTs4,506
5
5
  canns/analyzer/data/__init__.py,sha256=5LJ8Q4SyDpi6CAzyW3lSpra3I0BgG20kmQMceONw7A4,158
6
- canns/analyzer/data/asa/__init__.py,sha256=MNsWjFFWFbnDbYPgvF43Kj4Tv8pPpqmdJnlNbK2_fbI,1670
6
+ canns/analyzer/data/asa/__init__.py,sha256=vfZCggZ_FvXwlYXLKErc_eltN6nOQiHWIivGYqpt8NM,1885
7
7
  canns/analyzer/data/asa/cohospace.py,sha256=n6DfPWUH_k67KrPbNMt1EUdxMuRFFmXCaqh4gcuXngM,27847
8
8
  canns/analyzer/data/asa/config.py,sha256=zVs-snLkD93hGc97GEX9CkcjTtUqORvMCs3khHhqd64,6288
9
- canns/analyzer/data/asa/decode.py,sha256=5haC6WclPIDQgPtnDfjKoWFbDwJVyaBzTEiRbdPnbyU,16579
9
+ canns/analyzer/data/asa/decode.py,sha256=NG8vVx2cPG7uSJDovnC2vzk0dsqU8oR4jaNPxxrvCc0,16501
10
10
  canns/analyzer/data/asa/embedding.py,sha256=rT3hOHk6BzHF1X5YznIFxE_dDveklbT0fx9GtA2w3z0,9550
11
11
  canns/analyzer/data/asa/filters.py,sha256=D-1mDVn4hBEAphKUgx1gQEUfgbghKcNQhZmr4xEExQA,7146
12
- canns/analyzer/data/asa/fr.py,sha256=VRvruZl5NBfK7QEsRLsqHSftelHFXghzJfBU1nov-YQ,13593
12
+ canns/analyzer/data/asa/fly_roi.py,sha256=_scBOd-4t9yv_1tHk7wbXJwPieU-L-QtFJY6fhHpxDI,38031
13
+ canns/analyzer/data/asa/fr.py,sha256=jt99H50e1RRAQgMIdkfK0rBbembZJEr9SMrxK-ZI_LA,13449
13
14
  canns/analyzer/data/asa/path.py,sha256=p3r8EGcJi8NewNFutr3kdO-ekdU_5Icpy6nAoELzSq4,12233
14
- canns/analyzer/data/asa/plotting.py,sha256=7EfVSUERju6lggXyZ8U_BkJznaoTgwY2w3aVfyoy5rs,40325
15
+ canns/analyzer/data/asa/plotting.py,sha256=xuKRuq12pITK0BOC4Bj3ZfJs-07DA2OUHCpJmtJLOnw,40852
15
16
  canns/analyzer/data/asa/tda.py,sha256=Hn110SBN4wKgHmcwPRkxHdbQ5DFW2xLdSv26J37_Zxc,30342
16
- canns/analyzer/data/legacy/__init__.py,sha256=JNFKxWBlT2p4Nf70Ddby67-4UYNaVS7lQFAy5wXEE-8,92
17
- canns/analyzer/data/legacy/cann1d.py,sha256=QN4vbYXrfcxAVbPhaYiXGoP5QzTc6hWJJGo5ETP_RSU,35793
18
- canns/analyzer/data/legacy/cann2d.py,sha256=RaibP0byPKK9qGefQme0-uZLwqTKOFGBL3VfmbRZOjk,86659
19
17
  canns/analyzer/metrics/__init__.py,sha256=WaTCJE2WhqPtZXYTtMilF_f0LZUfz2h9a25HzwmCwP8,132
20
18
  canns/analyzer/metrics/spatial_metrics.py,sha256=ZdS7tGH3lMgNlSYoHlH8IPsCw4XL0KtEnGALg_WHRX8,8965
21
19
  canns/analyzer/metrics/systematic_ratemap.py,sha256=MzXfa6_fGgrxD5xEd4hfrZR_fyUmhXQCxnPE3hUonE8,14004
@@ -65,7 +63,7 @@ canns/pipeline/asa/runner.py,sha256=PfHXlI-m3m-IYVFcFRhSODfPoRlCrloDOqEftvAfasg,
65
63
  canns/pipeline/asa/screens.py,sha256=DbqidxmoKe4KzSLuxuriVv1PIVFn5Z-PfScVfjrIiEA,5954
66
64
  canns/pipeline/asa/state.py,sha256=XukidfcFIOmm9ttT226FOTYR5hv2VAY8_DZt7V1Ml2g,6955
67
65
  canns/pipeline/asa/styles.tcss,sha256=eaXI3rQeWdBYmWdLJMMiSO6acHtreLRVKKoIHb2-dBk,3349
68
- canns/pipeline/asa/widgets.py,sha256=9pO6fJ13sQUm4vxETdSpYDXDA7z0mi3yFZO_RDakxpY,8087
66
+ canns/pipeline/asa/widgets.py,sha256=3vPGGQWP9V5FwuwqykCVp7dzAHdpcFkDqib0DtIw-lQ,8087
69
67
  canns/pipeline/gallery/__init__.py,sha256=PPOvxmTRzEnj33jHlsFlaWuEfhrcNe39pMPQkTisjlo,187
70
68
  canns/task/__init__.py,sha256=sfo8TBBVAqkx73Nu5lVv77UCwZjKqjt042ezW4Wv2Ec,350
71
69
  canns/task/_base.py,sha256=rdRy4mr6x53z4aJv04UJcpqCyv78AxB1k1DhxnriLpg,4401
@@ -84,8 +82,8 @@ canns/trainer/utils.py,sha256=ZdoLiRqFLfKXsWi0KX3wGUp0OqFikwiou8dPf3xvFhE,2847
84
82
  canns/typing/__init__.py,sha256=mXySdfmD8fA56WqZTb1Nj-ZovcejwLzNjuk6PRfTwmA,156
85
83
  canns/utils/__init__.py,sha256=OMyZ5jqZAIUS2Jr0qcnvvrx6YM-BZ1EJy5uZYeA3HC0,366
86
84
  canns/utils/benchmark.py,sha256=oJ7nvbvnQMh4_MZh7z160NPLp-197X0rEnmnLHYlev4,1361
87
- canns-0.13.0.dist-info/METADATA,sha256=E-6ivcs7Yrp4Hffey9nfIQttDe6dzWwAbg-sT09auAg,8827
88
- canns-0.13.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
89
- canns-0.13.0.dist-info/entry_points.txt,sha256=HDzYp0e9E1wfKYRkYtnUgVKK_u33_7eIn9exgm7t-wg,75
90
- canns-0.13.0.dist-info/licenses/LICENSE,sha256=u6NJ1N-QSnf5yTwSk5UvFAdU2yKD0jxG0Xa91n1cPO4,11306
91
- canns-0.13.0.dist-info/RECORD,,
85
+ canns-0.13.1.dist-info/METADATA,sha256=4HIOm94szk2GaS6A9Kbp9nZHXfoqVvNj_rOBdrExZWE,8827
86
+ canns-0.13.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
87
+ canns-0.13.1.dist-info/entry_points.txt,sha256=HDzYp0e9E1wfKYRkYtnUgVKK_u33_7eIn9exgm7t-wg,75
88
+ canns-0.13.1.dist-info/licenses/LICENSE,sha256=u6NJ1N-QSnf5yTwSk5UvFAdU2yKD0jxG0Xa91n1cPO4,11306
89
+ canns-0.13.1.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- """Legacy data-analysis modules (deprecated)."""
2
-
3
- __all__ = [
4
- "cann1d",
5
- "cann2d",
6
- ]