imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,340 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from abc import abstractstaticmethod
4
+ from functools import partial
5
+ from typing import Optional, TypeVar
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from ring import algebra
11
+ from ring import base
12
+ from ring import maths
13
+ from tree_utils import PyTree
14
+ from tree_utils import tree_batch
15
+ from vispy import scene
16
+ from vispy.scene import MatrixTransform
17
+
18
+ from . import vispy_visuals
19
+
20
+ Camera = TypeVar("Camera")
21
+ Visual = TypeVar("Visual")
22
+ VisualPosOri1 = PyTree
23
+ VisualPosOri2 = PyTree
24
+
25
+
26
+ class Scene(ABC):
27
+ _xyz: bool = True
28
+ _xyz_root: bool = True
29
+ _xyz_transform1: bool = True
30
+ visuals: list[Visual] = []
31
+
32
+ """
33
+ Example:
34
+ >> renderer = Renderer()
35
+ >> renderer.init(sys.geoms)
36
+ >> for t in range(xs.shape()):
37
+ >> renderer.update(xs[t])
38
+ >> image = renderer.render()
39
+ """
40
+
41
+ @abstractmethod
42
+ def _get_camera(self) -> Camera:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def _set_camera(self, camera: Camera) -> None:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def _render(self) -> jax.Array:
51
+ pass
52
+
53
+ def enable_xyz(self, enable_root: bool = True) -> None:
54
+ self._xyz = True
55
+ if enable_root:
56
+ self._xyz_root = True
57
+
58
+ def disable_xyz(self, disable_root: bool = True) -> None:
59
+ self._xyz = False
60
+ if disable_root:
61
+ self._xyz_root = False
62
+
63
+ def enable_xyz_transform1(self):
64
+ self._xyz_transform1 = True
65
+
66
+ def disable_xyz_tranform1(self):
67
+ self._xyz_transform1 = False
68
+
69
+ def render(
70
+ self, camera: Optional[Camera | list[Camera]] = None
71
+ ) -> jax.Array | list[jax.Array]:
72
+ "Returns: RGBA Array of Shape = (M, N, 4)"
73
+ if camera is None:
74
+ camera = self._get_camera()
75
+
76
+ if not isinstance(camera, list):
77
+ self._set_camera(camera)
78
+ return self._render()
79
+
80
+ images = []
81
+ for cam in camera:
82
+ self._set_camera(cam)
83
+ images.append(self._render())
84
+ return images
85
+
86
+ def _add_box(self, box: base.Box) -> Visual:
87
+ raise NotImplementedError
88
+
89
+ def _add_sphere(self, sphere: base.Sphere) -> Visual:
90
+ raise NotImplementedError
91
+
92
+ def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
93
+ raise NotImplementedError
94
+
95
+ def _add_capsule(self, cap: base.Capsule) -> Visual:
96
+ raise NotImplementedError
97
+
98
+ def _add_xyz(self) -> Visual:
99
+ raise NotImplementedError
100
+
101
+ @abstractmethod
102
+ def _remove_visual(self, visual: Visual) -> None:
103
+ pass
104
+
105
+ def _remove_all_visuals(self):
106
+ for visual in self.visuals:
107
+ self._remove_visual(visual)
108
+
109
+ def init(self, geoms: list[base.Geometry]):
110
+ self._remove_all_visuals()
111
+
112
+ self.geoms = [geom for geom in geoms]
113
+ self._fresh_init = True
114
+
115
+ geom_link_idx = []
116
+ geom_transform = []
117
+ self.visuals = []
118
+ for geom in geoms:
119
+ geom_link_idx.append(geom.link_idx)
120
+ geom_transform.append(geom.transform)
121
+ if isinstance(geom, base.Box):
122
+ visual = self._add_box(geom)
123
+ elif isinstance(geom, base.Sphere):
124
+ visual = self._add_sphere(geom)
125
+ elif isinstance(geom, base.Cylinder):
126
+ visual = self._add_cylinder(geom)
127
+ elif isinstance(geom, base.Capsule):
128
+ visual = self._add_capsule(geom)
129
+ elif isinstance(geom, base.XYZ):
130
+ visual = self._add_xyz()
131
+ if not self._xyz_transform1:
132
+ geom_transform.pop()
133
+ geom_transform.append(base.Transform.zero())
134
+ else:
135
+ raise Exception(f"Unknown geom type: {type(geom)}")
136
+ self.visuals.append(visual)
137
+
138
+ if self._xyz:
139
+ unique_link_indices = np.unique(np.array(geom_link_idx))
140
+ for unique_link_idx in unique_link_indices:
141
+ geom_link_idx.append(unique_link_idx)
142
+ geom_transform.append(base.Transform.zero())
143
+ self.visuals.append(self._add_xyz())
144
+ # otherwise the .update function won't iterate
145
+ # over all visuals since it uses a zip(...)
146
+ self.geoms.append(None)
147
+
148
+ if self._xyz_root:
149
+ # add one final for root frame
150
+ self._add_xyz()
151
+
152
+ self.geom_link_idx = tree_batch(geom_link_idx, backend="jax")
153
+ self.geom_transform = tree_batch(geom_transform, backend="jax")
154
+
155
+ @abstractstaticmethod
156
+ def _compute_transform_per_visual(
157
+ x_links: base.Transform,
158
+ x_link_to_geom: base.Transform,
159
+ geom_link_idx: int,
160
+ ) -> VisualPosOri1:
161
+ "This can easily account for possible convention differences"
162
+ pass
163
+
164
+ @abstractstaticmethod
165
+ def _postprocess_transforms(transform: VisualPosOri1) -> VisualPosOri2:
166
+ pass
167
+
168
+ @abstractmethod
169
+ def _init_visual(
170
+ self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
171
+ ):
172
+ pass
173
+
174
+ def _update_visual(
175
+ self, visual: Visual, transform: VisualPosOri2, geom: None | base.Geometry
176
+ ):
177
+ self._init_visual(visual, transform, geom)
178
+
179
+ def update(self, x: base.Transform):
180
+ "`x` are (n_links,) Transforms."
181
+
182
+ # step 1: pre-compute all required transforms
183
+ transform_per_visual = _compile_staticmethod(
184
+ self._compute_transform_per_visual,
185
+ x,
186
+ self.geom_transform,
187
+ self.geom_link_idx,
188
+ )
189
+
190
+ # step 2: postprocess all transforms once
191
+ transform_per_visual = self._postprocess_transforms(transform_per_visual)
192
+
193
+ # step 3: update visuals
194
+ for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
195
+ t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
196
+ if self._fresh_init:
197
+ self._init_visual(visual, t, geom)
198
+ else:
199
+ self._update_visual(visual, t, geom)
200
+
201
+ # step 4: unset flag
202
+ self._fresh_init = False
203
+
204
+
205
+ @partial(jax.jit, static_argnums=0)
206
+ def _compile_staticmethod(static_method, x, geom_transform, geom_link_idx):
207
+ return jax.vmap(static_method, in_axes=(None, 0, 0))(
208
+ x, geom_transform, geom_link_idx
209
+ )
210
+
211
+
212
+ class VispyScene(Scene):
213
+ def __init__(
214
+ self,
215
+ show_cs=False,
216
+ show_cs_root=True,
217
+ width: int = 320,
218
+ height: int = 240,
219
+ camera: scene.cameras.BaseCamera = scene.TurntableCamera(
220
+ elevation=25, distance=4.0, azimuth=25
221
+ ),
222
+ **kwargs,
223
+ ):
224
+ """Scene which can be rendered.
225
+
226
+ Args:
227
+ geoms (list[list[Geometry]]): A list of list of geometries per link.
228
+ len(geoms) == number of links in system
229
+ show_cs (bool, optional): Show coordinate system of links.
230
+ Defaults to True.
231
+ show_cs_root (bool, optional): Show coordinate system of earth frame.
232
+ Defaults to True.
233
+ camera (scene.cameras.BaseCamera, optional): The camera angle.
234
+ Defaults to scene.TurntableCamera( elevation=30, distance=6 ).
235
+
236
+ Example:
237
+ >> scene = VispyScene()
238
+ >> scene.init(sys.geoms)
239
+ >> scene.update(state.x)
240
+ >> image = scene.render()
241
+ """
242
+ self.canvas = scene.SceneCanvas(
243
+ keys="interactive", size=(width, height), show=True, **kwargs
244
+ )
245
+ self.view = self.canvas.central_widget.add_view()
246
+ self._set_camera(camera)
247
+ if show_cs:
248
+ self.enable_xyz()
249
+ else:
250
+ self.disable_xyz(not show_cs_root)
251
+
252
+ def _set_camera(self, camera: scene.cameras.BaseCamera) -> None:
253
+ self.view.camera = camera
254
+
255
+ def _get_camera(self) -> scene.cameras.BaseCamera:
256
+ return self.view.camera
257
+
258
+ def _render(self) -> jax.Array:
259
+ return self.canvas.render(alpha=True)
260
+
261
+ def _add_box(self, box: base.Box) -> Visual:
262
+ return vispy_visuals.Box(
263
+ box.dim_x,
264
+ box.dim_z,
265
+ box.dim_y,
266
+ color=box.color,
267
+ edge_color=box.edge_color,
268
+ parent=self.view.scene,
269
+ )
270
+
271
+ def _add_sphere(self, sphere: base.Sphere) -> Visual:
272
+ return vispy_visuals.Sphere(
273
+ sphere.radius,
274
+ color=sphere.color,
275
+ edge_color=sphere.edge_color,
276
+ parent=self.view.scene,
277
+ )
278
+
279
+ def _add_cylinder(self, cyl: base.Cylinder) -> Visual:
280
+ return vispy_visuals.Cylinder(
281
+ cyl.radius,
282
+ cyl.length,
283
+ color=cyl.color,
284
+ edge_color=cyl.edge_color,
285
+ parent=self.view.scene,
286
+ )
287
+
288
+ def _add_capsule(self, cap: base.Capsule) -> Visual:
289
+ return vispy_visuals.Capsule(
290
+ cap.radius,
291
+ cap.length,
292
+ color=cap.color,
293
+ edge_color=cap.edge_color,
294
+ parent=self.view.scene,
295
+ )
296
+
297
+ def _add_xyz(self) -> Visual:
298
+ return scene.visuals.XYZAxis(parent=self.view.scene)
299
+
300
+ def _remove_visual(self, visual: scene.visuals.VisualNode) -> None:
301
+ visual.parent = None
302
+
303
+ @staticmethod
304
+ def _compute_transform_per_visual(
305
+ x_links: base.Transform,
306
+ x_link_to_geom: base.Transform,
307
+ geom_link_idx: int,
308
+ ) -> jax.Array:
309
+ x = jax.lax.cond(
310
+ geom_link_idx == -1,
311
+ lambda: base.Transform.zero(),
312
+ lambda: x_links[geom_link_idx],
313
+ )
314
+ x = algebra.transform_mul(x_link_to_geom, x)
315
+ E = maths.quat_to_3x3(x.rot)
316
+ M = jnp.eye(4)
317
+ M = M.at[:3, :3].set(E)
318
+ T = jnp.eye(4)
319
+ T = T.at[3, :3].set(x.pos)
320
+ return M @ T
321
+
322
+ @staticmethod
323
+ def _postprocess_transforms(transform: jax.Array) -> np.ndarray:
324
+ return np.asarray(transform)
325
+
326
+ def _init_visual(
327
+ self,
328
+ visual: scene.visuals.VisualNode,
329
+ transform: np.ndarray,
330
+ geom: base.Geometry,
331
+ ):
332
+ visual.transform = MatrixTransform(transform)
333
+
334
+ def _update_visual(
335
+ self,
336
+ visual: scene.visuals.VisualNode,
337
+ transform: np.ndarray,
338
+ geom: base.Geometry,
339
+ ):
340
+ visual.transform.matrix = transform
@@ -0,0 +1,290 @@
1
+ import numpy as np
2
+ from ring.base import Color
3
+ from vispy.geometry.meshdata import MeshData
4
+ from vispy.scene.visuals import create_visual_node
5
+ from vispy.scene.visuals import Mesh
6
+ from vispy.visuals import CompoundVisual
7
+ from vispy.visuals import SphereVisual as _SphereVisual
8
+ from vispy.visuals import TubeVisual
9
+
10
+ # vertex density per unit length
11
+ _vectices_per_unit_length = 10
12
+
13
+ _default_color = (1, 0.8, 0.7, 1)
14
+ _default_edge_color = "black"
15
+
16
+
17
+ class DoubleMeshVisual(CompoundVisual):
18
+ _lines: Mesh
19
+ _faces: Mesh
20
+
21
+ def __init__(
22
+ self, verts, edges, faces, *, color: Color = None, edge_color: Color = None
23
+ ):
24
+ if color is None and edge_color is None:
25
+ color = _default_color
26
+
27
+ if color is not None:
28
+ self._faces = Mesh(verts, faces, color=color, shading=None)
29
+ self.light_dir = np.array([0, -1, 0])
30
+ else:
31
+ self._faces = Mesh()
32
+
33
+ if edge_color is not None:
34
+ self._edges = Mesh(verts, edges, color=edge_color, mode="lines")
35
+ else:
36
+ self._edges = Mesh()
37
+
38
+ super().__init__([self._faces, self._edges])
39
+ self._faces.set_gl_state(
40
+ polygon_offset_fill=True, polygon_offset=(1, 1), depth_test=True
41
+ )
42
+
43
+
44
+ class SphereVisual(_SphereVisual):
45
+ def __init__(self, radius: float, color: Color = None, edge_color: Color = None):
46
+ if color is None and edge_color is None:
47
+ color = _default_color
48
+
49
+ radius = float(radius)
50
+
51
+ num_rows = max(int(np.pi * radius * _vectices_per_unit_length), 10)
52
+ num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
53
+
54
+ super().__init__(
55
+ radius,
56
+ color=color,
57
+ edge_color=edge_color,
58
+ rows=num_rows,
59
+ cols=num_cols,
60
+ method="latitude",
61
+ shading="smooth",
62
+ )
63
+
64
+
65
+ Sphere = create_visual_node(SphereVisual)
66
+
67
+
68
+ def box_mesh(
69
+ dim_x: float, dim_y: float, dim_z: float
70
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
71
+ verts = np.array(
72
+ [
73
+ (-dim_x, -dim_y, -dim_z),
74
+ (dim_x, -dim_y, -dim_z),
75
+ (-dim_x, dim_y, -dim_z),
76
+ (dim_x, dim_y, -dim_z),
77
+ (-dim_x, -dim_y, dim_z),
78
+ (dim_x, -dim_y, dim_z),
79
+ (-dim_x, dim_y, dim_z),
80
+ (dim_x, dim_y, dim_z),
81
+ ],
82
+ dtype=np.float32,
83
+ )
84
+
85
+ verts /= 2
86
+
87
+ edges = np.array(
88
+ [
89
+ (0, 1),
90
+ (0, 2),
91
+ (0, 4),
92
+ (1, 3),
93
+ (1, 5),
94
+ (2, 3),
95
+ (2, 6),
96
+ (3, 7),
97
+ (4, 5),
98
+ (4, 6),
99
+ (5, 7),
100
+ (6, 7),
101
+ ],
102
+ dtype=np.uint32,
103
+ )
104
+
105
+ faces = np.array(
106
+ [
107
+ (0, 1, 2),
108
+ (1, 2, 3),
109
+ (0, 1, 4),
110
+ (1, 4, 5),
111
+ (0, 2, 4),
112
+ (2, 4, 6),
113
+ (1, 3, 5),
114
+ (3, 5, 7),
115
+ (2, 3, 6),
116
+ (3, 6, 7),
117
+ (4, 5, 6),
118
+ (5, 6, 7),
119
+ ]
120
+ )
121
+
122
+ return verts, edges, faces
123
+
124
+
125
+ class BoxVisual(DoubleMeshVisual):
126
+ # NOTE: need a custom BoxVisual class, since vispy.scene.visuals.Box does not
127
+ # support shading
128
+
129
+ def __init__(
130
+ self,
131
+ dim_x: float,
132
+ dim_y: float,
133
+ dim_z: float,
134
+ *,
135
+ color: Color = None,
136
+ edge_color: Color = None
137
+ ):
138
+ if color is None:
139
+ color = _default_color
140
+
141
+ if edge_color is None:
142
+ edge_color = _default_edge_color
143
+
144
+ dim_x = float(dim_x)
145
+ dim_y = float(dim_y)
146
+ dim_z = float(dim_z)
147
+
148
+ self.dim_x = dim_x
149
+ self.dim_y = dim_y
150
+ self.dim_z = dim_z
151
+
152
+ verts, edges, faces = box_mesh(dim_x, dim_y, dim_z)
153
+
154
+ super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
155
+
156
+
157
+ Box = create_visual_node(BoxVisual)
158
+
159
+
160
+ class CylinderVisual(TubeVisual):
161
+ def __init__(
162
+ self,
163
+ radius: float,
164
+ length: float,
165
+ *,
166
+ color: Color = None,
167
+ edge_color: Color = None
168
+ ):
169
+ if color is None and edge_color is None:
170
+ color = _default_color
171
+
172
+ radius = float(radius)
173
+ length = float(length)
174
+
175
+ num_length_points = 10 * max(int(length * _vectices_per_unit_length), 10)
176
+ num_radial_points = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
177
+
178
+ points = np.zeros((num_length_points, 3))
179
+ points[:, 0] = np.linspace(-length / 2, length / 2, num_length_points)
180
+
181
+ self.radius = radius
182
+ self.length = length
183
+
184
+ super().__init__(
185
+ points,
186
+ radius,
187
+ tube_points=num_radial_points,
188
+ closed=True,
189
+ color=color,
190
+ shading="smooth",
191
+ )
192
+
193
+
194
+ Cylinder = create_visual_node(CylinderVisual)
195
+
196
+
197
+ def capsule_mesh(radius: float, length: float, offset: bool = True) -> MeshData:
198
+ if length < 2 * radius:
199
+ raise ValueError("length must be at least 2 * radius")
200
+
201
+ # number of cap vertices in x direction
202
+ num_sphere_rows = max(int(radius * _vectices_per_unit_length), 10)
203
+
204
+ # length without caps
205
+ cyl_length = length - 2 * radius
206
+ # number of cylinder vertices in x direction
207
+ num_cyl_rows = max(int(cyl_length * _vectices_per_unit_length), 10)
208
+
209
+ num_total_rows = 2 * num_sphere_rows + num_cyl_rows
210
+
211
+ # number of radial vertices
212
+ num_cols = max(int(2 * np.pi * radius * _vectices_per_unit_length), 20)
213
+
214
+ verts = np.empty((num_total_rows, num_cols, 3), dtype=np.float32)
215
+
216
+ # polar angle
217
+ theta_top = np.linspace(0.0, np.pi / 2, num_sphere_rows)
218
+ theta_bottom = np.linspace(np.pi / 2, np.pi, num_sphere_rows)
219
+
220
+ # fill in x coordinate
221
+ verts[:num_sphere_rows, :, 0] = radius * np.cos(theta_top[:, None]) + cyl_length / 2
222
+
223
+ verts[num_sphere_rows:-num_sphere_rows, :, 0] = np.linspace(
224
+ -cyl_length / 2, cyl_length / 2, num_cyl_rows
225
+ )[::-1, None]
226
+
227
+ verts[-num_sphere_rows:, :, 0] = (
228
+ radius * np.cos(theta_bottom[:, None]) - cyl_length / 2
229
+ )
230
+
231
+ # azimuth angle
232
+ phi = (np.linspace(0, 2 * np.pi, num_cols))[None, :]
233
+
234
+ if offset:
235
+ # rotate each row by 1/2 column
236
+ phi = phi + (np.pi / num_cols) * np.arange(num_total_rows)[:, None]
237
+
238
+ # y and z coordinates
239
+ verts[..., 1] = radius * np.cos(phi)
240
+ verts[..., 2] = radius * np.sin(phi)
241
+
242
+ # for caps: bend inwards to close
243
+ verts[:num_sphere_rows, :, 1:3] *= np.sin(theta_top[:, None, None])
244
+ verts[-num_sphere_rows:, :, 1:3] *= np.sin(theta_bottom[:, None, None])
245
+
246
+ verts = verts.reshape(-1, 3)
247
+
248
+ # compute faces
249
+ faces = np.empty(((num_total_rows - 1) * num_cols * 2, 3), dtype=np.uint32)
250
+
251
+ rowtemplate1 = (
252
+ (np.arange(num_cols).reshape(num_cols, 1) + np.array([[0, 1, 0]])) % num_cols
253
+ ) + np.array([[num_cols, 0, 0]])
254
+
255
+ rowtemplate2 = (
256
+ (np.arange(num_cols).reshape(num_cols, 1) + np.array([[1, 1, 0]])) % num_cols
257
+ ) + np.array([[num_cols, 0, num_cols]])
258
+
259
+ for row in range(num_total_rows - 1):
260
+ start = row * num_cols * 2
261
+
262
+ faces[start : start + num_cols] = rowtemplate1 + row * num_cols
263
+ faces[start + num_cols : start + (num_cols * 2)] = rowtemplate2 + row * num_cols
264
+
265
+ mesh = MeshData(vertices=verts, faces=faces)
266
+
267
+ return mesh.get_vertices(), mesh.get_edges(), mesh.get_faces()
268
+
269
+
270
+ class CapsuleVisual(DoubleMeshVisual):
271
+ def __init__(
272
+ self,
273
+ radius: float,
274
+ length: float,
275
+ *,
276
+ color: Color = None,
277
+ edge_color: Color = None
278
+ ):
279
+ radius = float(radius)
280
+ length = float(length)
281
+
282
+ self.radius = radius
283
+ self.length = length
284
+
285
+ verts, edges, faces = capsule_mesh(radius, length)
286
+
287
+ super().__init__(verts, edges, faces, color=color, edge_color=edge_color)
288
+
289
+
290
+ Capsule = create_visual_node(CapsuleVisual)
@@ -0,0 +1,7 @@
1
+ from .sim2real import delete_to_world_pos_rot
2
+ from .sim2real import match_xs
3
+ from .sim2real import randomize_to_world_pos_rot
4
+ from .sim2real import scale_xs
5
+ from .sim2real import unzip_xs
6
+ from .sim2real import xs_from_raw
7
+ from .sim2real import zip_xs