imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
| 
         @@ -0,0 +1,340 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from abc import ABC
         
     | 
| 
      
 2 
     | 
    
         
            +
            from abc import abstractmethod
         
     | 
| 
      
 3 
     | 
    
         
            +
            from abc import abstractstaticmethod
         
     | 
| 
      
 4 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Optional, TypeVar
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 8 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 9 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 10 
     | 
    
         
            +
            from ring import algebra
         
     | 
| 
      
 11 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 12 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 13 
     | 
    
         
            +
            from tree_utils import PyTree
         
     | 
| 
      
 14 
     | 
    
         
            +
            from tree_utils import tree_batch
         
     | 
| 
      
 15 
     | 
    
         
            +
            from vispy import scene
         
     | 
| 
      
 16 
     | 
    
         
            +
            from vispy.scene import MatrixTransform
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            from . import vispy_visuals
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            Camera = TypeVar("Camera")
         
     | 
| 
      
 21 
     | 
    
         
            +
            Visual = TypeVar("Visual")
         
     | 
| 
      
 22 
     | 
    
         
            +
            VisualPosOri1 = PyTree
         
     | 
| 
      
 23 
     | 
    
         
            +
            VisualPosOri2 = PyTree
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
            class Scene(ABC):
         
     | 
| 
      
 27 
     | 
    
         
            +
                _xyz: bool = True
         
     | 
| 
      
 28 
     | 
    
         
            +
                _xyz_root: bool = True
         
     | 
| 
      
 29 
     | 
    
         
            +
                _xyz_transform1: bool = True
         
     | 
| 
      
 30 
     | 
    
         
            +
                visuals: list[Visual] = []
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                """
         
     | 
| 
      
 33 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 34 
     | 
    
         
            +
                    >> renderer = Renderer()
         
     | 
| 
      
 35 
     | 
    
         
            +
                    >> renderer.init(sys.geoms)
         
     | 
| 
      
 36 
     | 
    
         
            +
                    >> for t in range(xs.shape()):
         
     | 
| 
      
 37 
     | 
    
         
            +
                    >>   renderer.update(xs[t])
         
     | 
| 
      
 38 
     | 
    
         
            +
                    >>   image = renderer.render()
         
     | 
| 
      
 39 
     | 
    
         
            +
                """
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 42 
     | 
    
         
            +
                def _get_camera(self) -> Camera:
         
     | 
| 
      
 43 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 46 
     | 
    
         
            +
                def _set_camera(self, camera: Camera) -> None:
         
     | 
| 
      
 47 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 50 
     | 
    
         
            +
                def _render(self) -> jax.Array:
         
     | 
| 
      
 51 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                def enable_xyz(self, enable_root: bool = True) -> None:
         
     | 
| 
      
 54 
     | 
    
         
            +
                    self._xyz = True
         
     | 
| 
      
 55 
     | 
    
         
            +
                    if enable_root:
         
     | 
| 
      
 56 
     | 
    
         
            +
                        self._xyz_root = True
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                def disable_xyz(self, disable_root: bool = True) -> None:
         
     | 
| 
      
 59 
     | 
    
         
            +
                    self._xyz = False
         
     | 
| 
      
 60 
     | 
    
         
            +
                    if disable_root:
         
     | 
| 
      
 61 
     | 
    
         
            +
                        self._xyz_root = False
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                def enable_xyz_transform1(self):
         
     | 
| 
      
 64 
     | 
    
         
            +
                    self._xyz_transform1 = True
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                def disable_xyz_tranform1(self):
         
     | 
| 
      
 67 
     | 
    
         
            +
                    self._xyz_transform1 = False
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                def render(
         
     | 
| 
      
 70 
     | 
    
         
            +
                    self, camera: Optional[Camera | list[Camera]] = None
         
     | 
| 
      
 71 
     | 
    
         
            +
                ) -> jax.Array | list[jax.Array]:
         
     | 
| 
      
 72 
     | 
    
         
            +
                    "Returns: RGBA Array of Shape = (M, N, 4)"
         
     | 
| 
      
 73 
     | 
    
         
            +
                    if camera is None:
         
     | 
| 
      
 74 
     | 
    
         
            +
                        camera = self._get_camera()
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                    if not isinstance(camera, list):
         
     | 
| 
      
 77 
     | 
    
         
            +
                        self._set_camera(camera)
         
     | 
| 
      
 78 
     | 
    
         
            +
                        return self._render()
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                    images = []
         
     | 
| 
      
 81 
     | 
    
         
            +
                    for cam in camera:
         
     | 
| 
      
 82 
     | 
    
         
            +
                        self._set_camera(cam)
         
     | 
| 
      
 83 
     | 
    
         
            +
                        images.append(self._render())
         
     | 
| 
      
 84 
     | 
    
         
            +
                    return images
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                def _add_box(self, box: base.Box) -> Visual:
         
     | 
| 
      
 87 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
                def _add_sphere(self, sphere: base.Sphere) -> Visual:
         
     | 
| 
      
 90 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
         
     | 
| 
      
 93 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                def _add_capsule(self, cap: base.Capsule) -> Visual:
         
     | 
| 
      
 96 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                def _add_xyz(self) -> Visual:
         
     | 
| 
      
 99 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 102 
     | 
    
         
            +
                def _remove_visual(self, visual: Visual) -> None:
         
     | 
