parallel-matplotlib-animation 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ from .animator import Animator
@@ -0,0 +1,387 @@
1
+ import matplotlib
2
+
3
+ matplotlib.use("Agg")
4
+
5
+ import tempfile
6
+ import logging
7
+ import time
8
+ import multiprocessing as mp
9
+ from pathlib import Path
10
+ from abc import ABC, abstractmethod
11
+ from typing import Any
12
+
13
+ import matplotlib.pyplot as plt
14
+ import av
15
+ from PIL import Image
16
+ from tqdm import tqdm, trange
17
+
18
+
19
+ class Animator(ABC):
20
+ """
21
+ Base class for creating matplotlib animations with efficient parallel rendering.
22
+ """
23
+
24
+ def __init__(self):
25
+ """Initialize the animator."""
26
+ self.logger = logging.getLogger(__name__)
27
+
28
+ @abstractmethod
29
+ def setup(self):
30
+ """
31
+ Set up the figure, axes, and any artists. Store whatever you need as instance
32
+ attributes (self.ax, self.line, etc.); you will need them in update().
33
+
34
+ This method MUST not accept any argument other than self (if you need to pass
35
+ data to setup, store it as instance attributes in __init__).
36
+
37
+ This method MUST return the plt.Figure object, and that alone.
38
+
39
+ This method is called once before the animation loop (once per worker if
40
+ parallelized), unless reuse_figure_object is False (basically never).
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def update(self, frame_idx: int, params: Any):
46
+ """
47
+ Update the plot for the given frame.
48
+
49
+ This is called for each frame in the animation. Modify any figure, axes, and
50
+ artists that you created in setup() and stored as instance attributes.
51
+
52
+ Args:
53
+ frame_idx (int): Current frame index (0 to num_frames-1)
54
+ params (Any): Parameters for this frame from param_by_frame[frame_idx]. Can
55
+ be any object you want as long as it's pickleable (required to
56
+ distribute jobs to distributed workers).
57
+ """
58
+ pass
59
+
60
+ def make_video(
61
+ self,
62
+ output_file: Path | str,
63
+ param_by_frame: list[Any],
64
+ fps: int,
65
+ num_workers: int = -1,
66
+ disable_progress_bar: bool | None = None,
67
+ plotting_log_interval: int | None = None,
68
+ saving_log_interval: int | None = None,
69
+ savefig_params: dict[str, Any] = {},
70
+ video_codec: str = "libx264",
71
+ video_params: dict[str, Any] = {"pix_fmt": "yuv420p"},
72
+ reuse_figure_object: bool = True,
73
+ ) -> None:
74
+ """
75
+ Render the animation to a video file.
76
+
77
+ Args:
78
+ output_file (Path | str): Path to output video file
79
+ param_by_frame (list[Any]): List of parameters, one per frame
80
+ fps (int): Frames per second for output video
81
+ num_workers (int): Number of parallel workers. 1 for serial processing
82
+ (in the main thread), -1 for all CPU cores, -2 for all but one CPU core,
83
+ etc.
84
+ disable_progress_bar (bool | None): Same behavior as tqdm: if True, disable
85
+ progress bar. If None, auto-detect based on whether output is a TTY.
86
+ plotting_log_interval (int | None): Log progress every N frames in each
87
+ worker (None = no interval logging).
88
+ saving_log_interval (int | None): Log progress every N frames when merging
89
+ frames into video (None = no interval logging).
90
+ savefig_params (dict[str, Any]): Additional keyword arguments to
91
+ pass to plt.Figure.savefig() when saving frames (default: {}).
92
+ video_codec (str): Codec to use for video encoding (default: "libx264").
93
+ video_params (dict[str, Any]): Additional parameters to set on the video
94
+ stream (default: {"pix_fmt": "yuv420p"}).
95
+ reuse_figure_object (bool): If False, the figure will be re-created for each
96
+ frame (i.e. setup() called every frame). There is basically no reason to
97
+ set this to False. Use only for testing and benchmarking.
98
+ """
99
+ # Try to convert param_by_frame to list
100
+ try:
101
+ params_list = list(param_by_frame)
102
+ except Exception as e:
103
+ self.logger.critical(
104
+ "param_by_frame must be convertible to a list. Ensure it is a list-like object."
105
+ )
106
+ raise e
107
+
108
+ num_frames = len(params_list)
109
+
110
+ # Determine number of workers
111
+ if num_workers == -1:
112
+ num_workers = mp.cpu_count()
113
+ elif num_workers < -1:
114
+ num_workers = max(1, mp.cpu_count() + num_workers + 1)
115
+
116
+ self.logger.info(f"Rendering {num_frames} frames at {fps} fps")
117
+ with tempfile.TemporaryDirectory(prefix="animator_frames_") as frames_dir:
118
+ self.logger.info(f"Using temporary directory: {frames_dir}")
119
+
120
+ if num_workers == 1:
121
+ self.logger.info("Running in serial mode")
122
+ self._render_serial(
123
+ params_list,
124
+ frames_dir,
125
+ disable_progress_bar,
126
+ plotting_log_interval,
127
+ savefig_params,
128
+ reuse_figure_object,
129
+ )
130
+ else:
131
+ self.logger.info(f"Running in parallel mode with {num_workers} workers")
132
+ self._render_parallel(
133
+ params_list,
134
+ frames_dir,
135
+ num_workers,
136
+ disable_progress_bar,
137
+ plotting_log_interval,
138
+ savefig_params,
139
+ reuse_figure_object,
140
+ )
141
+
142
+ self.logger.info("Creating video with PyAV")
143
+ _merge_frames_into_video(
144
+ frames_dir,
145
+ output_file,
146
+ fps,
147
+ video_codec,
148
+ video_params,
149
+ disable_progress_bar,
150
+ self.logger,
151
+ log_interval=saving_log_interval,
152
+ )
153
+
154
+ self.logger.info(f"Animation complete: {output_file}")
155
+
156
+ def _setup_and_check(self) -> plt.Figure:
157
+ """Call setup() and validate its return type."""
158
+ fig = self.setup()
159
+ if not isinstance(fig, plt.Figure):
160
+ raise TypeError(
161
+ f"`.setup()` must return a matplotlib Figure object, got {type(fig).__name__}"
162
+ )
163
+ return fig
164
+
165
+ def _render_serial(
166
+ self,
167
+ params_list: list[Any],
168
+ frames_dir: Path | str,
169
+ disable_progress_bar: bool | None,
170
+ log_interval: int | None,
171
+ savefig_params: dict[str, Any],
172
+ reuse_figure_object: bool,
173
+ ) -> None:
174
+ """Render frames serially."""
175
+ self.logger.info("Serial rendering")
176
+
177
+ if reuse_figure_object:
178
+ # Setup once and get figure
179
+ fig = self._setup_and_check()
180
+
181
+ # Render all frames with progress bar
182
+ num_frames = len(params_list)
183
+ for frame_idx in trange(
184
+ num_frames, desc="Rendering", disable=disable_progress_bar
185
+ ):
186
+ if not reuse_figure_object:
187
+ fig = self._setup_and_check()
188
+
189
+ params = params_list[frame_idx]
190
+ self.update(frame_idx, params)
191
+ fig.canvas.draw()
192
+
193
+ frame_path = Path(frames_dir) / f"frame_{frame_idx:09d}.png"
194
+ fig.savefig(frame_path, **savefig_params)
195
+ if not reuse_figure_object:
196
+ plt.close(fig)
197
+
198
+ # Optional interval logging
199
+ if log_interval and (frame_idx + 1) % log_interval == 0:
200
+ self.logger.info(f"Frame {frame_idx + 1}/{num_frames}")
201
+
202
+ plt.close(fig)
203
+
204
+ def _render_parallel(
205
+ self,
206
+ params_list: list[Any],
207
+ frames_dir: Path | str,
208
+ num_workers: int,
209
+ disable_progress_bar: bool | None,
210
+ log_interval: int | None,
211
+ savefig_params: dict[str, Any],
212
+ reuse_figure_object: bool,
213
+ ) -> None:
214
+ """Render frames in parallel using dynamic work distribution."""
215
+ num_frames = len(params_list)
216
+
217
+ self.logger.info(f"Using dynamic work distribution with {num_workers} workers")
218
+
219
+ # Create queues for task distribution and atomic counter for progress
220
+ task_queue = mp.Queue()
221
+ num_frames_completed = mp.Value("i", 0) # atomic integer counter
222
+
223
+ # Start worker processes
224
+ workers = []
225
+ for worker_id in range(num_workers):
226
+ p = mp.Process(
227
+ target=_worker_process,
228
+ args=(
229
+ self,
230
+ worker_id,
231
+ task_queue,
232
+ num_frames_completed,
233
+ frames_dir,
234
+ log_interval,
235
+ savefig_params,
236
+ reuse_figure_object,
237
+ ),
238
+ )
239
+ p.start()
240
+ workers.append(p)
241
+
242
+ # Populate task queue with individual frames (batch_size = 1)
243
+ for frame_idx in range(num_frames):
244
+ task_queue.put((frame_idx, params_list[frame_idx]))
245
+
246
+ # Send sentinel values to signal workers to exit
247
+ for _ in range(num_workers):
248
+ task_queue.put(None)
249
+
250
+ # Monitor progress using atomic counter
251
+ pbar = tqdm(total=num_frames, desc="Rendering", disable=disable_progress_bar)
252
+ while num_frames_completed.value < num_frames:
253
+ current_progress = num_frames_completed.value
254
+ if current_progress > pbar.n:
255
+ pbar.update(current_progress - pbar.n)
256
+ time.sleep(0.1)
257
+ pbar.update(num_frames_completed.value - pbar.n)
258
+ pbar.close()
259
+
260
+ # Wait for all workers to finish
261
+ for p in workers:
262
+ p.join()
263
+
264
+ self.logger.info("All workers completed")
265
+
266
+
267
+ def _worker_process(
268
+ animator: Animator,
269
+ worker_id: int,
270
+ task_queue: mp.Queue,
271
+ progress_counter,
272
+ frames_dir: Path | str,
273
+ log_interval: int | None,
274
+ savefig_params: dict[str, Any],
275
+ reuse_figure_object: bool,
276
+ ) -> None:
277
+ """
278
+ Worker process that renders frames.
279
+
280
+ Each worker:
281
+ 1. Calls setup() once to initialize the figure (unless reuse_figure_object is False)
282
+ 2. Repeatedly pulls individual frames from the task queue
283
+ 3. Renders each frame
284
+ 4. Atomically increments the progress counter
285
+ """
286
+ # Setup once per worker
287
+ if reuse_figure_object:
288
+ fig = animator._setup_and_check()
289
+
290
+ # Process frames until we get a sentinel value (None)
291
+ frames_processed = 0
292
+ while True:
293
+ task = task_queue.get()
294
+ if task is None:
295
+ break # sentinel value - exit
296
+ frame_idx, params = task
297
+
298
+ # Render the frame
299
+ if not reuse_figure_object:
300
+ fig = animator._setup_and_check()
301
+ animator.update(frame_idx, params)
302
+ fig.canvas.draw()
303
+ frame_path = Path(frames_dir) / f"frame_{frame_idx:09d}.png"
304
+ fig.savefig(frame_path, **savefig_params)
305
+ frames_processed += 1
306
+ if not reuse_figure_object:
307
+ plt.close(fig)
308
+
309
+ # Logging
310
+ if log_interval and frames_processed % log_interval == 0:
311
+ animator.logger.info(
312
+ f"Worker {worker_id}: processed {frames_processed} frames"
313
+ )
314
+
315
+ # Atomically increment progress counter
316
+ with progress_counter.get_lock():
317
+ progress_counter.value += 1
318
+
319
+ plt.close(fig)
320
+ animator.logger.info(f"Worker {worker_id}: completed {frames_processed} frames")
321
+
322
+
323
+ def _merge_frames_into_video(
324
+ frames_dir: Path | str,
325
+ output_file: Path | str,
326
+ fps: int,
327
+ video_codec: str,
328
+ video_params: dict[str, Any],
329
+ disable_progress_bar: bool | None,
330
+ logger: logging.Logger,
331
+ log_interval: int | None,
332
+ ) -> None:
333
+ """Use PyAV to merge frames into video."""
334
+ # Gather frame files in sorted order
335
+ frame_files = sorted(Path(frames_dir).glob("frame_*.png"))
336
+
337
+ if not frame_files:
338
+ raise RuntimeError("No frames found in temporary directory")
339
+
340
+ # Open first image to determine size and ensure even dimensions
341
+ with Image.open(frame_files[0]) as first_img:
342
+ width, height = first_img.size
343
+ # Make width/height even (required by many codecs)
344
+ width = (width // 2) * 2
345
+ height = (height // 2) * 2
346
+
347
+ try:
348
+ Path(output_file).parent.mkdir(parents=True, exist_ok=True)
349
+ container = av.open(str(output_file), mode="w")
350
+ stream = container.add_stream(video_codec, rate=fps)
351
+ stream.width = width
352
+ stream.height = height
353
+ for key, value in video_params.items():
354
+ setattr(stream, key, value)
355
+
356
+ # Encode each frame
357
+ for i, frame_path in tqdm(
358
+ enumerate(frame_files),
359
+ total=len(frame_files),
360
+ desc="Merging frames",
361
+ disable=disable_progress_bar,
362
+ ):
363
+ img = Image.open(frame_path).convert("RGBA")
364
+ # Ensure image has the right size
365
+ if img.size != (width, height):
366
+ img = img.resize((width, height))
367
+
368
+ video_frame = av.VideoFrame.from_image(img)
369
+ for packet in stream.encode(video_frame):
370
+ container.mux(packet)
371
+
372
+ # Optional interval logging
373
+ if log_interval and (i + 1) % log_interval == 0:
374
+ logger.info(f"Frame written {i + 1}/{len(frame_files)} to video")
375
+
376
+ # Flush encoder
377
+ for packet in stream.encode(None):
378
+ container.mux(packet)
379
+
380
+ container.close()
381
+
382
+ size_mb = Path(output_file).stat().st_size / (1024 * 1024)
383
+ logger.info(f"Video created: {output_file} ({size_mb:.2f} MB)")
384
+
385
+ except Exception as e:
386
+ logger.critical(f"PyAV failed to create video: {e}")
387
+ raise e
File without changes
@@ -0,0 +1,144 @@
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from parallel_animate import Animator
7
+
8
+
9
+ class MultiPanelAnimation(Animator):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def setup(self):
14
+ # Create figure with 1 row, 5 columns of subplots
15
+ self.fig, axes = plt.subplots(1, 5, figsize=(20, 4))
16
+ self.axes = list(axes)
17
+
18
+ # Replace 4th subplot with 3D version
19
+ self.axes[3].remove()
20
+ self.axes[3] = self.fig.add_subplot(1, 5, 4, projection="3d")
21
+
22
+ # Initialize artists for each subplot
23
+ self.artists = {}
24
+
25
+ # Line chart - multiple overlapping waves
26
+ x = np.linspace(0, 4 * np.pi, 100)
27
+ (line,) = self.axes[0].plot(x, np.sin(x), "b-", linewidth=2)
28
+ self.axes[0].set_xlim(0, 4 * np.pi)
29
+ self.axes[0].set_ylim(-2, 2)
30
+ self.axes[0].set_title("Line Chart")
31
+ self.axes[0].grid(True, alpha=0.3)
32
+ self.artists["line"] = line
33
+ self.artists["line_x"] = x
34
+
35
+ # Bar chart - dancing bars
36
+ bars = self.axes[1].bar(range(10), np.random.rand(10), color="steelblue")
37
+ self.axes[1].set_ylim(0, 2)
38
+ self.axes[1].set_title("Bar Chart")
39
+ self.axes[1].set_xlabel("Category")
40
+ self.axes[1].set_ylabel("Value")
41
+ self.artists["bars"] = bars
42
+
43
+ # Imshow - wave interference pattern
44
+ data = np.random.rand(30, 30)
45
+ im = self.axes[2].imshow(data, cmap="viridis", vmin=0, vmax=2)
46
+ self.axes[2].set_title("Wave Interference")
47
+ self.axes[2].axis("off")
48
+ self.fig.colorbar(im, ax=self.axes[2], fraction=0.046, pad=0.04)
49
+ self.artists["imshow"] = im
50
+
51
+ # 3D line - rotating helix
52
+ t = np.linspace(0, 4 * np.pi, 100)
53
+ (line3d,) = self.axes[3].plot(np.cos(t), np.sin(t), t, "b-", linewidth=2)
54
+ self.axes[3].set_xlim(-2, 2)
55
+ self.axes[3].set_ylim(-2, 2)
56
+ self.axes[3].set_zlim(0, 4 * np.pi)
57
+ self.axes[3].set_title("3D Helix")
58
+ self.axes[3].set_xlabel("X")
59
+ self.axes[3].set_ylabel("Y")
60
+ self.axes[3].set_zlabel("Z")
61
+ self.artists["line3d"] = line3d
62
+ self.artists["line3d_t"] = t
63
+
64
+ # Scatter - orbiting points with color gradient
65
+ n_points = 50
66
+ theta = np.linspace(0, 2 * np.pi, n_points)
67
+ x_scatter = 0.5 + 0.3 * np.cos(theta)
68
+ y_scatter = 0.5 + 0.3 * np.sin(theta)
69
+ colors = np.linspace(0, 1, n_points)
70
+ scatter = self.axes[4].scatter(
71
+ x_scatter, y_scatter, c=colors, cmap="plasma", s=100, alpha=0.8
72
+ )
73
+ self.axes[4].set_xlim(0, 1)
74
+ self.axes[4].set_ylim(0, 1)
75
+ self.axes[4].set_title("Orbiting Points")
76
+ self.axes[4].set_aspect("equal")
77
+ self.artists["scatter"] = scatter
78
+ self.artists["n_points"] = n_points
79
+
80
+ self.fig.tight_layout()
81
+ return self.fig
82
+
83
+ def update(self, frame_idx, params):
84
+ phase = params["phase"]
85
+
86
+ # Update line chart - multiple frequency sine wave
87
+ x = self.artists["line_x"]
88
+ y = np.sin(x + phase) + 0.3 * np.sin(3 * x - phase)
89
+ self.artists["line"].set_ydata(y)
90
+
91
+ # Update bar chart - pulsating bars
92
+ heights = np.abs(np.sin(np.arange(10) * 0.5 + phase)) * 1.5 + 0.2
93
+ for bar, h in zip(self.artists["bars"], heights):
94
+ bar.set_height(h)
95
+
96
+ # Update imshow - wave interference from moving sources
97
+ x = np.linspace(-5, 5, 30)
98
+ y = np.linspace(-5, 5, 30)
99
+ X, Y = np.meshgrid(x, y)
100
+
101
+ # Two moving wave sources
102
+ x1, y1 = 2 * np.cos(phase), 2 * np.sin(phase)
103
+ x2, y2 = 2 * np.cos(phase + np.pi), 2 * np.sin(phase + np.pi)
104
+
105
+ R1 = np.sqrt((X - x1) ** 2 + (Y - y1) ** 2)
106
+ R2 = np.sqrt((X - x2) ** 2 + (Y - y2) ** 2)
107
+
108
+ Z = np.sin(3 * R1 - 2 * phase) / (R1 + 0.5) + np.sin(3 * R2 - 2 * phase) / (
109
+ R2 + 0.5
110
+ )
111
+ Z = Z + 1 # Shift to positive range
112
+
113
+ self.artists["imshow"].set_data(Z)
114
+
115
+ # Update 3D line - rotating and pulsating helix
116
+ t = self.artists["line3d_t"]
117
+ radius = 1 + 0.3 * np.sin(2 * phase)
118
+ x_3d = radius * np.cos(t + phase)
119
+ y_3d = radius * np.sin(t + phase)
120
+ z_3d = t
121
+ self.artists["line3d"].set_data_3d(x_3d, y_3d, z_3d)
122
+
123
+ # Update scatter - orbiting and pulsating points
124
+ n_points = self.artists["n_points"]
125
+ angle = phase
126
+ r = 0.25 + 0.15 * np.sin(2 * angle)
127
+ theta = np.linspace(0, 2 * np.pi, n_points) + angle
128
+ x_scatter = 0.5 + r * np.cos(theta)
129
+ y_scatter = 0.5 + r * np.sin(theta)
130
+ self.artists["scatter"].set_offsets(np.c_[x_scatter, y_scatter])
131
+
132
+
133
+ if __name__ == "__main__":
134
+ # Generate parameters
135
+ num_frames = 90
136
+ params = [{"phase": 2 * np.pi * i / num_frames} for i in range(num_frames)]
137
+
138
+ # Create animation
139
+ anim = MultiPanelAnimation()
140
+ output_path = Path("example_output/multi_panel_animation.mp4")
141
+ output_path.parent.mkdir(parents=True, exist_ok=True)
142
+ anim.make_video(
143
+ output_file=output_path, param_by_frame=params, fps=30, num_workers=8
144
+ )
@@ -0,0 +1,96 @@
1
+ import time
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ from parallel_animate.examples.very_complex_animation import VeryComplexAnimation
9
+
10
+
11
+ def run_scaling_test(
12
+ num_frames: int, num_workers_list: list[int], output_dir: Path
13
+ ) -> dict[str, dict]:
14
+ """Run strong scaling test for different parallelization strategies."""
15
+ # Generate parameters
16
+ params = [{"phase": 2 * np.pi * i / num_frames} for i in range(num_frames)]
17
+
18
+ # Run tests for all configurations
19
+ results_all = {}
20
+ for num_workers in num_workers_list:
21
+ for reuse_figure_object in [True, False]:
22
+ config_name = f"{num_workers}workers_reusefig{reuse_figure_object}"
23
+ print(f"Running test: {config_name}...")
24
+ output_file = output_dir / f"output_{config_name}.mp4"
25
+ start_time = time.time()
26
+ anim = VeryComplexAnimation()
27
+ anim.make_video(
28
+ output_file=output_file,
29
+ param_by_frame=params,
30
+ fps=30,
31
+ num_workers=num_workers,
32
+ disable_progress_bar=True,
33
+ reuse_figure_object=reuse_figure_object,
34
+ )
35
+ elapsed_time = time.time() - start_time
36
+ results_all[config_name] = {
37
+ "num_workers": num_workers,
38
+ "reuse_fig_obj": reuse_figure_object,
39
+ "time_seconds": elapsed_time,
40
+ }
41
+ print(f"Test {config_name} completed in {elapsed_time:.2f} seconds")
42
+
43
+ return results_all
44
+
45
+
46
+ def plot_scaling_results(results: dict[str, dict], output_path: Path) -> None:
47
+ """Create plots showing scaling performance."""
48
+ num_workers_list = list(set(result["num_workers"] for result in results.values()))
49
+
50
+ fig, ax = plt.subplots(figsize=(4, 4), tight_layout=True)
51
+ baseline_time = results[f"1workers_reusefigFalse"]["time_seconds"]
52
+ _max_speedup = len(num_workers_list)
53
+ for reuse_figure_object in [True, False]:
54
+ times = [
55
+ results[f"{n}workers_reusefig{reuse_figure_object}"]["time_seconds"]
56
+ for n in num_workers_list
57
+ ]
58
+ speedups = [baseline_time / t for t in times]
59
+ _max_speedup = max(_max_speedup, *speedups)
60
+ ax.plot(
61
+ num_workers_list,
62
+ speedups,
63
+ marker="o",
64
+ label="with cache" if reuse_figure_object else "without cache",
65
+ )
66
+ ax.plot(
67
+ [1, max(num_workers_list)],
68
+ [1, max(num_workers_list)],
69
+ color="black",
70
+ label="ideal scaling",
71
+ zorder=0,
72
+ )
73
+ ax.legend()
74
+ ax.set_xscale("log", base=2)
75
+ ax.set_yscale("log", base=2)
76
+ ax.set_xlabel("# workers")
77
+ ax.set_ylabel("speedup")
78
+ ax.set_title("Strong scaling test")
79
+ ax.set_aspect("equal")
80
+ fig.savefig(output_path)
81
+ plt.close(fig)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ output_dir = Path("example_output/scaling_test")
86
+ output_dir.mkdir(exist_ok=True, parents=True)
87
+
88
+ num_frames_to_draw = 320
89
+ num_workers_to_test = [1, 2, 4, 8, 16]
90
+ results = run_scaling_test(num_frames_to_draw, num_workers_to_test, output_dir)
91
+ with open(output_dir / "results.json", "w") as f:
92
+ json.dump(results, f, indent=4)
93
+
94
+ with open(output_dir / "results.json", "r") as f:
95
+ results = json.load(f)
96
+ plot_scaling_results(results, output_path=output_dir / "scaling_graph.png")
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from parallel_animate import Animator
7
+
8
+
9
+ class WaveAnimation(Animator):
10
+ def setup(self):
11
+ fig, ax = plt.subplots()
12
+ self.x = np.linspace(0, 4 * np.pi, 200)
13
+ (self.line,) = ax.plot(self.x, np.cos(self.x))
14
+ ax.set_xlim(0, 4 * np.pi)
15
+ ax.set_ylim(-1.5, 1.5)
16
+ ax.set_xlabel("x")
17
+ ax.set_ylabel("y")
18
+ ax.set_title("Cosine Wave")
19
+ return fig
20
+
21
+ def update(self, frame_idx, params):
22
+ phase = params["phase"]
23
+ self.line.set_ydata(np.cos(self.x + phase))
24
+
25
+
26
+ if __name__ == "__main__":
27
+ # Generate parameters
28
+ num_frames = 60
29
+ params = [{"phase": 2 * np.pi * i / num_frames} for i in range(num_frames)]
30
+
31
+ # Create and render
32
+ anim = WaveAnimation()
33
+ output_path = Path("example_output/simple_wave_animation.mp4")
34
+ output_path.parent.mkdir(parents=True, exist_ok=True)
35
+ anim.make_video(
36
+ output_file=output_path, param_by_frame=params, fps=30, num_workers=4
37
+ )