canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -43,32 +43,21 @@ def select_animation_backend(
|
|
|
43
43
|
requested_backend: str | None = None,
|
|
44
44
|
check_imageio_plugins: bool = True,
|
|
45
45
|
) -> BackendSelection:
|
|
46
|
-
"""
|
|
47
|
-
Select the optimal animation rendering backend.
|
|
48
|
-
|
|
49
|
-
This function implements smart backend selection logic:
|
|
50
|
-
1. If user explicitly requests a backend, validate and use it
|
|
51
|
-
2. Otherwise, auto-select based on file format and available dependencies
|
|
52
|
-
3. For GIF: prefer imageio (parallel rendering)
|
|
53
|
-
4. For MP4: prefer imageio if plugins available, else matplotlib
|
|
54
|
-
5. Always fallback gracefully with helpful warnings
|
|
46
|
+
"""Select the optimal animation rendering backend.
|
|
55
47
|
|
|
56
48
|
Args:
|
|
57
|
-
save_path: Output file path (determines format)
|
|
58
|
-
requested_backend:
|
|
59
|
-
check_imageio_plugins: Whether to verify imageio can write the format
|
|
49
|
+
save_path: Output file path (determines format).
|
|
50
|
+
requested_backend: Backend preference ('imageio', 'matplotlib', 'auto', or None).
|
|
51
|
+
check_imageio_plugins: Whether to verify imageio can write the format.
|
|
60
52
|
|
|
61
53
|
Returns:
|
|
62
|
-
BackendSelection with backend choice and metadata
|
|
54
|
+
BackendSelection with backend choice and metadata.
|
|
63
55
|
|
|
64
|
-
|
|
56
|
+
Examples:
|
|
57
|
+
>>> from canns.analyzer.visualization.core.backend import select_animation_backend
|
|
65
58
|
>>> selection = select_animation_backend("output.mp4")
|
|
66
|
-
>>> print(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
>>> selection = select_animation_backend("output.gif", "matplotlib")
|
|
70
|
-
>>> print(selection.warnings)
|
|
71
|
-
['Consider using imageio backend for faster GIF rendering']
|
|
59
|
+
>>> print(selection.backend in {"imageio", "matplotlib"})
|
|
60
|
+
True
|
|
72
61
|
"""
|
|
73
62
|
warnings_list = []
|
|
74
63
|
|
|
@@ -110,7 +99,7 @@ def select_animation_backend(
|
|
|
110
99
|
return BackendSelection(
|
|
111
100
|
backend="imageio",
|
|
112
101
|
supports_parallel=True,
|
|
113
|
-
reason=
|
|
102
|
+
reason="User explicitly requested imageio backend",
|
|
114
103
|
warnings=[],
|
|
115
104
|
)
|
|
116
105
|
|
|
@@ -277,7 +266,7 @@ def get_multiprocessing_context(prefer_fork: bool = False):
|
|
|
277
266
|
prefer_fork: Whether to prefer 'fork' over 'spawn' (Linux only)
|
|
278
267
|
|
|
279
268
|
Returns:
|
|
280
|
-
|
|
269
|
+
Tuple of (multiprocessing context, method name) or (None, None) if unavailable
|
|
281
270
|
"""
|
|
282
271
|
import multiprocessing as mp
|
|
283
272
|
|
|
@@ -293,16 +282,16 @@ def get_multiprocessing_context(prefer_fork: bool = False):
|
|
|
293
282
|
RuntimeWarning,
|
|
294
283
|
stacklevel=3,
|
|
295
284
|
)
|
|
296
|
-
return mp.get_context("spawn")
|
|
297
|
-
return mp.get_context("fork")
|
|
285
|
+
return mp.get_context("spawn"), "spawn"
|
|
286
|
+
return mp.get_context("fork"), "fork"
|
|
298
287
|
except (RuntimeError, ValueError):
|
|
299
288
|
pass
|
|
300
289
|
|
|
301
290
|
# Default to spawn (works everywhere)
|
|
302
291
|
try:
|
|
303
|
-
return mp.get_context("spawn")
|
|
292
|
+
return mp.get_context("spawn"), "spawn"
|
|
304
293
|
except (RuntimeError, ValueError):
|
|
305
|
-
return None
|
|
294
|
+
return None, None
|
|
306
295
|
|
|
307
296
|
|
|
308
297
|
def emit_backend_warnings(warnings_list: list[str], stacklevel: int = 2):
|
|
@@ -13,13 +13,19 @@ __all__ = ["PlotConfig", "PlotConfigs", "AnimationConfig", "finalize_figure"]
|
|
|
13
13
|
|
|
14
14
|
@dataclass
|
|
15
15
|
class PlotConfig:
|
|
16
|
-
"""Unified configuration class for
|
|
16
|
+
"""Unified configuration class for plotting helpers in ``canns.analyzer``.
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
18
|
+
Examples:
|
|
19
|
+
>>> import numpy as np
|
|
20
|
+
>>> from canns.analyzer.visualization import PlotConfig, energy_landscape_1d_static
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Dummy input (matches test-style energy_landscape usage)
|
|
23
|
+
>>> x = np.linspace(0, 1, 5)
|
|
24
|
+
>>> data_sets = {"u": (x, np.sin(x))}
|
|
25
|
+
>>> config = PlotConfig(title="Demo", show=False)
|
|
26
|
+
>>> fig, ax = energy_landscape_1d_static(data_sets, config=config)
|
|
27
|
+
>>> print(fig is not None)
|
|
28
|
+
True
|
|
23
29
|
"""
|
|
24
30
|
|
|
25
31
|
title: str = ""
|
|
@@ -109,6 +115,21 @@ def finalize_figure(
|
|
|
109
115
|
rasterize_artists: Optional list of artists to rasterize before saving.
|
|
110
116
|
savefig_kwargs: Extra kwargs merged into ``savefig`` (wins over config).
|
|
111
117
|
always_close: If True, close the figure even when ``config.show`` is True.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
>>> import numpy as np
|
|
121
|
+
>>> from matplotlib import pyplot as plt
|
|
122
|
+
>>> from canns.analyzer.visualization import PlotConfig
|
|
123
|
+
>>> from canns.analyzer.visualization.core.config import finalize_figure
|
|
124
|
+
>>>
|
|
125
|
+
>>> x = np.linspace(0, 1, 5)
|
|
126
|
+
>>> y = np.sin(x)
|
|
127
|
+
>>> fig, ax = plt.subplots()
|
|
128
|
+
>>> _ = ax.plot(x, y)
|
|
129
|
+
>>> config = PlotConfig(title="Finalize Demo", show=False)
|
|
130
|
+
>>> finalized = finalize_figure(fig, config)
|
|
131
|
+
>>> print(finalized is not None)
|
|
132
|
+
True
|
|
112
133
|
"""
|
|
113
134
|
|
|
114
135
|
from matplotlib import pyplot as plt
|
|
@@ -156,14 +177,13 @@ class AnimationConfig:
|
|
|
156
177
|
more than this many frames
|
|
157
178
|
|
|
158
179
|
Example:
|
|
159
|
-
>>>
|
|
160
|
-
>>> config = AnimationConfig(fps=30, quality='high')
|
|
161
|
-
>>>
|
|
162
|
-
>>> # Fast draft mode for quick iteration
|
|
163
|
-
>>> draft_config = AnimationConfig(quality='draft') # Auto: 15 FPS, 0.5x resolution
|
|
180
|
+
>>> from canns.analyzer.visualization import AnimationConfig
|
|
164
181
|
>>>
|
|
165
|
-
>>> #
|
|
166
|
-
>>>
|
|
182
|
+
>>> # Dummy input representing total frames
|
|
183
|
+
>>> total_frames = 120
|
|
184
|
+
>>> config = AnimationConfig(fps=30, quality="high")
|
|
185
|
+
>>> print(config.fps, total_frames)
|
|
186
|
+
30 120
|
|
167
187
|
"""
|
|
168
188
|
|
|
169
189
|
fps: int = 30
|
|
@@ -186,8 +206,16 @@ class AnimationConfig:
|
|
|
186
206
|
class PlotConfigs:
|
|
187
207
|
"""Collection of commonly used plot configurations.
|
|
188
208
|
|
|
189
|
-
|
|
190
|
-
|
|
209
|
+
Examples:
|
|
210
|
+
>>> import numpy as np
|
|
211
|
+
>>> from canns.analyzer.visualization import PlotConfigs, energy_landscape_1d_static
|
|
212
|
+
>>>
|
|
213
|
+
>>> x = np.linspace(0, 1, 5)
|
|
214
|
+
>>> data_sets = {"u": (x, np.sin(x))}
|
|
215
|
+
>>> config = PlotConfigs.energy_landscape_1d_static(show=False)
|
|
216
|
+
>>> fig, ax = energy_landscape_1d_static(data_sets, config=config)
|
|
217
|
+
>>> print(fig is not None)
|
|
218
|
+
True
|
|
191
219
|
"""
|
|
192
220
|
|
|
193
221
|
@staticmethod
|
|
@@ -240,6 +268,83 @@ class PlotConfigs:
|
|
|
240
268
|
defaults.update(kwargs)
|
|
241
269
|
return PlotConfig.for_animation(time_steps, **defaults)
|
|
242
270
|
|
|
271
|
+
@staticmethod
|
|
272
|
+
def cohomap(**kwargs: Any) -> PlotConfig:
|
|
273
|
+
defaults: dict[str, Any] = {
|
|
274
|
+
"title": "CohoMap",
|
|
275
|
+
"xlabel": "",
|
|
276
|
+
"ylabel": "",
|
|
277
|
+
"figsize": (10, 4),
|
|
278
|
+
}
|
|
279
|
+
defaults.update(kwargs)
|
|
280
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def cohospace_trajectory(**kwargs: Any) -> PlotConfig:
|
|
284
|
+
defaults: dict[str, Any] = {
|
|
285
|
+
"title": "CohoSpace trajectory",
|
|
286
|
+
"xlabel": "theta1 (deg)",
|
|
287
|
+
"ylabel": "theta2 (deg)",
|
|
288
|
+
"figsize": (6, 6),
|
|
289
|
+
}
|
|
290
|
+
defaults.update(kwargs)
|
|
291
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
292
|
+
|
|
293
|
+
@staticmethod
|
|
294
|
+
def cohospace_neuron(**kwargs: Any) -> PlotConfig:
|
|
295
|
+
defaults: dict[str, Any] = {
|
|
296
|
+
"title": "Neuron activity on coho-space",
|
|
297
|
+
"xlabel": "Theta 1 (deg)",
|
|
298
|
+
"ylabel": "Theta 2 (deg)",
|
|
299
|
+
"figsize": (6, 6),
|
|
300
|
+
}
|
|
301
|
+
defaults.update(kwargs)
|
|
302
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def cohospace_population(**kwargs: Any) -> PlotConfig:
|
|
306
|
+
defaults: dict[str, Any] = {
|
|
307
|
+
"title": "Population activity on coho-space",
|
|
308
|
+
"xlabel": "Theta 1 (deg)",
|
|
309
|
+
"ylabel": "Theta 2 (deg)",
|
|
310
|
+
"figsize": (6, 6),
|
|
311
|
+
}
|
|
312
|
+
defaults.update(kwargs)
|
|
313
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
def fr_heatmap(**kwargs: Any) -> PlotConfig:
|
|
317
|
+
defaults: dict[str, Any] = {
|
|
318
|
+
"title": "Firing Rate Heatmap",
|
|
319
|
+
"xlabel": "Time",
|
|
320
|
+
"ylabel": "Neuron",
|
|
321
|
+
"figsize": (10, 5),
|
|
322
|
+
"clabel": "Value",
|
|
323
|
+
}
|
|
324
|
+
defaults.update(kwargs)
|
|
325
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def frm(**kwargs: Any) -> PlotConfig:
|
|
329
|
+
defaults: dict[str, Any] = {
|
|
330
|
+
"title": "Firing Rate Map",
|
|
331
|
+
"xlabel": "X bin",
|
|
332
|
+
"ylabel": "Y bin",
|
|
333
|
+
"figsize": (6, 5),
|
|
334
|
+
"clabel": "Rate",
|
|
335
|
+
}
|
|
336
|
+
defaults.update(kwargs)
|
|
337
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
338
|
+
|
|
339
|
+
@staticmethod
|
|
340
|
+
def path_compare(**kwargs: Any) -> PlotConfig:
|
|
341
|
+
defaults: dict[str, Any] = {
|
|
342
|
+
"title": "Path Compare",
|
|
343
|
+
"figsize": (12, 5),
|
|
344
|
+
}
|
|
345
|
+
defaults.update(kwargs)
|
|
346
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
347
|
+
|
|
243
348
|
@staticmethod
|
|
244
349
|
def raster_plot(mode: str = "block", **kwargs: Any) -> PlotConfig:
|
|
245
350
|
defaults: dict[str, Any] = {
|
|
@@ -4,11 +4,15 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def is_jupyter_environment() -> bool:
|
|
7
|
-
"""
|
|
8
|
-
Detect if code is running in a Jupyter notebook environment.
|
|
7
|
+
"""Detect if code is running in a Jupyter notebook environment.
|
|
9
8
|
|
|
10
9
|
Returns:
|
|
11
|
-
bool: True if running in Jupyter
|
|
10
|
+
bool: True if running in a Jupyter notebook, False otherwise.
|
|
11
|
+
|
|
12
|
+
Examples:
|
|
13
|
+
>>> from canns.analyzer.visualization.core.jupyter_utils import is_jupyter_environment
|
|
14
|
+
>>> print(is_jupyter_environment() in {True, False})
|
|
15
|
+
True
|
|
12
16
|
"""
|
|
13
17
|
try:
|
|
14
18
|
# Check if IPython is available and we're in a notebook
|
|
@@ -28,23 +32,37 @@ def is_jupyter_environment() -> bool:
|
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
def display_animation_in_jupyter(animation, format: str = "html5"):
|
|
31
|
-
"""
|
|
32
|
-
Display a matplotlib animation in Jupyter notebook.
|
|
33
|
-
|
|
34
|
-
Performance comparison (100 frames):
|
|
35
|
-
- html5 (default): 1.3s, 134 KB - Fast encoding, small size, smooth playback
|
|
36
|
-
- jshtml: 2.6s, 4837 KB - Slower, 36x larger, but works without FFmpeg
|
|
35
|
+
"""Display a matplotlib animation in a Jupyter notebook.
|
|
37
36
|
|
|
38
37
|
Args:
|
|
39
|
-
animation: matplotlib.animation.FuncAnimation
|
|
40
|
-
format: Display format -
|
|
38
|
+
animation: ``matplotlib.animation.FuncAnimation`` instance.
|
|
39
|
+
format: Display format - ``"html5"`` (default) or ``"jshtml"``.
|
|
41
40
|
|
|
42
41
|
Returns:
|
|
43
|
-
IPython.display.HTML object if successful, None
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
42
|
+
``IPython.display.HTML`` object if successful, otherwise ``None``.
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
>>> import numpy as np
|
|
46
|
+
>>> from matplotlib import pyplot as plt
|
|
47
|
+
>>> from matplotlib.animation import FuncAnimation
|
|
48
|
+
>>> from canns.analyzer.visualization.core.jupyter_utils import (
|
|
49
|
+
... display_animation_in_jupyter,
|
|
50
|
+
... is_jupyter_environment,
|
|
51
|
+
... )
|
|
52
|
+
>>>
|
|
53
|
+
>>> x = np.linspace(0, 2 * np.pi, 50)
|
|
54
|
+
>>> fig, ax = plt.subplots()
|
|
55
|
+
>>> (line,) = ax.plot([], [])
|
|
56
|
+
>>>
|
|
57
|
+
>>> def update(i):
|
|
58
|
+
... line.set_data(x[: i + 1], np.sin(x[: i + 1]))
|
|
59
|
+
... return (line,)
|
|
60
|
+
>>>
|
|
61
|
+
>>> anim = FuncAnimation(fig, update, frames=5, interval=50, blit=True)
|
|
62
|
+
>>> if is_jupyter_environment():
|
|
63
|
+
... _ = display_animation_in_jupyter(anim, format="jshtml")
|
|
64
|
+
... print(anim is not None)
|
|
65
|
+
True
|
|
48
66
|
"""
|
|
49
67
|
try:
|
|
50
68
|
from IPython.display import HTML, display
|
|
@@ -286,51 +286,48 @@ def render_animation_parallel(
|
|
|
286
286
|
show_progress: bool = True,
|
|
287
287
|
file_format: str | None = None,
|
|
288
288
|
):
|
|
289
|
-
"""Universal parallel animation renderer for
|
|
290
|
-
|
|
291
|
-
This function provides a unified interface for parallel frame rendering that can be
|
|
292
|
-
used by ANY animation function in the codebase. It handles:
|
|
293
|
-
- Format detection (GIF vs MP4)
|
|
294
|
-
- Parallel vs sequential rendering
|
|
295
|
-
- Progress bars
|
|
296
|
-
- Optimal writer selection
|
|
289
|
+
"""Universal parallel animation renderer for analyzer animations.
|
|
297
290
|
|
|
298
291
|
Args:
|
|
299
292
|
render_frame_func: Callable that renders a single frame:
|
|
300
|
-
|
|
301
|
-
frame_data: Data needed by render_frame_func (
|
|
302
|
-
num_frames: Total number of frames to render
|
|
303
|
-
save_path: Output file path (extension determines format)
|
|
304
|
-
fps: Frames per second
|
|
305
|
-
num_workers: Number of parallel workers (None = auto-detect)
|
|
306
|
-
show_progress: Whether to show progress bar
|
|
307
|
-
file_format: Override file format detection ('gif', 'mp4', etc.)
|
|
293
|
+
``func(frame_idx, frame_data) -> np.ndarray (H, W, 3 or 4)``.
|
|
294
|
+
frame_data: Data needed by ``render_frame_func`` (passed to workers).
|
|
295
|
+
num_frames: Total number of frames to render.
|
|
296
|
+
save_path: Output file path (extension determines format).
|
|
297
|
+
fps: Frames per second.
|
|
298
|
+
num_workers: Number of parallel workers (None = auto-detect).
|
|
299
|
+
show_progress: Whether to show progress bar.
|
|
300
|
+
file_format: Override file format detection ('gif', 'mp4', etc.).
|
|
308
301
|
|
|
309
302
|
Returns:
|
|
310
|
-
None (saves animation to file)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
>>>
|
|
314
|
-
|
|
315
|
-
|
|
303
|
+
None (saves animation to file).
|
|
304
|
+
|
|
305
|
+
Examples:
|
|
306
|
+
>>> import numpy as np
|
|
307
|
+
>>> import tempfile
|
|
308
|
+
>>> from pathlib import Path
|
|
309
|
+
>>> from canns.analyzer.visualization.core.rendering import render_animation_parallel
|
|
310
|
+
>>> from canns.analyzer.visualization.core import rendering
|
|
311
|
+
>>>
|
|
312
|
+
>>> def render_frame(idx, data):
|
|
313
|
+
... frame = data[idx]
|
|
314
|
+
... return frame # (H, W, 3)
|
|
316
315
|
>>>
|
|
317
|
-
>>>
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
...
|
|
321
|
-
...
|
|
322
|
-
...
|
|
323
|
-
... )
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
- Automatically falls back to sequential for short animations (<50 frames)
|
|
316
|
+
>>> frames = [np.zeros((10, 10, 3), dtype=np.uint8) for _ in range(2)]
|
|
317
|
+
>>> # Save a tiny animation if imageio is available
|
|
318
|
+
>>> if rendering.IMAGEIO_AVAILABLE:
|
|
319
|
+
... with tempfile.TemporaryDirectory() as tmpdir:
|
|
320
|
+
... save_path = Path(tmpdir) / "demo.gif"
|
|
321
|
+
... render_animation_parallel(
|
|
322
|
+
... render_frame, frames, num_frames=2, save_path=str(save_path), fps=2
|
|
323
|
+
... )
|
|
324
|
+
... print("saved")
|
|
325
|
+
... else:
|
|
326
|
+
... print("imageio not available")
|
|
329
327
|
"""
|
|
330
|
-
import os
|
|
331
328
|
import multiprocessing as mp
|
|
332
|
-
import
|
|
333
|
-
|
|
329
|
+
import os
|
|
330
|
+
|
|
334
331
|
from tqdm import tqdm
|
|
335
332
|
|
|
336
333
|
# Detect file format
|
|
@@ -489,7 +486,12 @@ def _render_mp4_parallel(
|
|
|
489
486
|
if IMAGEIO_AVAILABLE:
|
|
490
487
|
# Try imageio first (simpler, more reliable if ffmpeg plugin available)
|
|
491
488
|
try:
|
|
492
|
-
writer_kwargs = {
|
|
489
|
+
writer_kwargs = {
|
|
490
|
+
"fps": fps,
|
|
491
|
+
"codec": "libx264",
|
|
492
|
+
"pixelformat": "yuv420p",
|
|
493
|
+
"bitrate": "5000k",
|
|
494
|
+
}
|
|
493
495
|
with imageio.get_writer(save_path, **writer_kwargs) as writer:
|
|
494
496
|
for frame in frames:
|
|
495
497
|
# Ensure RGB format
|
|
@@ -514,14 +516,14 @@ def _render_mp4_parallel(
|
|
|
514
516
|
h, w = frames[0].shape[:2]
|
|
515
517
|
fig = plt.figure(figsize=(w / 100, h / 100), dpi=100, frameon=False)
|
|
516
518
|
ax = fig.add_axes([0, 0, 1, 1])
|
|
517
|
-
ax.axis(
|
|
519
|
+
ax.axis("off")
|
|
518
520
|
|
|
519
521
|
writer = FFMpegWriter(fps=fps, codec="h264", bitrate=5000)
|
|
520
522
|
with writer.saving(fig, save_path, dpi=100):
|
|
521
523
|
for frame in frames:
|
|
522
524
|
ax.clear()
|
|
523
525
|
ax.imshow(frame)
|
|
524
|
-
ax.axis(
|
|
526
|
+
ax.axis("off")
|
|
525
527
|
writer.grab_frame()
|
|
526
528
|
|
|
527
529
|
plt.close(fig)
|
|
@@ -439,31 +439,21 @@ def warn_gif_format(*, stacklevel: int = 2) -> None:
|
|
|
439
439
|
|
|
440
440
|
|
|
441
441
|
def get_matplotlib_writer(save_path: str, fps: int = 10, **kwargs):
|
|
442
|
-
"""
|
|
443
|
-
Create appropriate matplotlib animation writer based on file extension.
|
|
444
|
-
|
|
445
|
-
This function automatically selects the correct writer:
|
|
446
|
-
- .mp4 → FFMpegWriter (H.264 codec, high quality, fast encoding)
|
|
447
|
-
- .gif → PillowWriter (universal compatibility)
|
|
448
|
-
- others → FFMpegWriter (default)
|
|
442
|
+
"""Create a Matplotlib animation writer based on file extension.
|
|
449
443
|
|
|
450
444
|
Args:
|
|
451
|
-
save_path: Output file path (extension determines format)
|
|
452
|
-
fps: Frames per second
|
|
453
|
-
**kwargs: Additional arguments passed to the writer
|
|
454
|
-
For FFMpegWriter: codec, bitrate, extra_args
|
|
455
|
-
For PillowWriter: codec (ignored)
|
|
445
|
+
save_path: Output file path (extension determines format).
|
|
446
|
+
fps: Frames per second.
|
|
447
|
+
**kwargs: Additional arguments passed to the writer.
|
|
456
448
|
|
|
457
449
|
Returns:
|
|
458
|
-
Matplotlib animation writer instance
|
|
450
|
+
Matplotlib animation writer instance.
|
|
459
451
|
|
|
460
|
-
|
|
461
|
-
>>> from
|
|
462
|
-
>>> writer = get_matplotlib_writer(
|
|
463
|
-
>>>
|
|
464
|
-
|
|
465
|
-
>>> # With custom codec
|
|
466
|
-
>>> writer = get_matplotlib_writer('output.mp4', fps=30, bitrate=8000)
|
|
452
|
+
Examples:
|
|
453
|
+
>>> from canns.analyzer.visualization.core.writers import get_matplotlib_writer
|
|
454
|
+
>>> writer = get_matplotlib_writer("output.gif", fps=5)
|
|
455
|
+
>>> print(writer is not None)
|
|
456
|
+
True
|
|
467
457
|
"""
|
|
468
458
|
import os
|
|
469
459
|
|