| 
      
 103 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                def _remove_all_visuals(self):
         
     | 
| 
      
 106 
     | 
    
         
            +
                    for visual in self.visuals:
         
     | 
| 
      
 107 
     | 
    
         
            +
                        self._remove_visual(visual)
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
                def init(self, geoms: list[base.Geometry]):
         
     | 
| 
      
 110 
     | 
    
         
            +
                    self._remove_all_visuals()
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                    self.geoms = [geom for geom in geoms]
         
     | 
| 
      
 113 
     | 
    
         
            +
                    self._fresh_init = True
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                    geom_link_idx = []
         
     | 
| 
      
 116 
     | 
    
         
            +
                    geom_transform = []
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self.visuals = []
         
     | 
| 
      
 118 
     | 
    
         
            +
                    for geom in geoms:
         
     | 
| 
      
 119 
     | 
    
         
            +
                        geom_link_idx.append(geom.link_idx)
         
     | 
| 
      
 120 
     | 
    
         
            +
                        geom_transform.append(geom.transform)
         
     | 
| 
      
 121 
     | 
    
         
            +
                        if isinstance(geom, base.Box):
         
     | 
| 
      
 122 
     | 
    
         
            +
                            visual = self._add_box(geom)
         
     | 
| 
      
 123 
     | 
    
         
            +
                        elif isinstance(geom, base.Sphere):
         
     | 
| 
      
 124 
     | 
    
         
            +
                            visual = self._add_sphere(geom)
         
     | 
| 
      
 125 
     | 
    
         
            +
                        elif isinstance(geom, base.Cylinder):
         
     | 
| 
      
 126 
     | 
    
         
            +
                            visual = self._add_cylinder(geom)
         
     | 
| 
      
 127 
     | 
    
         
            +
                        elif isinstance(geom, base.Capsule):
         
     | 
| 
      
 128 
     | 
    
         
            +
                            visual = self._add_capsule(geom)
         
     | 
| 
      
 129 
     | 
    
         
            +
                        elif isinstance(geom, base.XYZ):
         
     | 
| 
      
 130 
     | 
    
         
            +
                            visual = self._add_xyz()
         
     | 
| 
      
 131 
     | 
    
         
            +
                            if not self._xyz_transform1:
         
     | 
| 
      
 132 
     | 
    
         
            +
                                geom_transform.pop()
         
     | 
| 
      
 133 
     | 
    
         
            +
                                geom_transform.append(base.Transform.zero())
         
     | 
| 
      
 134 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 135 
     | 
    
         
            +
                            raise Exception(f"Unknown geom type: {type(geom)}")
         
     | 
| 
      
 136 
     | 
    
         
            +
                        self.visuals.append(visual)
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                    if self._xyz:
         
     | 
| 
      
 139 
     | 
    
         
            +
                        unique_link_indices = np.unique(np.array(geom_link_idx))
         
     | 
| 
      
 140 
     | 
    
         
            +
                        for unique_link_idx in unique_link_indices:
         
     | 
| 
      
 141 
     | 
    
         
            +
                            geom_link_idx.append(unique_link_idx)
         
     | 
| 
      
 142 
     | 
    
         
            +
                            geom_transform.append(base.Transform.zero())
         
     | 
| 
      
 143 
     | 
    
         
            +
                            self.visuals.append(self._add_xyz())
         
     | 
| 
      
 144 
     | 
    
         
            +
                            # otherwise the .update function won't iterate
         
     | 
| 
      
 145 
     | 
    
         
            +
                            # over all visuals since it uses a zip(...)
         
     | 
| 
      
 146 
     | 
    
         
            +
                            self.geoms.append(None)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                    if self._xyz_root:
         
     | 
| 
      
 149 
     | 
    
         
            +
                        # add one final for root frame
         
     | 
| 
      
 150 
     | 
    
         
            +
                        self._add_xyz()
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                    self.geom_link_idx = tree_batch(geom_link_idx, backend="jax")
         
     | 
| 
      
 153 
     | 
    
         
            +
                    self.geom_transform = tree_batch(geom_transform, backend="jax")
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
                @abstractstaticmethod
         
     | 
| 
      
 156 
     | 
    
         
            +
                def _compute_transform_per_visual(
         
     | 
| 
      
 157 
     | 
    
         
            +
                    x_links: base.Transform,
         
     | 
| 
      
 158 
     | 
    
         
            +
                    x_link_to_geom: base.Transform,
         
     | 
| 
      
 159 
     | 
    
         
            +
                    geom_link_idx: int,
         
     | 
| 
      
 160 
     | 
    
         
            +
                ) -> VisualPosOri1:
         
     | 
| 
      
 161 
     | 
    
         
            +
                    "This can easily account for possible convention differences"
         
     | 
| 
      
 162 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                @abstractstaticmethod
         
     | 
| 
      
 165 
     | 
    
         
            +
                def _postprocess_transforms(transform: VisualPosOri1) -> VisualPosOri2:
         
     | 
| 
      
 166 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 167 
     | 
    
         
            +
             
     | 
| 
      
 168 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 169 
     | 
    
         
            +
                def _init_visual(
         
     | 
| 
      
 170 
     | 
    
         
            +
                    self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
         
     | 
| 
      
 171 
     | 
    
         
            +
                ):
         
     | 
| 
      
 172 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                def _update_visual(
         
     | 
| 
      
 175 
     | 
    
         
            +
                    self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
         
     | 
| 
      
 176 
     | 
    
         
            +
                ):
         
     | 
