openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,605 @@
1
+ """Animated scene elements for viser visualization.
2
+
3
+ Each function in this module adds an animated element to a viser scene and
4
+ returns a tuple of ``(handle, update_callback)``. The update callback has
5
+ signature ``update_callback(frame_idx: int) -> None`` and updates the visual
6
+ to reflect the state at that frame index.
7
+
8
+ Collect these callbacks and pass them to ``add_animation_controls()`` to
9
+ wire up playback with GUI controls (play/pause, scrubber, speed, etc.).
10
+
11
+ Example::
12
+
13
+ _, update_trail = add_animated_trail(server, positions, colors)
14
+ _, update_marker = add_position_marker(server, positions)
15
+ _, update_thrust = add_thrust_vector(server, positions, thrust, attitude)
16
+
17
+ add_animation_controls(server, time_array, [update_trail, update_marker, update_thrust])
18
+ """
19
+
20
+ import threading
21
+ import time
22
+ from typing import Callable
23
+
24
+ import numpy as np
25
+ import viser
26
+
27
+ # Type alias for update callbacks: fn(frame_idx: int) -> None
28
+ UpdateCallback = Callable[[int], None]
29
+
30
+
31
+ def add_animated_trail(
32
+ server: viser.ViserServer,
33
+ pos: np.ndarray,
34
+ colors: np.ndarray,
35
+ point_size: float = 0.15,
36
+ ) -> tuple[viser.PointCloudHandle, UpdateCallback]:
37
+ """Add an animated trail that grows with the animation.
38
+
39
+ Args:
40
+ server: ViserServer instance
41
+ pos: Position array of shape (N, 3)
42
+ colors: RGB color array of shape (N, 3)
43
+ point_size: Size of trail points
44
+
45
+ Returns:
46
+ Tuple of (handle, update_callback)
47
+ """
48
+ handle = server.scene.add_point_cloud(
49
+ "/trail",
50
+ points=pos[:1],
51
+ colors=colors[:1],
52
+ point_size=point_size,
53
+ )
54
+
55
+ def update(frame_idx: int) -> None:
56
+ idx = frame_idx + 1 # Include current frame
57
+ handle.points = pos[:idx]
58
+ handle.colors = colors[:idx]
59
+
60
+ return handle, update
61
+
62
+
63
+ def add_position_marker(
64
+ server: viser.ViserServer,
65
+ pos: np.ndarray,
66
+ radius: float = 0.5,
67
+ color: tuple[int, int, int] = (100, 200, 255),
68
+ ) -> tuple[viser.IcosphereHandle, UpdateCallback]:
69
+ """Add an animated position marker (sphere at current position).
70
+
71
+ Args:
72
+ server: ViserServer instance
73
+ pos: Position array of shape (N, 3)
74
+ radius: Marker radius
75
+ color: RGB color tuple
76
+
77
+ Returns:
78
+ Tuple of (handle, update_callback)
79
+ """
80
+ handle = server.scene.add_icosphere(
81
+ "/current_pos",
82
+ radius=radius,
83
+ color=color,
84
+ position=pos[0],
85
+ )
86
+
87
+ def update(frame_idx: int) -> None:
88
+ handle.position = pos[frame_idx]
89
+
90
+ return handle, update
91
+
92
+
93
+ def add_target_marker(
94
+ server: viser.ViserServer,
95
+ target_pos: np.ndarray,
96
+ name: str = "target",
97
+ radius: float = 0.8,
98
+ color: tuple[int, int, int] = (255, 50, 50),
99
+ show_trail: bool = True,
100
+ trail_color: tuple[int, int, int] | None = None,
101
+ ) -> tuple[viser.IcosphereHandle, UpdateCallback | None]:
102
+ """Add a viewplanning target marker (static or moving).
103
+
104
+ Args:
105
+ server: ViserServer instance
106
+ target_pos: Target position - either shape (3,) for static or (N, 3) for moving
107
+ name: Unique name for this target (used in scene path)
108
+ radius: Marker radius
109
+ color: RGB color tuple for marker
110
+ show_trail: If True and target is moving, show trajectory trail
111
+ trail_color: RGB color for trail (defaults to dimmed marker color)
112
+
113
+ Returns:
114
+ Tuple of (handle, update_callback). update_callback is None for static targets.
115
+ """
116
+ target_pos = np.asarray(target_pos)
117
+
118
+ # Check if static (single position) or moving (trajectory)
119
+ is_moving = target_pos.ndim == 2 and target_pos.shape[0] > 1
120
+
121
+ initial_pos = target_pos[0] if is_moving else target_pos
122
+
123
+ # Add marker
124
+ handle = server.scene.add_icosphere(
125
+ f"/targets/{name}/marker",
126
+ radius=radius,
127
+ color=color,
128
+ position=initial_pos,
129
+ )
130
+
131
+ # For moving targets, optionally show trail
132
+ if is_moving and show_trail:
133
+ if trail_color is None:
134
+ trail_color = tuple(int(c * 0.5) for c in color)
135
+ server.scene.add_point_cloud(
136
+ f"/targets/{name}/trail",
137
+ points=target_pos,
138
+ colors=trail_color,
139
+ point_size=0.1,
140
+ )
141
+
142
+ if not is_moving:
143
+ # Static target - no update needed
144
+ return handle, None
145
+
146
+ def update(frame_idx: int) -> None:
147
+ # Clamp to valid range for target trajectory
148
+ idx = min(frame_idx, len(target_pos) - 1)
149
+ handle.position = target_pos[idx]
150
+
151
+ return handle, update
152
+
153
+
154
+ def add_target_markers(
155
+ server: viser.ViserServer,
156
+ target_positions: list[np.ndarray],
157
+ colors: list[tuple[int, int, int]] | None = None,
158
+ radius: float = 0.8,
159
+ show_trails: bool = True,
160
+ ) -> list[tuple[viser.IcosphereHandle, UpdateCallback | None]]:
161
+ """Add multiple viewplanning target markers.
162
+
163
+ Args:
164
+ server: ViserServer instance
165
+ target_positions: List of target positions, each either (3,) or (N, 3)
166
+ colors: List of RGB colors, one per target. Defaults to distinct colors.
167
+ radius: Marker radius
168
+ show_trails: If True, show trails for moving targets
169
+
170
+ Returns:
171
+ List of (handle, update_callback) tuples
172
+ """
173
+ # Default colors if not provided
174
+ if colors is None:
175
+ default_colors = [
176
+ (255, 50, 50), # Red
177
+ (50, 255, 50), # Green
178
+ (50, 50, 255), # Blue
179
+ (255, 255, 50), # Yellow
180
+ (255, 50, 255), # Magenta
181
+ (50, 255, 255), # Cyan
182
+ ]
183
+ colors = [default_colors[i % len(default_colors)] for i in range(len(target_positions))]
184
+
185
+ results = []
186
+ for i, (pos, color) in enumerate(zip(target_positions, colors)):
187
+ handle, update = add_target_marker(
188
+ server,
189
+ pos,
190
+ name=f"target_{i}",
191
+ radius=radius,
192
+ color=color,
193
+ show_trail=show_trails,
194
+ )
195
+ results.append((handle, update))
196
+
197
+ return results
198
+
199
+
200
+ def _rotate_vector_by_quaternion(v: np.ndarray, q: np.ndarray) -> np.ndarray:
201
+ """Rotate vector v by quaternion q (wxyz format).
202
+
203
+ Args:
204
+ v: Vector of shape (3,)
205
+ q: Quaternion of shape (4,) in [w, x, y, z] format
206
+
207
+ Returns:
208
+ Rotated vector of shape (3,)
209
+ """
210
+ w, x, y, z = q
211
+ # Quaternion rotation: v' = q * v * q_conj
212
+ # Using the formula for rotating a vector by a quaternion
213
+ t = 2.0 * np.cross(np.array([x, y, z]), v)
214
+ return v + w * t + np.cross(np.array([x, y, z]), t)
215
+
216
+
217
+ def add_thrust_vector(
218
+ server: viser.ViserServer,
219
+ pos: np.ndarray,
220
+ thrust: np.ndarray | None,
221
+ attitude: np.ndarray | None = None,
222
+ scale: float = 0.3,
223
+ color: tuple[int, int, int] = (255, 100, 100),
224
+ line_width: float = 4.0,
225
+ ) -> tuple[viser.LineSegmentsHandle | None, UpdateCallback | None]:
226
+ """Add an animated thrust/force vector visualization.
227
+
228
+ Args:
229
+ server: ViserServer instance
230
+ pos: Position array of shape (N, 3)
231
+ thrust: Thrust/force array of shape (N, 3), or None to skip
232
+ attitude: Quaternion array of shape (N, 4) in [w, x, y, z] format.
233
+ If provided, thrust is assumed to be in body frame and will be
234
+ rotated to world frame using the attitude.
235
+ scale: Scale factor for thrust vector length
236
+ color: RGB color tuple
237
+ line_width: Line width
238
+
239
+ Returns:
240
+ Tuple of (handle, update_callback), or (None, None) if thrust is None
241
+ """
242
+ if thrust is None:
243
+ return None, None
244
+
245
+ def get_thrust_world(frame_idx: int) -> np.ndarray:
246
+ """Get thrust vector in world frame."""
247
+ thrust_body = thrust[frame_idx]
248
+ if attitude is not None:
249
+ return _rotate_vector_by_quaternion(thrust_body, attitude[frame_idx])
250
+ return thrust_body
251
+
252
+ thrust_world = get_thrust_world(0)
253
+ thrust_end = pos[0] + thrust_world * scale
254
+ handle = server.scene.add_line_segments(
255
+ "/thrust_vector",
256
+ points=np.array([[pos[0], thrust_end]]), # Shape (1, 2, 3)
257
+ colors=color,
258
+ line_width=line_width,
259
+ )
260
+
261
+ def update(frame_idx: int) -> None:
262
+ thrust_world = get_thrust_world(frame_idx)
263
+ thrust_end = pos[frame_idx] + thrust_world * scale
264
+ handle.points = np.array([[pos[frame_idx], thrust_end]])
265
+
266
+ return handle, update
267
+
268
+
269
+ def add_attitude_frame(
270
+ server: viser.ViserServer,
271
+ pos: np.ndarray,
272
+ attitude: np.ndarray | None,
273
+ axes_length: float = 2.0,
274
+ axes_radius: float = 0.05,
275
+ ) -> tuple[viser.FrameHandle | None, UpdateCallback | None]:
276
+ """Add an animated body coordinate frame showing attitude.
277
+
278
+ Args:
279
+ server: ViserServer instance
280
+ pos: Position array of shape (N, 3)
281
+ attitude: Quaternion array of shape (N, 4) in [w, x, y, z] format, or None to skip
282
+ axes_length: Length of the coordinate axes
283
+ axes_radius: Radius of the axes cylinders
284
+
285
+ Returns:
286
+ Tuple of (handle, update_callback), or (None, None) if attitude is None
287
+ """
288
+ if attitude is None:
289
+ return None, None
290
+
291
+ # Viser uses wxyz quaternion format
292
+ handle = server.scene.add_frame(
293
+ "/body_frame",
294
+ wxyz=attitude[0],
295
+ position=pos[0],
296
+ axes_length=axes_length,
297
+ axes_radius=axes_radius,
298
+ )
299
+
300
+ def update(frame_idx: int) -> None:
301
+ handle.wxyz = attitude[frame_idx]
302
+ handle.position = pos[frame_idx]
303
+
304
+ return handle, update
305
+
306
+
307
+ def _generate_viewcone_vertices(
308
+ half_angle_x: float,
309
+ half_angle_y: float | None,
310
+ depth: float,
311
+ norm_type: float | str,
312
+ n_segments: int = 32,
313
+ ) -> np.ndarray:
314
+ """Generate viewcone vertices in sensor frame (apex at origin, pointing along +Z).
315
+
316
+ The base cross-section follows the p-norm unit ball boundary (superellipse):
317
+ ||[x/a, y/b]||_p = 1
318
+
319
+ Args:
320
+ half_angle_x: Half-angle in x direction (radians)
321
+ half_angle_y: Half-angle in y direction (radians). If None, uses half_angle_x.
322
+ depth: Depth/length of the cone
323
+ norm_type: p-norm value (1, 2, 3, ..., or "inf"/float("inf") for infinity norm)
324
+ n_segments: Number of segments around the boundary
325
+
326
+ Returns:
327
+ Vertices array of shape (N, 3) where first vertex is apex at origin
328
+ """
329
+ if half_angle_y is None:
330
+ half_angle_y = half_angle_x
331
+
332
+ # Compute base dimensions at the given depth
333
+ base_half_x = depth * np.tan(half_angle_x)
334
+ base_half_y = depth * np.tan(half_angle_y)
335
+
336
+ vertices = [[0.0, 0.0, 0.0]] # Apex at origin
337
+
338
+ # Handle inf norm
339
+ if norm_type == "inf" or norm_type == float("inf"):
340
+ p = 100.0 # Large p approximates inf-norm
341
+ else:
342
+ p = float(norm_type)
343
+
344
+ # Generate superellipse boundary points
345
+ # Parameterization: x = sign(cos(t)) * |cos(t)|^(2/p), y = sign(sin(t)) * |sin(t)|^(2/p)
346
+ for i in range(n_segments):
347
+ theta = 2 * np.pi * i / n_segments
348
+ cos_t = np.cos(theta)
349
+ sin_t = np.sin(theta)
350
+
351
+ # Superellipse parameterization
352
+ x = np.sign(cos_t) * (np.abs(cos_t) ** (2.0 / p)) * base_half_x
353
+ y = np.sign(sin_t) * (np.abs(sin_t) ** (2.0 / p)) * base_half_y
354
+ vertices.append([x, y, depth])
355
+
356
+ return np.array(vertices, dtype=np.float32)
357
+
358
+
359
+ def _generate_viewcone_faces(n_base_vertices: int) -> np.ndarray:
360
+ """Generate faces for a cone/pyramid mesh.
361
+
362
+ Args:
363
+ n_base_vertices: Number of vertices on the base (excluding apex)
364
+
365
+ Returns:
366
+ Faces array of shape (F, 3) with vertex indices
367
+ """
368
+ faces = []
369
+
370
+ # Side faces: triangles from apex (index 0) to each edge of base
371
+ # Winding: apex -> current -> next gives outward-facing normals (visible from outside)
372
+ for i in range(n_base_vertices):
373
+ current_i = i + 1
374
+ next_i = (i + 1) % n_base_vertices + 1
375
+ faces.append([0, current_i, next_i])
376
+
377
+ # Base cap: triangulate as a fan from first base vertex
378
+ # Winding for outward-facing normal (visible from outside/below the cone)
379
+ for i in range(2, n_base_vertices):
380
+ faces.append([1, i + 1, i])
381
+
382
+ return np.array(faces, dtype=np.int32)
383
+
384
+
385
+ def _quaternion_to_rotation_matrix(q: np.ndarray) -> np.ndarray:
386
+ """Convert quaternion (wxyz) to rotation matrix.
387
+
388
+ Args:
389
+ q: Quaternion [w, x, y, z]
390
+
391
+ Returns:
392
+ 3x3 rotation matrix
393
+ """
394
+ w, x, y, z = q
395
+ return np.array(
396
+ [
397
+ [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
398
+ [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
399
+ [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
400
+ ]
401
+ )
402
+
403
+
404
+ def add_viewcone(
405
+ server: viser.ViserServer,
406
+ pos: np.ndarray,
407
+ attitude: np.ndarray | None,
408
+ half_angle_x: float,
409
+ half_angle_y: float | None = None,
410
+ scale: float = 10.0,
411
+ norm_type: float | str = 2,
412
+ R_sb: np.ndarray | None = None,
413
+ color: tuple[int, int, int] = (35, 138, 141), # Viridis at t~0.33 (teal)
414
+ opacity: float = 0.4,
415
+ wireframe: bool = False,
416
+ n_segments: int = 32,
417
+ ) -> tuple[viser.MeshHandle | None, UpdateCallback | None]:
418
+ """Add an animated viewcone mesh that matches p-norm constraints.
419
+
420
+ The sensor is assumed to look along +Z in its own frame (boresight = [0,0,1]).
421
+ The viewcone represents the constraint ||[x,y]||_p <= tan(alpha) * z.
422
+
423
+ Cross-section shapes by norm:
424
+ - p=1: diamond
425
+ - p=2: circle/ellipse
426
+ - p>2: rounded square (superellipse)
427
+ - p=inf: square/rectangle
428
+
429
+ Args:
430
+ server: ViserServer instance
431
+ pos: Position array of shape (N, 3)
432
+ attitude: Quaternion array of shape (N, 4) in [w, x, y, z] format, or None to skip
433
+ half_angle_x: Half-angle of the cone in x direction (radians).
434
+ For symmetric cones, this is pi/alpha_x where alpha_x is the constraint parameter.
435
+ half_angle_y: Half-angle in y direction (radians). If None, uses half_angle_x.
436
+ For asymmetric constraints, this is pi/alpha_y.
437
+ scale: Depth/length of the cone visualization
438
+ norm_type: p-norm value (1, 2, 3, ..., or "inf" for infinity norm)
439
+ R_sb: Body-to-sensor rotation matrix (3x3). If None, sensor is aligned with body z-axis.
440
+ color: RGB color tuple
441
+ opacity: Mesh opacity (0-1), ignored if wireframe=True
442
+ wireframe: If True, render as wireframe instead of solid
443
+ n_segments: Number of segments for cone smoothness
444
+
445
+ Returns:
446
+ Tuple of (handle, update_callback), or (None, None) if attitude is None
447
+ """
448
+ if attitude is None:
449
+ return None, None
450
+
451
+ # Convert inputs to numpy arrays (handles JAX arrays)
452
+ pos = np.asarray(pos, dtype=np.float64)
453
+ attitude = np.asarray(attitude, dtype=np.float64)
454
+ if R_sb is not None:
455
+ R_sb = np.asarray(R_sb, dtype=np.float64)
456
+
457
+ # Generate base geometry in sensor frame
458
+ base_vertices = _generate_viewcone_vertices(
459
+ half_angle_x, half_angle_y, scale, norm_type, n_segments
460
+ )
461
+ n_base_verts = len(base_vertices) - 1 # Exclude apex
462
+ faces = _generate_viewcone_faces(n_base_verts)
463
+
464
+ # Sensor-to-body rotation (transpose of body-to-sensor)
465
+ R_sensor_to_body = R_sb.T if R_sb is not None else np.eye(3)
466
+
467
+ def transform_vertices(frame_idx: int) -> np.ndarray:
468
+ """Transform cone vertices from sensor frame to world frame."""
469
+ # Get body-to-world rotation from attitude quaternion
470
+ q_body = attitude[frame_idx]
471
+ R_body_to_world = _quaternion_to_rotation_matrix(q_body)
472
+
473
+ # Full transform: sensor -> body -> world
474
+ R_sensor_to_world = R_body_to_world @ R_sensor_to_body
475
+
476
+ # Transform vertices and translate to position
477
+ world_vertices = (R_sensor_to_world @ base_vertices.T).T + pos[frame_idx]
478
+ return world_vertices.astype(np.float32)
479
+
480
+ # Create initial mesh
481
+ initial_vertices = transform_vertices(0)
482
+ handle = server.scene.add_mesh_simple(
483
+ "/viewcone_mesh",
484
+ vertices=initial_vertices,
485
+ faces=faces,
486
+ color=color,
487
+ wireframe=wireframe,
488
+ opacity=opacity if not wireframe else 1.0,
489
+ )
490
+
491
+ def update(frame_idx: int) -> None:
492
+ handle.vertices = transform_vertices(frame_idx)
493
+
494
+ return handle, update
495
+
496
+
497
+ # =============================================================================
498
+ # Animation Controls
499
+ # =============================================================================
500
+
501
+
502
+ def add_animation_controls(
503
+ server: viser.ViserServer,
504
+ traj_time: np.ndarray,
505
+ update_callbacks: list[UpdateCallback],
506
+ loop: bool = True,
507
+ folder_name: str = "Animation",
508
+ ) -> None:
509
+ """Add animation GUI controls and start the animation loop.
510
+
511
+ Creates play/pause button, reset button, time slider, speed slider, and loop checkbox.
512
+ Runs animation in a background daemon thread.
513
+
514
+ Args:
515
+ server: ViserServer instance
516
+ traj_time: Time array of shape (N,) with timestamps for each frame
517
+ update_callbacks: List of update functions to call each frame
518
+ loop: Whether to loop animation by default
519
+ folder_name: Name for the GUI folder
520
+ """
521
+ traj_time = traj_time.flatten()
522
+ n_frames = len(traj_time)
523
+ t_start, t_end = float(traj_time[0]), float(traj_time[-1])
524
+ duration = t_end - t_start
525
+
526
+ # Filter out None callbacks
527
+ callbacks = [cb for cb in update_callbacks if cb is not None]
528
+
529
+ def time_to_frame(t: float) -> int:
530
+ """Convert simulation time to frame index."""
531
+ return int(np.clip(np.searchsorted(traj_time, t, side="right") - 1, 0, n_frames - 1))
532
+
533
+ def update_all(sim_t: float) -> None:
534
+ """Update all visualization components."""
535
+ idx = time_to_frame(sim_t)
536
+ for callback in callbacks:
537
+ callback(idx)
538
+
539
+ # --- GUI Controls ---
540
+ with server.gui.add_folder(folder_name):
541
+ play_button = server.gui.add_button("Play")
542
+ reset_button = server.gui.add_button("Reset")
543
+ time_slider = server.gui.add_slider(
544
+ "Time (s)",
545
+ min=t_start,
546
+ max=t_end,
547
+ step=duration / 100,
548
+ initial_value=t_start,
549
+ )
550
+ speed_slider = server.gui.add_slider(
551
+ "Speed",
552
+ min=0.1,
553
+ max=5.0,
554
+ step=0.1,
555
+ initial_value=1.0,
556
+ )
557
+ loop_checkbox = server.gui.add_checkbox("Loop", initial_value=loop)
558
+
559
+ # Animation state
560
+ state = {"playing": False, "sim_time": t_start}
561
+
562
+ @play_button.on_click
563
+ def _(_) -> None:
564
+ state["playing"] = not state["playing"]
565
+ play_button.name = "Pause" if state["playing"] else "Play"
566
+
567
+ @reset_button.on_click
568
+ def _(_) -> None:
569
+ state["sim_time"] = t_start
570
+ time_slider.value = t_start
571
+ update_all(t_start)
572
+
573
+ @time_slider.on_update
574
+ def _(_) -> None:
575
+ if not state["playing"]:
576
+ state["sim_time"] = float(time_slider.value)
577
+ update_all(state["sim_time"])
578
+
579
+ def animation_loop() -> None:
580
+ """Background thread for realtime animation playback."""
581
+ last_time = time.time()
582
+ while True:
583
+ time.sleep(0.016) # ~60 fps
584
+ current_time = time.time()
585
+ dt = current_time - last_time
586
+ last_time = current_time
587
+
588
+ if state["playing"]:
589
+ # Advance simulation time (speed=1.0 is realtime)
590
+ state["sim_time"] += dt * speed_slider.value
591
+
592
+ if state["sim_time"] >= t_end:
593
+ if loop_checkbox.value:
594
+ state["sim_time"] = t_start
595
+ else:
596
+ state["sim_time"] = t_end
597
+ state["playing"] = False
598
+ play_button.name = "Play"
599
+
600
+ time_slider.value = state["sim_time"]
601
+ update_all(state["sim_time"])
602
+
603
+ # Start animation thread
604
+ thread = threading.Thread(target=animation_loop, daemon=True)
605
+ thread.start()