imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,340 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from abc import abstractmethod
|
3
|
+
from abc import abstractstaticmethod
|
4
|
+
from functools import partial
|
5
|
+
from typing import Optional, TypeVar
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import numpy as np
|
10
|
+
from ring import algebra
|
11
|
+
from ring import base
|
12
|
+
from ring import maths
|
13
|
+
from tree_utils import PyTree
|
14
|
+
from tree_utils import tree_batch
|
15
|
+
from vispy import scene
|
16
|
+
from vispy.scene import MatrixTransform
|
17
|
+
|
18
|
+
from . import vispy_visuals
|
19
|
+
|
20
|
+
Camera = TypeVar("Camera")
|
21
|
+
Visual = TypeVar("Visual")
|
22
|
+
VisualPosOri1 = PyTree
|
23
|
+
VisualPosOri2 = PyTree
|
24
|
+
|
25
|
+
|
26
|
+
class Scene(ABC):
|
27
|
+
_xyz: bool = True
|
28
|
+
_xyz_root: bool = True
|
29
|
+
_xyz_transform1: bool = True
|
30
|
+
visuals: list[Visual] = []
|
31
|
+
|
32
|
+
"""
|
33
|
+
Example:
|
34
|
+
>> renderer = Renderer()
|
35
|
+
>> renderer.init(sys.geoms)
|
36
|
+
>> for t in range(xs.shape()):
|
37
|
+
>> renderer.update(xs[t])
|
38
|
+
>> image = renderer.render()
|
39
|
+
"""
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def _get_camera(self) -> Camera:
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def _set_camera(self, camera: Camera) -> None:
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def _render(self) -> jax.Array:
|
51
|
+
pass
|
52
|
+
|
53
|
+
def enable_xyz(self, enable_root: bool = True) -> None:
|
54
|
+
self._xyz = True
|
55
|
+
if enable_root:
|
56
|
+
self._xyz_root = True
|
57
|
+
|
58
|
+
def disable_xyz(self, disable_root: bool = True) -> None:
|
59
|
+
self._xyz = False
|
60
|
+
if disable_root:
|
61
|
+
self._xyz_root = False
|
62
|
+
|
63
|
+
def enable_xyz_transform1(self):
|
64
|
+
self._xyz_transform1 = True
|
65
|
+
|
66
|
+
def disable_xyz_tranform1(self):
|
67
|
+
self._xyz_transform1 = False
|
68
|
+
|
69
|
+
def render(
|
70
|
+
self, camera: Optional[Camera | list[Camera]] = None
|
71
|
+
) -> jax.Array | list[jax.Array]:
|
72
|
+
"Returns: RGBA Array of Shape = (M, N, 4)"
|
73
|
+
if camera is None:
|
74
|
+
camera = self._get_camera()
|
75
|
+
|
76
|
+
if not isinstance(camera, list):
|
77
|
+
self._set_camera(camera)
|
78
|
+
return self._render()
|
79
|
+
|
80
|
+
images = []
|
81
|
+
for cam in camera:
|
82
|
+
self._set_camera(cam)
|
83
|
+
images.append(self._render())
|
84
|
+
return images
|
85
|
+
|
86
|
+
def _add_box(self, box: base.Box) -> Visual:
|
87
|
+
raise NotImplementedError
|
88
|
+
|
89
|
+
def _add_sphere(self, sphere: base.Sphere) -> Visual:
|
90
|
+
raise NotImplementedError
|
91
|
+
|
92
|
+
def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
|
93
|
+
raise NotImplementedError
|
94
|
+
|
95
|
+
def _add_capsule(self, cap: base.Capsule) -> Visual:
|
96
|
+
raise NotImplementedError
|
97
|
+
|
98
|
+
def _add_xyz(self) -> Visual:
|
99
|
+
raise NotImplementedError
|
100
|
+
|
101
|
+
@abstractmethod
|
102
|
+
def _remove_visual(self, visual: Visual) -> None:
|
103
|
+
pass
|
104
|
+
|
105
|
+
def _remove_all_visuals(self):
|
106
|
+
for visual in self.visuals:
|
107
|
+
self._remove_visual(visual)
|
108
|
+
|
109
|
+
def init(self, geoms: list[base.Geometry]):
|
110
|
+
self._remove_all_visuals()
|
111
|
+
|
112
|
+
self.geoms = [geom for geom in geoms]
|
113
|
+
self._fresh_init = True
|
114
|
+
|
115
|
+
geom_link_idx = []
|
116
|
+
geom_transform = []
|
117
|
+
self.visuals = []
|
118
|
+
for geom in geoms:
|
119
|
+
geom_link_idx.append(geom.link_idx)
|
120
|
+
geom_transform.append(geom.transform)
|
121
|
+
if isinstance(geom, base.Box):
|
122
|
+
visual = self._add_box(geom)
|
123
|
+
elif isinstance(geom, base.Sphere):
|
124
|
+
visual = self._add_sphere(geom)
|
125
|
+
elif isinstance(geom, base.Cylinder):
|
126
|
+
visual = self._add_cylinder(geom)
|
127
|
+
elif isinstance(geom, base.Capsule):
|
128
|
+
visual = self._add_capsule(geom)
|
129
|
+
elif isinstance(geom, base.XYZ):
|
130
|
+
visual = self._add_xyz()
|
131
|
+
if not self._xyz_transform1:
|
132
|
+
geom_transform.pop()
|
133
|
+
geom_transform.append(base.Transform.zero())
|
134
|
+
else:
|
135
|
+
raise Exception(f"Unknown geom type: {type(geom)}")
|
136
|
+
self.visuals.append(visual)
|
137
|
+
|
138
|
+
if self._xyz:
|
139
|
+
unique_link_indices = np.unique(np.array(geom_link_idx))
|
140
|
+
for unique_link_idx in unique_link_indices:
|
141
|
+
geom_link_idx.append(unique_link_idx)
|
142
|
+
geom_transform.append(base.Transform.zero())
|
143
|
+
self.visuals.append(self._add_xyz())
|
144
|
+
# otherwise the .update function won't iterate
|
145
|
+
# over all visuals since it uses a zip(...)
|
146
|
+
self.geoms.append(None)
|
147
|
+
|
148
|
+
if self._xyz_root:
|
149
|
+
# add one final for root frame
|
150
|
+
self._add_xyz()
|
151
|
+
|
152
|
+
self.geom_link_idx = tree_batch(geom_link_idx, backend="jax")
|
153
|
+
self.geom_transform = tree_batch(geom_transform, backend="jax")
|
154
|
+
|
155
|
+
@abstractstaticmethod
|
156
|
+
def _compute_transform_per_visual(
|
157
|
+
x_links: base.Transform,
|
158
|
+
x_link_to_geom: base.Transform,
|
159
|
+
geom_link_idx: int,
|
160
|
+
) -> VisualPosOri1:
|
161
|
+
"This can easily account for possible convention differences"
|
162
|
+
pass
|
163
|
+
|
164
|
+
@abstractstaticmethod
|
165
|
+
def _postprocess_transforms(transform: VisualPosOri1) -> VisualPosOri2:
|
166
|
+
pass
|
167
|
+
|
168
|
+
@abstractmethod
|
169
|
+
def _init_visual(
|
170
|
+
self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
|
171
|
+
):
|
172
|
+
pass
|
173
|
+
|
174
|
+
def _update_visual(
|
175
|
+
self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
|
176
|
+
):
|
177
|
+
self._init_visual(visual, transform, geom)
|
178
|
+
|
179
|
+
def update(self, x: base.Transform):
|
180
|
+
"`x` are (n_links,) Transforms."
|
181
|
+
|
182
|
+
# step 1: pre-compute all required transforms
|
183
|
+
transform_per_visual = _compile_staticmethod(
|
184
|
+
self._compute_transform_per_visual,
|
185
|
+
x,
|
186
|
+
self.geom_transform,
|
187
|
+
self.geom_link_idx,
|
188
|
+
)
|
189
|
+
|
190
|
+
# step 2: postprocess all transforms once
|
191
|
+
transform_per_visual = self._postprocess_transforms(transform_per_visual)
|
192
|
+
|
193
|
+
# step 3: update visuals
|
194
|
+
for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
|
195
|
+
t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
|
196
|
+
if self._fresh_init:
|
197
|
+
self._init_visual(visual, t, geom)
|
198
|
+
else:
|
199
|
+
self._update_visual(visual, t, geom)
|
200
|
+
|
201
|
+
# step 4: unset flag
|
202
|
+
self._fresh_init = False
|
203
|
+
|
204
|
+
|
205
|
+
@partial(jax.jit, static_argnums=0)
|
206
|
+
def _compile_staticmethod(static_method, x, geom_transform, geom_link_idx):
|
207
|
+
return jax.vmap(static_method, in_axes=(None, 0, 0))(
|
208
|
+
x, geom_transform, geom_link_idx
|
209
|
+
)
|
210
|
+
|
211
|
+
|
212
|
+
class VispyScene(Scene):
|
213
|
+
def __init__(
|
214
|
+
self,
|
215
|
+
show_cs=False,
|
216
|
+
show_cs_root=True,
|
217
|
+
width: int = 320,
|
218
|
+
height: int = 240,
|
219
|
+
camera: scene.cameras.BaseCamera = scene.TurntableCamera(
|
220
|
+
elevation=25, distance=4.0, azimuth=25
|
221
|
+
),
|
222
|
+
**kwargs,
|
223
|
+
):
|
224
|
+
"""Scene which can be rendered.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
geoms (list[list[Geometry]]): A list of list of geometries per link.
|
228
|
+
len(geoms) == number of links in system
|
229
|
+
show_cs (bool, optional): Show coordinate system of links.
|
230
|
+
Defaults to True.
|
231
|
+
show_cs_root (bool, optional): Show coordinate system of earth frame.
|
232
|
+
Defaults to True.
|
233
|
+
camera (scene.cameras.BaseCamera, optional): The camera angle.
|
234
|
+
Defaults to scene.TurntableCamera( elevation=30, distance=6 ).
|
235
|
+
|
236
|
+
Example:
|
237
|
+
>> scene = VispyScene()
|
238
|
+
>> scene.init(sys.geoms)
|
239
|
+
>> scene.update(state.x)
|
240
|
+
>> image = scene.render()
|
241
|
+
"""
|
242
|
+
self.canvas = scene.SceneCanvas(
|
243
|
+
keys="interactive", size=(width, height), show=True, **kwargs
|
244
|
+
)
|
245
|
+
self.view = self.canvas.central_widget.add_view()
|
246
|
+
self._set_camera(camera)
|
247
|
+
if show_cs:
|
248
|
+
self.enable_xyz()
|
249
|
+
else:
|
250
|
+
self.disable_xyz(not show_cs_root)
|
251
|
+
|
252
|
+
def _set_camera(self, camera: scene.cameras.BaseCamera) -> None:
|
253
|
+
self.view.camera = camera
|
254
|
+
|
255
|
+
def _get_camera(self) -> scene.cameras.BaseCamera:
|
256
|
+
return self.view.camera
|
257
|
+
|
258
|
+
def _render(self) -> jax.Array:
|
259
|
+
return self.canvas.render(alpha=True)
|
260
|
+
|
261
|
+
def _add_box(self, box: base.Box) -> Visual:
|
262
|
+
return vispy_visuals.Box(
|
263
|
+
box.dim_x,
|
264
|
+
box.dim_z,
|
265
|
+
box.dim_y,
|
266
|
+
color=box.color,
|
267
|
+
edge_color=box.edge_color,
|
268
|
+
parent=self.view.scene,
|
269
|
+
)
|
270
|
+
|
271
|
+
def _add_sphere(self, sphere: base.Sphere) -> Visual:
|
272
|
+
return vispy_visuals.Sphere(
|
273
|
+
sphere.radius,
|
274
|
+
color=sphere.color,
|
275
|
+
edge_color=sphere.edge_color,
|
276
|
+
parent=self.view.scene,
|
277
|
+
)
|
278
|
+
|
279
|
+
def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
|
280
|
+
return vispy_visuals.Cylinder(
|
281
|
+
cyl.radius,
|
282
|
+
cyl.length,
|
283
|
+
color=cyl.color,
|
284
|
+
edge_color=cyl.edge_color,
|
285
|
+
parent=self.view.scene,
|
286
|
+
)
|
287
|
+
|
288
|
+
def _add_capsule(self, cap: base.Capsule) -> Visual:
|
289
|
+
return vispy_visuals.Capsule(
|
290
|
+
cap.radius,
|
291
|
+
cap.length,
|
292
|
+
color=cap.color,
|
293
|
+
edge_color=cap.edge_color,
|
294
|
+
parent=self.view.scene,
|
295
|
+
)
|
296
|
+
|
297
|
+
def _add_xyz(self) -> Visual:
|
298
|
+
return scene.visuals.XYZAxis(parent=self.view.scene)
|
299
|
+
|
300
|
+
def _remove_visual(self, visual: scene.visuals.VisualNode) -> None:
|
301
|
+
visual.parent = None
|
302
|
+
|
303
|
+
@staticmethod
|
304
|
+
def _compute_transform_per_visual(
|
305
|
+
x_links: base.Transform,
|
306
|
+
x_link_to_geom: base.Transform,
|
307
|
+
geom_link_idx: int,
|
308
|
+
) -> jax.Array:
|
309
|
+
x = jax.lax.cond(
|
310
|
+
geom_link_idx == -1,
|
311
|
+
lambda: base.Transform.zero(),
|
312
|
+
lambda: x_links[geom_link_idx],
|
313
|
+
)
|
314
|
+
x = algebra.transform_mul(x_link_to_geom, x)
|
315
|
+
E = maths.quat_to_3x3(x.rot)
|
316
|
+
M = jnp.eye(4)
|
317
|
+
M = M.at[:3, :3].set(E)
|
318
|
+
T = jnp.eye(4)
|
319
|
+
T = T.at[3, :3].set(x.pos)
|
320
|
+
return M @ T
|
321
|
+
|
322
|
+
@staticmethod
|
323
|
+
def _postprocess_transforms(transform: jax.Array) -> np.ndarray:
|
324
|
+
return np.asarray(transform)
|
325
|
+
|
326
|
+
def _init_visual(
|
327
|
+
self,
|
328
|
+
visual: scene.visuals.VisualNode,
|
329
|
+
transform: np.ndarray,
|
330
|
+
geom: base.Geometry,
|
331
|
+
):
|
332
|
+
visual.transform = MatrixTransform(transform)
|
333
|
+
|
334
|
+
def _update_visual(
|
335
|
+
self,
|
336
|
+
visual: scene.visuals.VisualNode,
|
337
|
+
transform: np.ndarray,
|
338
|
+
geom: base.Geometry,
|
339
|
+
):
|
340
|
+
visual.transform.matrix = transform
|
@@ -0,0 +1,290 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from ring.base import Color
|
3
|
+
from vispy.geometry.meshdata import MeshData
|
4
|
+
from vispy.scene.visuals import create_visual_node
|
5
|
+
from vispy.scene.visuals import Mesh
|
6
|
+
from vispy.visuals import CompoundVisual
|
7
|
+
from vispy.visuals import SphereVisual as _SphereVisual
|
8
|
+
from vispy.visuals import TubeVisual
|
9
|
+
|
10
|
+
# vertex density per unit length
|
11
|
+
_vectices_per_unit_length = 10
|
12
|
+
|
13
|
+
_default_color = (1, 0.8, 0.7, 1)
|
14
|
+
_default_edge_color = "black"
|
15
|
+
|
16
|
+
|
17
|
+
class DoubleMeshVisual(CompoundVisual):
|
18
|
+
_lines: Mesh
|
19
|
+
_faces: Mesh
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self, verts, edges, faces, *, color: Color = None, edge_color: Color = None
|
23
|
+
):
|
24
|
+
if color is None and edge_color is None:
|
25
|
+
color = _default_color
|
26
|
+
|
27
|
+
if color is not None:
|
28
|
+
self._faces = Mesh(verts, faces, color=color, shading=None)
|
29
|
+
self.light_dir = np.array([0, -1, 0])
|
30
|
+
else:
|
31
|
+
self._faces = Mesh()
|
32
|
+
|
33
|
+
if edge_color is not None:
|
34
|
+
self._edges = Mesh(verts, edges, color=edge_color, mode="lines")
|
35
|
+
else:
|
36
|
+
self._edges = Mesh()
|
37
|
+
|
38
|
+
super().__init__([self._faces, self._edges])
|
39
|
+
self._faces.set_gl_state(
|
40
|
+
polygon_offset_fill=True, polygon_offset=(1, 1), depth_test=True
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class SphereVisual(_SphereVisual):
|
45
|
+
def __init__(self, radius: float, color: Color = None, edge_color: Color = None):
|
46
|
+
if color is None and edge_color is None:
|
47
|
+
color = _default_color
|
48
|
+
|
49
|
+
radius = float(radius)
|
50
|
+
|
51
|
+
num_rows = max(int(np.pi * radius * _vectices_per_unit_length), 10)
|
52
|
+
num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
|
53
|
+
|
54
|
+
super().__init__(
|
55
|
+
radius,
|
56
|
+
color=color,
|
57
|
+
edge_color=edge_color,
|
58
|
+
rows=num_rows,
|
59
|
+
cols=num_cols,
|
60
|
+
method="latitude",
|
61
|
+
shading="smooth",
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
Sphere = create_visual_node(SphereVisual)
|
66
|
+
|
67
|
+
|
68
|
+
def box_mesh(
|
69
|
+
dim_x: float, dim_y: float, dim_z: float
|
70
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
71
|
+
verts = np.array(
|
72
|
+
[
|
73
|
+
(-dim_x, -dim_y, -dim_z),
|
74
|
+
(dim_x, -dim_y, -dim_z),
|
75
|
+
(-dim_x, dim_y, -dim_z),
|
76
|
+
(dim_x, dim_y, -dim_z),
|
77
|
+
(-dim_x, -dim_y, dim_z),
|
78
|
+
(dim_x, -dim_y, dim_z),
|
79
|
+
(-dim_x, dim_y, dim_z),
|
80
|
+
(dim_x, dim_y, dim_z),
|
81
|
+
],
|
82
|
+
dtype=np.float32,
|
83
|
+
)
|
84
|
+
|
85
|
+
verts /= 2
|
86
|
+
|
87
|
+
edges = np.array(
|
88
|
+
[
|
89
|
+
(0, 1),
|
90
|
+
(0, 2),
|
91
|
+
(0, 4),
|
92
|
+
(1, 3),
|
93
|
+
(1, 5),
|
94
|
+
(2, 3),
|
95
|
+
(2, 6),
|
96
|
+
(3, 7),
|
97
|
+
(4, 5),
|
98
|
+
(4, 6),
|
99
|
+
(5, 7),
|
100
|
+
(6, 7),
|
101
|
+
],
|
102
|
+
dtype=np.uint32,
|
103
|
+
)
|
104
|
+
|
105
|
+
faces = np.array(
|
106
|
+
[
|
107
|
+
(0, 1, 2),
|
108
|
+
(1, 2, 3),
|
109
|
+
(0, 1, 4),
|
110
|
+
(1, 4, 5),
|
111
|
+
(0, 2, 4),
|
112
|
+
(2, 4, 6),
|
113
|
+
(1, 3, 5),
|
114
|
+
(3, 5, 7),
|
115
|
+
(2, 3, 6),
|
116
|
+
(3, 6, 7),
|
117
|
+
(4, 5, 6),
|
118
|
+
(5, 6, 7),
|
119
|
+
]
|
120
|
+
)
|
121
|
+
|
122
|
+
return verts, edges, faces
|
123
|
+
|
124
|
+
|
125
|
+
class BoxVisual(DoubleMeshVisual):
|
126
|
+
# NOTE: need a custom BoxVisual class, since vispy.scene.visuals.Box does not
|
127
|
+
# support shading
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
dim_x: float,
|
132
|
+
dim_y: float,
|
133
|
+
dim_z: float,
|
134
|
+
*,
|
135
|
+
color: Color = None,
|
136
|
+
edge_color: Color = None
|
137
|
+
):
|
138
|
+
if color is None:
|
139
|
+
color = _default_color
|
140
|
+
|
141
|
+
if edge_color is None:
|
142
|
+
edge_color = _default_edge_color
|
143
|
+
|
144
|
+
dim_x = float(dim_x)
|
145
|
+
dim_y = float(dim_y)
|
146
|
+
dim_z = float(dim_z)
|
147
|
+
|
148
|
+
self.dim_x = dim_x
|
149
|
+
self.dim_y = dim_y
|
150
|
+
self.dim_z = dim_z
|
151
|
+
|
152
|
+
verts, edges, faces = box_mesh(dim_x, dim_y, dim_z)
|
153
|
+
|
154
|
+
super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
|
155
|
+
|
156
|
+
|
157
|
+
Box = create_visual_node(BoxVisual)
|
158
|
+
|
159
|
+
|
160
|
+
class CylinderVisual(TubeVisual):
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
radius: float,
|
164
|
+
length: float,
|
165
|
+
*,
|
166
|
+
color: Color = None,
|
167
|
+
edge_color: Color = None
|
168
|
+
):
|
169
|
+
if color is None and edge_color is None:
|
170
|
+
color = _default_color
|
171
|
+
|
172
|
+
radius = float(radius)
|
173
|
+
length = float(length)
|
174
|
+
|
175
|
+
num_length_points = 10 * max(int(length * _vectices_per_unit_length), 10)
|
176
|
+
num_radial_points = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
|
177
|
+
|
178
|
+
points = np.zeros((num_length_points, 3))
|
179
|
+
points[:, 0] = np.linspace(-length / 2, length / 2, num_length_points)
|
180
|
+
|
181
|
+
self.radius = radius
|
182
|
+
self.length = length
|
183
|
+
|
184
|
+
super().__init__(
|
185
|
+
points,
|
186
|
+
radius,
|
187
|
+
tube_points=num_radial_points,
|
188
|
+
closed=True,
|
189
|
+
color=color,
|
190
|
+
shading="smooth",
|
191
|
+
)
|
192
|
+
|
193
|
+
|
194
|
+
Cylinder = create_visual_node(CylinderVisual)
|
195
|
+
|
196
|
+
|
197
|
+
def capsule_mesh(radius: float, length: float, offset: bool = True) -> MeshData:
|
198
|
+
if length < 2 * radius:
|
199
|
+
raise ValueError("length must be at least 2 * radius")
|
200
|
+
|
201
|
+
# number of cap vertices in x direction
|
202
|
+
num_sphere_rows = max(int(radius * _vectices_per_unit_length), 10)
|
203
|
+
|
204
|
+
# length without caps
|
205
|
+
cyl_length = length - 2 * radius
|
206
|
+
# number of cylinder vertices in x direction
|
207
|
+
num_cyl_rows = max(int(cyl_length * _vectices_per_unit_length), 10)
|
208
|
+
|
209
|
+
num_total_rows = 2 * num_sphere_rows + num_cyl_rows
|
210
|
+
|
211
|
+
# number of radial vertices
|
212
|
+
num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
|
213
|
+
|
214
|
+
verts = np.empty((num_total_rows, num_cols, 3), dtype=np.float32)
|
215
|
+
|
216
|
+
# polar angle
|
217
|
+
theta_top = np.linspace(0.0, np.pi / 2, num_sphere_rows)
|
218
|
+
theta_bottom = np.linspace(np.pi / 2, np.pi, num_sphere_rows)
|
219
|
+
|
220
|
+
# fill in x coordinate
|
221
|
+
verts[:num_sphere_rows, :, 0] = radius * np.cos(theta_top[:, None]) + cyl_length / 2
|
222
|
+
|
223
|
+
verts[num_sphere_rows:-num_sphere_rows, :, 0] = np.linspace(
|
224
|
+
-cyl_length / 2, cyl_length / 2, num_cyl_rows
|
225
|
+
)[::-1, None]
|
226
|
+
|
227
|
+
verts[-num_sphere_rows:, :, 0] = (
|
228
|
+
radius * np.cos(theta_bottom[:, None]) - cyl_length / 2
|
229
|
+
)
|
230
|
+
|
231
|
+
# azimuth angle
|
232
|
+
phi = (np.linspace(0, 2 * np.pi, num_cols))[None, :]
|
233
|
+
|
234
|
+
if offset:
|
235
|
+
# rotate each row by 1/2 column
|
236
|
+
phi = phi + (np.pi / num_cols) * np.arange(num_total_rows)[:, None]
|
237
|
+
|
238
|
+
# y and z coordinates
|
239
|
+
verts[..., 1] = radius * np.cos(phi)
|
240
|
+
verts[..., 2] = radius * np.sin(phi)
|
241
|
+
|
242
|
+
# for caps: bend inwards to close
|
243
|
+
verts[:num_sphere_rows, :, 1:3] *= np.sin(theta_top[:, None, None])
|
244
|
+
verts[-num_sphere_rows:, :, 1:3] *= np.sin(theta_bottom[:, None, None])
|
245
|
+
|
246
|
+
verts = verts.reshape(-1, 3)
|
247
|
+
|
248
|
+
# compute faces
|
249
|
+
faces = np.empty(((num_total_rows - 1) * num_cols * 2, 3), dtype=np.uint32)
|
250
|
+
|
251
|
+
rowtemplate1 = (
|
252
|
+
(np.arange(num_cols).reshape(num_cols, 1) + np.array([[0, 1, 0]])) % num_cols
|
253
|
+
) + np.array([[num_cols, 0, 0]])
|
254
|
+
|
255
|
+
rowtemplate2 = (
|
256
|
+
(np.arange(num_cols).reshape(num_cols, 1) + np.array([[1, 1, 0]])) % num_cols
|
257
|
+
) + np.array([[num_cols, 0, num_cols]])
|
258
|
+
|
259
|
+
for row in range(num_total_rows - 1):
|
260
|
+
start = row * num_cols * 2
|
261
|
+
|
262
|
+
faces[start : start + num_cols] = rowtemplate1 + row * num_cols
|
263
|
+
faces[start + num_cols : start + (num_cols * 2)] = rowtemplate2 + row * num_cols
|
264
|
+
|
265
|
+
mesh = MeshData(vertices=verts, faces=faces)
|
266
|
+
|
267
|
+
return mesh.get_vertices(), mesh.get_edges(), mesh.get_faces()
|
268
|
+
|
269
|
+
|
270
|
+
class CapsuleVisual(DoubleMeshVisual):
|
271
|
+
def __init__(
|
272
|
+
self,
|
273
|
+
radius: float,
|
274
|
+
length: float,
|
275
|
+
*,
|
276
|
+
color: Color = None,
|
277
|
+
edge_color: Color = None
|
278
|
+
):
|
279
|
+
radius = float(radius)
|
280
|
+
length = float(length)
|
281
|
+
|
282
|
+
self.radius = radius
|
283
|
+
self.length = length
|
284
|
+
|
285
|
+
verts, edges, faces = capsule_mesh(radius, length)
|
286
|
+
|
287
|
+
super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
|
288
|
+
|
289
|
+
|
290
|
+
Capsule = create_visual_node(CapsuleVisual)
|