| 
      
 177 
     | 
    
         
            +
                    self._init_visual(visual, transform, geom)
         
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
                def update(self, x: base.Transform):
         
     | 
| 
      
 180 
     | 
    
         
            +
                    "`x` are (n_links,) Transforms."
         
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
      
 182 
     | 
    
         
            +
                    # step 1: pre-compute all required transforms
         
     | 
| 
      
 183 
     | 
    
         
            +
                    transform_per_visual = _compile_staticmethod(
         
     | 
| 
      
 184 
     | 
    
         
            +
                        self._compute_transform_per_visual,
         
     | 
| 
      
 185 
     | 
    
         
            +
                        x,
         
     | 
| 
      
 186 
     | 
    
         
            +
                        self.geom_transform,
         
     | 
| 
      
 187 
     | 
    
         
            +
                        self.geom_link_idx,
         
     | 
| 
      
 188 
     | 
    
         
            +
                    )
         
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
                    # step 2: postprocess all transforms once
         
     | 
| 
      
 191 
     | 
    
         
            +
                    transform_per_visual = self._postprocess_transforms(transform_per_visual)
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
                    # step 3: update visuals
         
     | 
| 
      
 194 
     | 
    
         
            +
                    for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
         
     | 
| 
      
 195 
     | 
    
         
            +
                        t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
         
     | 
| 
      
 196 
     | 
    
         
            +
                        if self._fresh_init:
         
     | 
| 
      
 197 
     | 
    
         
            +
                            self._init_visual(visual, t, geom)
         
     | 
| 
      
 198 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 199 
     | 
    
         
            +
                            self._update_visual(visual, t, geom)
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                    # step 4: unset flag
         
     | 
| 
      
 202 
     | 
    
         
            +
                    self._fresh_init = False
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
            @partial(jax.jit, static_argnums=0)
         
     | 
| 
      
 206 
     | 
    
         
            +
            def _compile_staticmethod(static_method, x, geom_transform, geom_link_idx):
         
     | 
| 
      
 207 
     | 
    
         
            +
                return jax.vmap(static_method, in_axes=(None, 0, 0))(
         
     | 
| 
      
 208 
     | 
    
         
            +
                    x, geom_transform, geom_link_idx
         
     | 
| 
      
 209 
     | 
    
         
            +
                )
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
            class VispyScene(Scene):
         
     | 
| 
      
 213 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 214 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 215 
     | 
    
         
            +
                    show_cs=False,
         
     | 
| 
      
 216 
     | 
    
         
            +
                    show_cs_root=True,
         
     | 
| 
      
 217 
     | 
    
         
            +
                    width: int = 320,
         
     | 
| 
      
 218 
     | 
    
         
            +
                    height: int = 240,
         
     | 
| 
      
 219 
     | 
    
         
            +
                    camera: scene.cameras.BaseCamera = scene.TurntableCamera(
         
     | 
| 
      
 220 
     | 
    
         
            +
                        elevation=25, distance=4.0, azimuth=25
         
     | 
| 
      
 221 
     | 
    
         
            +
                    ),
         
     | 
| 
      
 222 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 223 
     | 
    
         
            +
                ):
         
     | 
| 
      
 224 
     | 
    
         
            +
                    """Scene which can be rendered.
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 227 
     | 
    
         
            +
                        geoms (list[list[Geometry]]): A list of list of geometries per link.
         
     | 
| 
      
 228 
     | 
    
         
            +
                            len(geoms) == number of links in system
         
     | 
| 
      
 229 
     | 
    
         
            +
                        show_cs (bool, optional): Show coordinate system of links.
         
     | 
| 
      
 230 
     | 
    
         
            +
                            Defaults to True.
         
     | 
| 
      
 231 
     | 
    
         
            +
                        show_cs_root (bool, optional): Show coordinate system of earth frame.
         
     | 
| 
      
 232 
     | 
    
         
            +
                            Defaults to True.
         
     | 
| 
      
 233 
     | 
    
         
            +
                        camera (scene.cameras.BaseCamera, optional): The camera angle.
         
     | 
| 
      
 234 
     | 
    
         
            +
                            Defaults to scene.TurntableCamera( elevation=30, distance=6 ).
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                    Example:
         
     | 
| 
      
 237 
     | 
    
         
            +
                        >> scene = VispyScene()
         
     | 
| 
      
 238 
     | 
    
         
            +
                        >> scene.init(sys.geoms)
         
     | 
| 
      
 239 
     | 
    
         
            +
                        >> scene.update(state.x)
         
     | 
| 
      
 240 
     | 
    
         
            +
                        >> image = scene.render()
         
     | 
| 
      
 241 
     | 
    
         
            +
                    """
         
     | 
| 
      
 242 
     | 
    
         
            +
                    self.canvas = scene.SceneCanvas(
         
     | 
| 
      
 243 
     | 
    
         
            +
                        keys="interactive", size=(width, height), show=True, **kwargs
         
     | 
| 
      
 244 
     | 
    
         
            +
                    )
         
     | 
| 
      
 245 
     | 
    
         
            +
                    self.view = self.canvas.central_widget.add_view()
         
     | 
| 
      
 246 
     | 
    
         
            +
                    self._set_camera(camera)
         
     | 
| 
      
 247 
     | 
    
         
            +
                    if show_cs:
         
     | 
| 
      
 248 
     | 
    
         
            +
                        self.enable_xyz()
         
     | 
| 
      
 249 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 250 
     | 
    
         
            +
                        self.disable_xyz(not show_cs_root)
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
                def _set_camera(self, camera: scene.cameras.BaseCamera) -> None:
         
     | 
| 
      
 253 
     | 
    
         
            +
                    self.view.camera = camera
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
                def _get_camera(self) -> scene.cameras.BaseCamera:
         
     | 
| 
      
 256 
     | 
    
         
            +
                    return self.view.camera
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                def _render(self) -> jax.Array:
         
     | 
| 
      
 259 
     | 
    
         
            +
                    return self.canvas.render(alpha=True)
         
     | 
| 
      
 260 
     | 
    
         
            +
             
     | 
| 
      
 261 
     | 
    
         
            +
                def _add_box(self, box: base.Box) -> Visual:
         
     | 
| 
      
 262 
     | 
    
         
            +
                    return vispy_visuals.Box(
         
     | 
| 
      
 263 
     | 
    
         
            +
                        box.dim_x,
         
     | 
| 
      
 264 
     | 
    
         
            +
                        box.dim_z,
         
     | 
| 
      
 265 
     | 
    
         
            +
                        box.dim_y,
         
     | 
| 
      
 266 
     | 
    
         
            +
                        color=box.color,
         
     | 
| 
      
 267 
     | 
    
         
            +
                        edge_color=box.edge_color,
         
     | 
| 
      
 268 
     | 
    
         
            +
                        parent=self.view.scene,
         
     | 
| 
      
 269 
     | 
    
         
            +
                    )
         
     | 
| 
      
 270 
     | 
    
         
            +
             
     | 
| 
      
 271 
     | 
    
         
            +
                def _add_sphere(self, sphere: base.Sphere) -> Visual:
         
     | 
| 
      
 272 
     | 
    
         
            +
                    return vispy_visuals.Sphere(
         
     | 
| 
      
 273 
     | 
    
         
            +
                        sphere.radius,
         
     | 
| 
      
 274 
     | 
    
         
            +
                        color=sphere.color,
         
     | 
| 
      
 275 
     | 
    
         
            +
                        edge_color=sphere.edge_color,
         
     | 
| 
      
 276 
     | 
    
         
            +
                        parent=self.view.scene,
         
     | 
| 
      
 277 
     | 
    
         
            +
                    )
         
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
                def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
         
     | 
| 
      
 280 
     | 
    
         
            +
                    return vispy_visuals.Cylinder(
         
     | 
| 
      
 281 
     | 
    
         
            +
                        cyl.radius,
         
     | 
| 
      
 282 
     | 
    
         
            +
                        cyl.length,
         
     | 
| 
      
 283 
     | 
    
         
            +
                        color=cyl.color,
         
     | 
| 
      
 284 
     | 
    
         
            +
                        edge_color=cyl.edge_color,
         
     | 
| 
      
 285 
     | 
    
         
            +
                        parent=self.view.scene,
         
     | 
| 
      
 286 
     | 
    
         
            +
                    )
         
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
                def _add_capsule(self, cap: base.Capsule) -> Visual:
         
     | 
| 
      
 289 
     | 
    
         
            +
                    return vispy_visuals.Capsule(
         
     | 
| 
      
 290 
     | 
    
         
            +
                        cap.radius,
         
     | 
| 
      
 291 
     | 
    
         
            +
                        cap.length,
         
     | 
| 
      
 292 
     | 
    
         
            +
                        color=cap.color,
         
     | 
| 
      
 293 
     | 
    
         
            +
                        edge_color=cap.edge_color,
         
     | 
| 
      
 294 
     | 
    
         
            +
                        parent=self.view.scene,
         
     | 
| 
      
 295 
     | 
    
         
            +
                    )
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
                def _add_xyz(self) -> Visual:
         
     | 
| 
      
 298 
     | 
    
         
            +
                    return scene.visuals.XYZAxis(parent=self.view.scene)
         
     | 
| 
      
 299 
     | 
    
         
            +
             
     | 
| 
      
 300 
     | 
    
         
            +
                def _remove_visual(self, visual: scene.visuals.VisualNode) -> None:
         
     | 
| 
      
 301 
     | 
    
         
            +
                    visual.parent = None
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 304 
     | 
    
         
            +
                def _compute_transform_per_visual(
         
     | 
| 
      
 305 
     | 
    
         
            +
                    x_links: base.Transform,
         
     | 
| 
      
 306 
     | 
    
         
            +
                    x_link_to_geom: base.Transform,
         
     | 
| 
      
 307 
     | 
    
         
            +
                    geom_link_idx: int,
         
     | 
| 
      
 308 
     | 
    
         
            +
                ) -> jax.Array:
         
     | 
| 
      
 309 
     | 
    
         
            +
                    x = jax.lax.cond(
         
     | 
| 
      
 310 
     | 
    
         
            +
                        geom_link_idx == -1,
         
     | 
| 
      
 311 
     | 
    
         
            +
                        lambda: base.Transform.zero(),
         
     | 
| 
      
 312 
     | 
    
         
            +
                        lambda: x_links[geom_link_idx],
         
     | 
| 
      
 313 
     | 
    
         
            +
                    )
         
     | 
| 
      
 314 
     | 
    
         
            +
                    x = algebra.transform_mul(x_link_to_geom, x)
         
     | 
| 
      
 315 
     | 
    
         
            +
                    E = maths.quat_to_3x3(x.rot)
         
     | 
| 
      
 316 
     | 
    
         
            +
                    M = jnp.eye(4)
         
     | 
| 
      
 317 
     | 
    
         
            +
                    M = M.at[:3, :3].set(E)
         
     | 
| 
      
 318 
     | 
    
         
            +
                    T = jnp.eye(4)
         
     | 
| 
      
 319 
     | 
    
         
            +
                    T = T.at[3, :3].set(x.pos)
         
     | 
| 
      
 320 
     | 
    
         
            +
                    return M @ T
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
      
 322 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 323 
     | 
    
         
            +
                def _postprocess_transforms(transform: jax.Array) -> np.ndarray:
         
     | 
| 
      
 324 
     | 
    
         
            +
                    return np.asarray(transform)
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                def _init_visual(
         
     | 
| 
      
 327 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 328 
     | 
    
         
            +
                    visual: scene.visuals.VisualNode,
         
     | 
| 
      
 329 
     | 
    
         
            +
                    transform: np.ndarray,
         
     | 
| 
      
 330 
     | 
    
         
            +
                    geom: base.Geometry,
         
     | 
| 
      
 331 
     | 
    
         
            +
                ):
         
     | 
| 
      
 332 
     | 
    
         
            +
                    visual.transform = MatrixTransform(transform)
         
     | 
| 
      
 333 
     | 
    
         
            +
             
     | 
| 
      
 334 
     | 
    
         
            +
                def _update_visual(
         
     | 
| 
      
 335 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 336 
     | 
    
         
            +
                    visual: scene.visuals.VisualNode,
         
     | 
| 
      
 337 
     | 
    
         
            +
                    transform: np.ndarray,
         
     | 
| 
      
 338 
     | 
    
         
            +
                    geom: base.Geometry,
         
     | 
| 
      
 339 
     | 
    
         
            +
                ):
         
     | 
| 
      
 340 
     | 
    
         
            +
                    visual.transform.matrix = transform
         
     | 
| 
         @@ -0,0 +1,290 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 2 
     | 
    
         
            +
            from ring.base import Color
         
     | 
| 
      
 3 
     | 
    
         
            +
            from vispy.geometry.meshdata import MeshData
         
     | 
| 
      
 4 
     | 
    
         
            +
            from vispy.scene.visuals import create_visual_node
         
     | 
| 
      
 5 
     | 
    
         
            +
            from vispy.scene.visuals import Mesh
         
     | 
| 
      
 6 
     | 
    
         
            +
            from vispy.visuals import CompoundVisual
         
     | 
| 
      
 7 
     | 
    
         
            +
            from vispy.visuals import SphereVisual as _SphereVisual
         
     | 
| 
      
 8 
     | 
    
         
            +
            from vispy.visuals import TubeVisual
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            # vertex density per unit length
         
     | 
| 
      
 11 
     | 
    
         
            +
            _vectices_per_unit_length = 10
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            _default_color = (1, 0.8, 0.7, 1)
         
     | 
| 
      
 14 
     | 
    
         
            +
            _default_edge_color = "black"
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            class DoubleMeshVisual(CompoundVisual):
         
     | 
| 
      
 18 
     | 
    
         
            +
                _lines: Mesh
         
     | 
| 
      
 19 
     | 
    
         
            +
                _faces: Mesh
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 22 
     | 
    
         
            +
                    self, verts, edges, faces, *, color: Color = None, edge_color: Color = None
         
     | 
| 
      
 23 
     | 
    
         
            +
                ):
         
     | 
| 
      
 24 
     | 
    
         
            +
                    if color is None and edge_color is None:
         
     | 
| 
      
 25 
     | 
    
         
            +
                        color = _default_color
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                    if color is not None:
         
     | 
| 
      
 28 
     | 
    
         
            +
                        self._faces = Mesh(verts, faces, color=color, shading=None)
         
     | 
| 
      
 29 
     | 
    
         
            +
                        self.light_dir = np.array([0, -1, 0])
         
     | 
| 
      
 30 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 31 
     | 
    
         
            +
                        self._faces = Mesh()
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                    if edge_color is not None:
         
     | 
| 
      
 34 
     | 
    
         
            +
                        self._edges = Mesh(verts, edges, color=edge_color, mode="lines")
         
     | 
| 
      
 35 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 36 
     | 
    
         
            +
                        self._edges = Mesh()
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                    super().__init__([self._faces, self._edges])
         
     | 
| 
      
 39 
     | 
    
         
            +
                    self._faces.set_gl_state(
         
     | 
| 
      
 40 
     | 
    
         
            +
                        polygon_offset_fill=True, polygon_offset=(1, 1), depth_test=True
         
     | 
| 
      
 41 
     | 
    
         
            +
                    )
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            class SphereVisual(_SphereVisual):
         
     | 
| 
      
 45 
     | 
    
         
            +
                def __init__(self, radius: float, color: Color = None, edge_color: Color = None):
         
     | 
| 
      
 46 
     | 
    
         
            +
                    if color is None and edge_color is None:
         
     | 
| 
      
 47 
     | 
    
         
            +
                        color = _default_color
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                    radius = float(radius)
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                    num_rows = max(int(np.pi * radius * _vectices_per_unit_length), 10)
         
     | 
| 
      
 52 
     | 
    
         
            +
                    num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                    super().__init__(
         
     | 
| 
      
 55 
     | 
    
         
            +
                        radius,
         
     | 
| 
      
 56 
     | 
    
         
            +
                        color=color,
         
     | 
| 
      
 57 
     | 
    
         
            +
                        edge_color=edge_color,
         
     | 
| 
      
 58 
     | 
    
         
            +
                        rows=num_rows,
         
     | 
| 
      
 59 
     | 
    
         
            +
                        cols=num_cols,
         
     | 
| 
      
 60 
     | 
    
         
            +
                        method="latitude",
         
     | 
| 
      
 61 
     | 
    
         
            +
                        shading="smooth",
         
     | 
| 
      
 62 
     | 
    
         
            +
                    )
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
            Sphere = create_visual_node(SphereVisual)
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
            def box_mesh(
         
     | 
| 
      
 69 
     | 
    
         
            +
                dim_x: float, dim_y: float, dim_z: float
         
     | 
| 
      
 70 
     | 
    
         
            +
            ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
         
     | 
| 
      
 71 
     | 
    
         
            +
                verts = np.array(
         
     | 
| 
      
 72 
     | 
    
         
            +
                    [
         
     | 
| 
      
 73 
     | 
    
         
            +
                        (-dim_x, -dim_y, -dim_z),
         
     | 
| 
      
 74 
     | 
    
         
            +
                        (dim_x, -dim_y, -dim_z),
         
     | 
| 
      
 75 
     | 
    
         
            +
                        (-dim_x, dim_y, -dim_z),
         
     | 
| 
      
 76 
     | 
    
         
            +
                        (dim_x, dim_y, -dim_z),
         
     | 
| 
      
 77 
     | 
    
         
            +
                        (-dim_x, -dim_y, dim_z),
         
     | 
| 
      
 78 
     | 
    
         
            +
                        (dim_x, -dim_y, dim_z),
         
     | 
| 
      
 79 
     | 
    
         
            +
                        (-dim_x, dim_y, dim_z),
         
     | 
| 
      
 80 
     | 
    
         
            +
                        (dim_x, dim_y, dim_z),
         
     | 
| 
      
 81 
     | 
    
         
            +
                    ],
         
     | 
| 
      
 82 
     | 
    
         
            +
                    dtype=np.float32,
         
     | 
| 
      
 83 
     | 
    
         
            +
                )
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                verts /= 2
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                edges = np.array(
         
     | 
| 
      
 88 
     | 
    
         
            +
                    [
         
     | 
| 
      
 89 
     | 
    
         
            +
                        (0, 1),
         
     | 
| 
      
 90 
     | 
    
         
            +
                        (0, 2),
         
     | 
| 
      
 91 
     | 
    
         
            +
                        (0, 4),
         
     | 
| 
      
 92 
     | 
    
         
            +
                        (1, 3),
         
     | 
| 
      
 93 
     | 
    
         
            +
                        (1, 5),
         
     | 
| 
      
 94 
     | 
    
         
            +
                        (2, 3),
         
     | 
| 
      
 95 
     | 
    
         
            +
                        (2, 6),
         
     | 
| 
      
 96 
     | 
    
         
            +
                        (3, 7),
         
     | 
| 
      
 97 
     | 
    
         
            +
                        (4, 5),
         
     | 
| 
      
 98 
     | 
    
         
            +
                        (4, 6),
         
     | 
| 
      
 99 
     | 
    
         
            +
                        (5, 7),
         
     | 
| 
      
 100 
     | 
    
         
            +
                        (6, 7),
         
     | 
| 
      
 101 
     | 
    
         
            +
                    ],
         
     | 
| 
      
 102 
     | 
    
         
            +
                    dtype=np.uint32,
         
     | 
| 
      
 103 
     | 
    
         
            +
                )
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                faces = np.array(
         
     | 
| 
      
 106 
     | 
    
         
            +
                    [
         
     | 
| 
      
 107 
     | 
    
         
            +
                        (0, 1, 2),
         
     | 
| 
      
 108 
     | 
    
         
            +
                        (1, 2, 3),
         
     | 
| 
      
 109 
     | 
    
         
            +
                        (0, 1, 4),
         
     | 
| 
      
 110 
     | 
    
         
            +
                        (1, 4, 5),
         
     | 
| 
      
 111 
     | 
    
         
            +
                        (0, 2, 4),
         
     | 
| 
      
 112 
     | 
    
         
            +
                        (2, 4, 6),
         
     | 
| 
      
 113 
     | 
    
         
            +
                        (1, 3, 5),
         
     | 
| 
      
 114 
     | 
    
         
            +
                        (3, 5, 7),
         
     | 
| 
      
 115 
     | 
    
         
            +
                        (2, 3, 6),
         
     | 
| 
      
 116 
     | 
    
         
            +
                        (3, 6, 7),
         
     | 
| 
      
 117 
     | 
    
         
            +
                        (4, 5, 6),
         
     | 
| 
      
 118 
     | 
    
         
            +
                        (5, 6, 7),
         
     | 
| 
      
 119 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 120 
     | 
    
         
            +
                )
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
                return verts, edges, faces
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
            class BoxVisual(DoubleMeshVisual):
         
     | 
| 
      
 126 
     | 
    
         
            +
                # NOTE: need a custom BoxVisual class, since vispy.scene.visuals.Box does not
         
     | 
| 
      
 127 
     | 
    
         
            +
                # support shading
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 130 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 131 
     | 
    
         
            +
                    dim_x: float,
         
     | 
| 
      
 132 
     | 
    
         
            +
                    dim_y: float,
         
     | 
| 
      
 133 
     | 
    
         
            +
                    dim_z: float,
         
     | 
| 
      
 134 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 135 
     | 
    
         
            +
                    color: Color = None,
         
     | 
| 
      
 136 
     | 
    
         
            +
                    edge_color: Color = None
         
     | 
| 
      
 137 
     | 
    
         
            +
                ):
         
     | 
| 
      
 138 
     | 
    
         
            +
                    if color is None:
         
     | 
| 
      
 139 
     | 
    
         
            +
                        color = _default_color
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
                    if edge_color is None:
         
     | 
| 
      
 142 
     | 
    
         
            +
                        edge_color = _default_edge_color
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
                    dim_x = float(dim_x)
         
     | 
| 
      
 145 
     | 
    
         
            +
                    dim_y = float(dim_y)
         
     | 
| 
      
 146 
     | 
    
         
            +
                    dim_z = float(dim_z)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                    self.dim_x = dim_x
         
     | 
| 
      
 149 
     | 
    
         
            +
                    self.dim_y = dim_y
         
     | 
| 
      
 150 
     | 
    
         
            +
                    self.dim_z = dim_z
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                    verts, edges, faces = box_mesh(dim_x, dim_y, dim_z)
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                    super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
            Box = create_visual_node(BoxVisual)
         
     | 
| 
      
 158 
     | 
    
         
            +
             
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
            class CylinderVisual(TubeVisual):
         
     | 
| 
      
 161 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 162 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 163 
     | 
    
         
            +
                    radius: float,
         
     | 
| 
      
 164 
     | 
    
         
            +
                    length: float,
         
     | 
| 
      
 165 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 166 
     | 
    
         
            +
                    color: Color = None,
         
     | 
| 
      
 167 
     | 
    
         
            +
                    edge_color: Color = None
         
     | 
| 
      
 168 
     | 
    
         
            +
                ):
         
     | 
| 
      
 169 
     | 
    
         
            +
                    if color is None and edge_color is None:
         
     | 
| 
      
 170 
     | 
    
         
            +
                        color = _default_color
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                    radius = float(radius)
         
     | 
| 
      
 173 
     | 
    
         
            +
                    length = float(length)
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
                    num_length_points = 10 * max(int(length * _vectices_per_unit_length), 10)
         
     | 
| 
      
 176 
     | 
    
         
            +
                    num_radial_points = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
                    points = np.zeros((num_length_points, 3))
         
     | 
| 
      
 179 
     | 
    
         
            +
                    points[:, 0] = np.linspace(-length / 2, length / 2, num_length_points)
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
                    self.radius = radius
         
     | 
| 
      
 182 
     | 
    
         
            +
                    self.length = length
         
     | 
| 
      
 183 
     | 
    
         
            +
             
     | 
| 
      
 184 
     | 
    
         
            +
                    super().__init__(
         
     | 
| 
      
 185 
     | 
    
         
            +
                        points,
         
     | 
| 
      
 186 
     | 
    
         
            +
                        radius,
         
     | 
| 
      
 187 
     | 
    
         
            +
                        tube_points=num_radial_points,
         
     | 
| 
      
 188 
     | 
    
         
            +
                        closed=True,
         
     | 
| 
      
 189 
     | 
    
         
            +
                        color=color,
         
     | 
| 
      
 190 
     | 
    
         
            +
                        shading="smooth",
         
     | 
| 
      
 191 
     | 
    
         
            +
                    )
         
     | 
| 
      
 192 
     | 
    
         
            +
             
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
            Cylinder = create_visual_node(CylinderVisual)
         
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
             
     | 
| 
      
 197 
     | 
    
         
            +
            def capsule_mesh(radius: float, length: float, offset: bool = True) -> MeshData:
         
     | 
| 
      
 198 
     | 
    
         
            +
                if length < 2 * radius:
         
     | 
| 
      
 199 
     | 
    
         
            +
                    raise ValueError("length must be at least 2 * radius")
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                # number of cap vertices in x direction
         
     | 
| 
      
 202 
     | 
    
         
            +
                num_sphere_rows = max(int(radius * _vectices_per_unit_length), 10)
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                # length without caps
         
     | 
| 
      
 205 
     | 
    
         
            +
                cyl_length = length - 2 * radius
         
     | 
| 
      
 206 
     | 
    
         
            +
                # number of cylinder vertices in x direction
         
     | 
| 
      
 207 
     | 
    
         
            +
                num_cyl_rows = max(int(cyl_length * _vectices_per_unit_length), 10)
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
      
 209 
     | 
    
         
            +
                num_total_rows = 2 * num_sphere_rows + num_cyl_rows
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
                # number of radial vertices
         
     | 
| 
      
 212 
     | 
    
         
            +
                num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                verts = np.empty((num_total_rows, num_cols, 3), dtype=np.float32)
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
      
 216 
     | 
    
         
            +
                # polar angle
         
     | 
| 
      
 217 
     | 
    
         
            +
                theta_top = np.linspace(0.0, np.pi / 2, num_sphere_rows)
         
     | 
| 
      
 218 
     | 
    
         
            +
                theta_bottom = np.linspace(np.pi / 2, np.pi, num_sphere_rows)
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                # fill in x coordinate
         
     | 
| 
      
 221 
     | 
    
         
            +
                verts[:num_sphere_rows, :, 0] = radius * np.cos(theta_top[:, None]) + cyl_length / 2
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
                verts[num_sphere_rows:-num_sphere_rows, :, 0] = np.linspace(
         
     | 
| 
      
 224 
     | 
    
         
            +
                    -cyl_length / 2, cyl_length / 2, num_cyl_rows
         
     | 
| 
      
 225 
     | 
    
         
            +
                )[::-1, None]
         
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
                verts[-num_sphere_rows:, :, 0] = (
         
     | 
| 
      
 228 
     | 
    
         
            +
                    radius * np.cos(theta_bottom[:, None]) - cyl_length / 2
         
     | 
| 
      
 229 
     | 
    
         
            +
                )
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
                # azimuth angle
         
     | 
| 
      
 232 
     | 
    
         
            +
                phi = (np.linspace(0, 2 * np.pi, num_cols))[None, :]
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
                if offset:
         
     | 
| 
      
 235 
     | 
    
         
            +
                    # rotate each row by 1/2 column
         
     | 
| 
      
 236 
     | 
    
         
            +
                    phi = phi + (np.pi / num_cols) * np.arange(num_total_rows)[:, None]
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                # y and z coordinates
         
     | 
| 
      
 239 
     | 
    
         
            +
                verts[..., 1] = radius * np.cos(phi)
         
     | 
| 
      
 240 
     | 
    
         
            +
                verts[..., 2] = radius * np.sin(phi)
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
                # for caps: bend inwards to close
         
     | 
| 
      
 243 
     | 
    
         
            +
                verts[:num_sphere_rows, :, 1:3] *= np.sin(theta_top[:, None, None])
         
     | 
| 
      
 244 
     | 
    
         
            +
                verts[-num_sphere_rows:, :, 1:3] *= np.sin(theta_bottom[:, None, None])
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
                verts = verts.reshape(-1, 3)
         
     | 
| 
      
 247 
     | 
    
         
            +
             
     | 
| 
      
 248 
     | 
    
         
            +
                # compute faces
         
     | 
| 
      
 249 
     | 
    
         
            +
                faces = np.empty(((num_total_rows - 1) * num_cols * 2, 3), dtype=np.uint32)
         
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
                rowtemplate1 = (
         
     | 
| 
      
 252 
     | 
    
         
            +
                    (np.arange(num_cols).reshape(num_cols, 1) + np.array([[0, 1, 0]])) % num_cols
         
     | 
| 
      
 253 
     | 
    
         
            +
                ) + np.array([[num_cols, 0, 0]])
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
                rowtemplate2 = (
         
     | 
| 
      
 256 
     | 
    
         
            +
                    (np.arange(num_cols).reshape(num_cols, 1) + np.array([[1, 1, 0]])) % num_cols
         
     | 
| 
      
 257 
     | 
    
         
            +
                ) + np.array([[num_cols, 0, num_cols]])
         
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
                for row in range(num_total_rows - 1):
         
     | 
| 
      
 260 
     | 
    
         
            +
                    start = row * num_cols * 2
         
     | 
| 
      
 261 
     | 
    
         
            +
             
     | 
| 
      
 262 
     | 
    
         
            +
                    faces[start : start + num_cols] = rowtemplate1 + row * num_cols
         
     | 
| 
      
 263 
     | 
    
         
            +
                    faces[start + num_cols : start + (num_cols * 2)] = rowtemplate2 + row * num_cols
         
     | 
| 
      
 264 
     | 
    
         
            +
             
     | 
| 
      
 265 
     | 
    
         
            +
                mesh = MeshData(vertices=verts, faces=faces)
         
     | 
| 
      
 266 
     | 
    
         
            +
             
     | 
| 
      
 267 
     | 
    
         
            +
                return mesh.get_vertices(), mesh.get_edges(), mesh.get_faces()
         
     | 
| 
      
 268 
     | 
    
         
            +
             
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
            class CapsuleVisual(DoubleMeshVisual):
         
     | 
| 
      
 271 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 272 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 273 
     | 
    
         
            +
                    radius: float,
         
     | 
| 
      
 274 
     | 
    
         
            +
                    length: float,
         
     | 
| 
      
 275 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 276 
     | 
    
         
            +
                    color: Color = None,
         
     | 
| 
      
 277 
     | 
    
         
            +
                    edge_color: Color = None
         
     | 
| 
      
 278 
     | 
    
         
            +
                ):
         
     | 
| 
      
 279 
     | 
    
         
            +
                    radius = float(radius)
         
     | 
| 
      
 280 
     | 
    
         
            +
                    length = float(length)
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
                    self.radius = radius
         
     | 
| 
      
 283 
     | 
    
         
            +
                    self.length = length
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
      
 285 
     | 
    
         
            +
                    verts, edges, faces = capsule_mesh(radius, length)
         
     | 
| 
      
 286 
     | 
    
         
            +
             
     | 
| 
      
 287 
     | 
    
         
            +
                    super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
         
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
             
     | 
| 
      
 290 
     | 
    
         
            +
            Capsule = create_visual_node(CapsuleVisual)
         
     |