imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,288 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ from ring import algebra
5
+ from ring import base
6
+ from ring import io
7
+ from ring import maths
8
+ from ring.algorithms import generator
9
+ from ring.algorithms import jcalc
10
+ import tree_utils
11
+
12
+
13
+ def xs_from_raw(
14
+ sys: base.System,
15
+ link_name_pos_rot: dict,
16
+ eps_frame: Optional[str] = None,
17
+ qinv: bool = False,
18
+ ) -> base.Transform:
19
+ """Build time-series of maximal coordinates `xs` from raw position and
20
+ quaternion trajectory data. This function scans through each link (as
21
+ defined by `sys`), looks for the raw data in `link_name_pos_rot` using
22
+ the `link_name` as identifier. It inverts the quaternion if `qinv`.
23
+ Then, it creates a `Transform` that transforms from epsilon (as defined
24
+ by `eps_frame`) to the link for each timestep. Finally, it stacks all
25
+ transforms in order as defined by `sys` along the 1-th axis. The 0-th
26
+ axis is time axis.
27
+
28
+ Args:
29
+ sys (ring.base.System): System which defines ordering of returned `xs`
30
+ link_name_pos_rot (dict): Dictonary of `link_name` ->
31
+ {'pos': ..., 'quat': ...}. Obtained, e.g., using `process_omc`.
32
+ eps_frame (str, optional): Move into this segment's frame at time zero as
33
+ eps frame. Defaults to `None`.
34
+ If `None`: Don't move into a specific eps-frame.
35
+
36
+ Returns:
37
+ ring.base.Transform: Time-series of eps-to-link transformations
38
+ """
39
+ # determine `eps_frame` transform
40
+ if eps_frame is not None:
41
+ eps = link_name_pos_rot[eps_frame]
42
+ q_eps = eps["quat"][0]
43
+ if qinv:
44
+ q_eps = maths.quat_inv(q_eps)
45
+ t_eps = base.Transform(eps["pos"][0], q_eps)
46
+ else:
47
+ t_eps = base.Transform.zero()
48
+
49
+ # build `xs` from optical motion capture data
50
+ xs = []
51
+
52
+ def f(_, __, link_name: str):
53
+ q = link_name_pos_rot[link_name]["quat"]
54
+ pos = link_name_pos_rot[link_name].get("pos", None)
55
+ if qinv:
56
+ q = maths.quat_inv(q)
57
+ t = base.Transform.create(pos, q)
58
+ t = algebra.transform_mul(t, algebra.transform_inv(t_eps))
59
+ xs.append(t)
60
+
61
+ sys.scan(f, "l", sys.link_names)
62
+
63
+ # stack and permute such that time-axis is 0-th axis
64
+ xs = xs[0].batch(*xs[1:])
65
+ xs = xs.transpose((1, 0, 2))
66
+ return xs
67
+
68
+
69
+ def match_xs(
70
+ sys: base.System, xs: base.Transform, sys_xs: base.System
71
+ ) -> base.Transform:
72
+ """Match tranforms `xs` to subsystem `sys`.
73
+
74
+ Args:
75
+ sys (System): Smaller system. Every link in `sys` must be in `sys_xs`.
76
+ xs (Transform): Transforms of larger system.
77
+ sys_xs (Transform): Larger system.
78
+
79
+ Returns:
80
+ Transform: Transforms of smaller system.
81
+ """
82
+ _checks_time_series_of_xs(sys_xs, xs)
83
+
84
+ xs_small = xs_from_raw(
85
+ sys,
86
+ {
87
+ name: {
88
+ "pos": xs.pos[:, sys_xs.name_to_idx(name)],
89
+ "quat": xs.rot[:, sys_xs.name_to_idx(name)],
90
+ }
91
+ for name in sys_xs.link_names
92
+ },
93
+ eps_frame=None,
94
+ qinv=False,
95
+ )
96
+ return xs_small
97
+
98
+
99
+ def unzip_xs(
100
+ sys: base.System, xs: base.Transform
101
+ ) -> Tuple[base.Transform, base.Transform]:
102
+ """Split eps-to-link transforms into parent-to-child pure
103
+ translational `transform1` and pure rotational `transform2`.
104
+
105
+ Args:
106
+ sys (System): Defines scan.tree
107
+ xs (Transform): Eps-to-link transforms
108
+
109
+ Returns:
110
+ Tuple[Transform, Transform]: transform1, transform2
111
+ """
112
+ _checks_time_series_of_xs(sys, xs)
113
+
114
+ @jax.vmap
115
+ def _unzip_xs(xs):
116
+ def f(_, __, i: int, p: int):
117
+ if p == -1:
118
+ x_parent_to_link = xs[i]
119
+ else:
120
+ x_parent_to_link = algebra.transform_mul(
121
+ xs[i], algebra.transform_inv(xs[p])
122
+ )
123
+
124
+ transform1_pos = base.Transform.create(pos=x_parent_to_link.pos)
125
+ transform2_rot = base.Transform.create(rot=x_parent_to_link.rot)
126
+ return (transform1_pos, transform2_rot)
127
+
128
+ return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_parents)
129
+
130
+ return _unzip_xs(xs)
131
+
132
+
133
+ def zip_xs(
134
+ sys: base.System,
135
+ xs_transform1: base.Transform,
136
+ xs_transform2: base.Transform,
137
+ ) -> base.Transform:
138
+ """Performs forward kinematics using `transform1` and `transform2`.
139
+
140
+ Args:
141
+ sys (ring.base.System): Defines scan_sys
142
+ xs_transform1 (ring.base.Transform): Applied before `transform1`
143
+ xs_transform2 (ring.base.Transform): Applied after `transform2`
144
+
145
+ Returns:
146
+ ring.base.Transform: Time-series of eps-to-link transformations
147
+ """
148
+ _checks_time_series_of_xs(sys, xs_transform1)
149
+ _checks_time_series_of_xs(sys, xs_transform2)
150
+
151
+ @jax.vmap
152
+ def _zip_xs(xs_transform1, xs_transform2):
153
+ eps_to_l = {-1: base.Transform.zero()}
154
+
155
+ def f(_, __, i: int, p: int):
156
+ transform = algebra.transform_mul(xs_transform2[i], xs_transform1[i])
157
+ eps_to_l[i] = algebra.transform_mul(transform, eps_to_l[p])
158
+ return eps_to_l[i]
159
+
160
+ return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_parents)
161
+
162
+ return _zip_xs(xs_transform1, xs_transform2)
163
+
164
+
165
+ def _checks_time_series_of_xs(sys, xs):
166
+ assert tree_utils.tree_ndim(xs) == 3, f"pos.shape={xs.pos.shape}"
167
+ num_links_xs, num_links_sys = tree_utils.tree_shape(xs, axis=1), sys.num_links()
168
+ assert num_links_xs == num_links_sys, f"{num_links_xs} != {num_links_sys}"
169
+
170
+
171
+ def delete_to_world_pos_rot(sys: base.System, xs: base.Transform) -> base.Transform:
172
+ """Replace the transforms of all links that connect to the worldbody
173
+ by unity transforms.
174
+
175
+ Args:
176
+ sys (System): System only used for structure (in scan_sys).
177
+ xs (Transform): Time-series of transforms to be modified.
178
+
179
+ Returns:
180
+ Transform: Time-series of modified transforms.
181
+ """
182
+ _checks_time_series_of_xs(sys, xs)
183
+
184
+ zero_trafo = base.Transform.zero((xs.shape(),))
185
+ for i, p in enumerate(sys.link_parents):
186
+ if p == -1:
187
+ xs = _overwrite_transform_of_link_then_update(sys, xs, zero_trafo, i)
188
+ return xs
189
+
190
+
191
+ def randomize_to_world_pos_rot(
192
+ key: jax.Array, sys: base.System, xs: base.Transform, config: jcalc.MotionConfig
193
+ ) -> base.Transform:
194
+ """Replace the transforms of all links that connect to the worldbody
195
+ by randomize transforms.
196
+
197
+ Args:
198
+ key (jax.Array): PRNG Key.
199
+ sys (System): System only used for structure (in scan_sys).
200
+ xs (Transform): Time-series of transforms to be modified.
201
+ config (MotionConfig): Defines the randomization.
202
+
203
+ Returns:
204
+ Transform: Time-series of modified transforms.
205
+ """
206
+ _checks_time_series_of_xs(sys, xs)
207
+ assert sys.link_parents.count(-1) == 1, "Found multiple connections to world"
208
+
209
+ free_sys_str = """
210
+ <x_xy>
211
+ <options dt="0.01"/>
212
+ <worldbody>
213
+ <body name="free" joint="free"/>
214
+ </worldbody>
215
+ </x_xy>
216
+ """
217
+
218
+ free_sys = io.load_sys_from_str(free_sys_str)
219
+ _, xs_free = generator.RCMG(
220
+ free_sys, config, finalize_fn=lambda key, q, x, sys: (q, x)
221
+ ).to_lazy_gen()(key)
222
+ xs_free = xs_free.take(0, axis=0)
223
+ xs_free = xs_free.take(free_sys.name_to_idx("free"), axis=1)
224
+ link_idx_to_world = sys.link_parents.index(-1)
225
+ return _overwrite_transform_of_link_then_update(sys, xs, xs_free, link_idx_to_world)
226
+
227
+
228
+ def _overwrite_transform_of_link_then_update(
229
+ sys: base.System, xs: base.Transform, xs_new_link: base.Transform, new_link_idx: int
230
+ ):
231
+ """Replace transform and then perform forward kinematics."""
232
+ assert xs_new_link.ndim() == (xs.ndim() - 1) == 2
233
+ transform1, transform2 = unzip_xs(sys, xs)
234
+ transform1 = _replace_transform_of_link(transform1, xs_new_link, new_link_idx)
235
+ zero_trafo = base.Transform.zero((xs_new_link.shape(),))
236
+ transform2 = _replace_transform_of_link(transform2, zero_trafo, new_link_idx)
237
+ return zip_xs(sys, transform1, transform2)
238
+
239
+
240
+ def _replace_transform_of_link(
241
+ xs: base.Transform, xs_new_link: base.Transform, link_idx
242
+ ):
243
+ return xs.transpose((1, 0, 2)).index_set(link_idx, xs_new_link).transpose((1, 0, 2))
244
+
245
+
246
+ def scale_xs(
247
+ sys: base.System,
248
+ xs: base.Transform,
249
+ factor: float,
250
+ exclude: list[str] = ["px", "py", "pz", "free"],
251
+ ) -> base.Transform:
252
+ """Increase / decrease transforms by scaling their positional / rotational
253
+ components based on the systems link type, i.e. the `xs` should conceptionally
254
+ be `transform2` objects.
255
+
256
+ Args:
257
+ sys (System): System defining structure (for scan_sys)
258
+ xs (Transform): Time-series of transforms to be modified.
259
+ factor (float): Multiplicative factor.
260
+ exclude (list[str], optional): Skip scaling of transforms if their link_type
261
+ is one of those. Defaults to ["px", "py", "pz", "free"].
262
+
263
+ Returns:
264
+ Transform: Time-series of scaled transforms.
265
+ """
266
+ _checks_time_series_of_xs(sys, xs)
267
+
268
+ @jax.vmap
269
+ def _scale_xs(xs):
270
+ def f(_, __, i: int, type: str):
271
+ x_link = xs[i]
272
+ if type not in exclude:
273
+ x_link = _scale_transform_based_on_type(x_link, type, factor)
274
+ return x_link
275
+
276
+ return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_types)
277
+
278
+ return _scale_xs(xs)
279
+
280
+
281
+ def _scale_transform_based_on_type(x: base.Transform, link_type: str, factor: float):
282
+ pos, rot = x.pos, x.rot
283
+ if link_type in ["px", "py", "pz", "free"]:
284
+ pos = pos * factor
285
+ if link_type in ["rx", "ry", "rz", "spherical", "free"]:
286
+ axis, angle = maths.quat_to_rot_axis(rot)
287
+ rot = maths.quat_rot_axis(axis, angle * factor)
288
+ return base.Transform(pos, rot)
ring/spatial.py ADDED
@@ -0,0 +1,126 @@
1
+ """
2
+ Implements `Table 1` from
3
+ `A Beginner's Guide to 6-D Vectors (Part 2)`
4
+ by Roy Featherstone.
5
+
6
+ `A small but sufficient set of spatial arithemtic operations.`
7
+ """
8
+
9
+
10
+ import jax.numpy as jnp
11
+
12
+
13
+ def rx(theta):
14
+ """
15
+ [
16
+ 1 0 0
17
+ 0 c s
18
+ 0 -s c
19
+ ]
20
+ where c = cos(theta)
21
+ s = sin(theta)
22
+ """
23
+ s, c = jnp.sin(theta), jnp.cos(theta)
24
+ return jnp.array([[1, 0, 0], [0, c, s], [0, -s, c]])
25
+
26
+
27
+ def ry(theta):
28
+ """
29
+ [
30
+ c 0 -s
31
+ 0 1 0
32
+ s 0 c
33
+ ]
34
+ where c = cos(theta)
35
+ s = sin(theta)
36
+ """
37
+ s, c = jnp.sin(theta), jnp.cos(theta)
38
+ return jnp.array([[c, 0, -s], [0, 1, 0], [s, 0, c]])
39
+
40
+
41
+ def rz(theta):
42
+ """
43
+ [
44
+ c s 0
45
+ -s c 0
46
+ 0 0 1
47
+ ]
48
+ where c = cos(theta)
49
+ s = sin(theta)
50
+ """
51
+ s, c = jnp.sin(theta), jnp.cos(theta)
52
+ return jnp.array([[c, s, 0], [-s, c, 0], [0, 0, 1]])
53
+
54
+
55
+ def cross(r):
56
+ assert r.shape == (3,)
57
+ return jnp.array([[0, -r[2], r[1]], [r[2], 0, -r[0]], [-r[1], r[0], 0]])
58
+
59
+
60
+ def quadrants(aa=None, ab=None, ba=None, bb=None, default=jnp.zeros):
61
+ M = default((6, 6))
62
+ if aa is not None:
63
+ M = M.at[:3, :3].set(aa)
64
+ if ab is not None:
65
+ M = M.at[:3, 3:].set(ab)
66
+ if ba is not None:
67
+ M = M.at[3:, :3].set(ba)
68
+ if bb is not None:
69
+ M = M.at[3:, 3:].set(bb)
70
+ return M
71
+
72
+
73
+ def crm(v):
74
+ assert v.shape == (6,)
75
+ return quadrants(cross(v[:3]), ba=cross(v[3:]), bb=cross(v[:3]))
76
+
77
+
78
+ def crf(v):
79
+ return -crm(v).T
80
+
81
+
82
+ def _rotxyz(E):
83
+ return quadrants(E, bb=E)
84
+
85
+
86
+ def rotx(theta):
87
+ return _rotxyz(rx(theta))
88
+
89
+
90
+ def roty(theta):
91
+ return _rotxyz(ry(theta))
92
+
93
+
94
+ def rotz(theta):
95
+ return _rotxyz(rz(theta))
96
+
97
+
98
+ def xlt(r):
99
+ assert r.shape == (3,)
100
+ return quadrants(jnp.eye(3), ba=-cross(r), bb=jnp.eye(3))
101
+
102
+
103
+ def X_transform(E, r):
104
+ return _rotxyz(E) @ xlt(r)
105
+
106
+
107
+ def mcI(m, c, Ic):
108
+ assert c.shape == (3,)
109
+ assert Ic.shape == (3, 3)
110
+ return quadrants(
111
+ Ic - m * cross(c) @ cross(c), m * cross(c), -m * cross(c), m * jnp.eye(3)
112
+ )
113
+
114
+
115
+ def XtoV(X):
116
+ assert X.shape == (6, 6)
117
+ return 0.5 * jnp.array(
118
+ [
119
+ [X[1, 2] - X[2, 1]],
120
+ [X[2, 0] - X[0, 2]],
121
+ [X[0, 1] - X[1, 0]],
122
+ [X[4, 2] - X[5, 1]],
123
+ [X[5, 0] - X[3, 2]],
124
+ [X[3, 1] - X[4, 0]],
125
+ ]
126
+ )
@@ -0,0 +1,5 @@
1
+ from .delete_sys import delete_subsystem
2
+ from .delete_sys import make_sys_noimu
3
+ from .inject_sys import inject_system
4
+ from .morph_sys import identify_system
5
+ from .morph_sys import morph_system
@@ -0,0 +1,114 @@
1
+ from typing import Optional
2
+
3
+ import jax.numpy as jnp
4
+ from ring import base
5
+ import tree_utils
6
+
7
+
8
+ def _autodetermine_imu_names(sys: base.System) -> list[str]:
9
+ return sys.findall_imus()
10
+
11
+
12
+ def make_sys_noimu(sys: base.System, imu_link_names: Optional[list[str]] = None):
13
+ "Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}"
14
+ if imu_link_names is None:
15
+ imu_link_names = _autodetermine_imu_names(sys)
16
+ imu_attachment = {name: sys.parent_name(name) for name in imu_link_names}
17
+ sys_noimu = delete_subsystem(sys, imu_link_names)
18
+ return sys_noimu, imu_attachment
19
+
20
+
21
+ def delete_subsystem(
22
+ sys: base.System, link_name: str | list[str], strict: bool = True
23
+ ) -> base.System:
24
+ "Cut subsystem starting at `link_name` (inclusive) from tree."
25
+ if isinstance(link_name, list):
26
+ for ln in link_name:
27
+ sys = delete_subsystem(sys, ln, strict)
28
+ return sys
29
+
30
+ if not strict:
31
+ try:
32
+ return delete_subsystem(sys, link_name, strict=True)
33
+ except AssertionError:
34
+ return sys
35
+
36
+ assert (
37
+ link_name in sys.link_names
38
+ ), f"link {link_name} not found in {sys.link_names}"
39
+
40
+ subsys = _find_subsystem_indices(sys.link_parents, sys.name_to_idx(link_name))
41
+ idx_map, keep = _idx_map_and_keepers(sys.link_parents, subsys)
42
+
43
+ def take(list):
44
+ return [ele for i, ele in enumerate(list) if i in keep]
45
+
46
+ d, a, ss, sz = [], [], [], []
47
+
48
+ def filter_arrays(_, __, damp, arma, stiff, zero, i: int):
49
+ if i in keep:
50
+ d.append(damp)
51
+ a.append(arma)
52
+ ss.append(stiff)
53
+ sz.append(zero)
54
+
55
+ sys.scan(
56
+ filter_arrays,
57
+ "dddql",
58
+ sys.link_damping,
59
+ sys.link_armature,
60
+ sys.link_spring_stiffness,
61
+ sys.link_spring_zeropoint,
62
+ list(range(sys.num_links())),
63
+ )
64
+
65
+ d, a, ss, sz = map(jnp.concatenate, (d, a, ss, sz))
66
+
67
+ new_sys = base.System(
68
+ link_parents=_reindex_parent_array(sys.link_parents, subsys),
69
+ links=tree_utils.tree_indices(sys.links, jnp.array(keep, dtype=int)),
70
+ link_types=take(sys.link_types),
71
+ link_damping=d,
72
+ link_armature=a,
73
+ link_spring_stiffness=ss,
74
+ link_spring_zeropoint=sz,
75
+ dt=sys.dt,
76
+ geoms=[
77
+ geom.replace(link_idx=idx_map[geom.link_idx])
78
+ for geom in sys.geoms
79
+ if geom.link_idx in keep
80
+ ],
81
+ gravity=sys.gravity,
82
+ integration_method=sys.integration_method,
83
+ mass_mat_iters=sys.mass_mat_iters,
84
+ link_names=take(sys.link_names),
85
+ model_name=sys.model_name,
86
+ omc=take(sys.omc),
87
+ )
88
+
89
+ return new_sys.parse()
90
+
91
+
92
+ def _find_subsystem_indices(parents: list[int], k: int) -> list[int]:
93
+ subsys = [k]
94
+ for i, p in enumerate(parents):
95
+ if p in subsys:
96
+ subsys.append(i)
97
+ return subsys
98
+
99
+
100
+ def _idx_map_and_keepers(parents: list[int], subsys: list[int]):
101
+ num_links = len(parents)
102
+ # keep must be in ascending order
103
+ keep = []
104
+ for i in range(num_links):
105
+ if i not in subsys:
106
+ keep.append(i)
107
+
108
+ idx_map = dict(zip([-1] + keep, range(-1, len(keep))))
109
+ return idx_map, keep
110
+
111
+
112
+ def _reindex_parent_array(parents: list[int], subsys: list[int]) -> list[int]:
113
+ idx_map, keep = _idx_map_and_keepers(parents, subsys)
114
+ return [idx_map[p] for i, p in enumerate(parents) if i in keep]
@@ -0,0 +1,110 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from ring import base
6
+ from tree_utils import tree_batch
7
+
8
+
9
+ def _tree_nan_like(tree, repeats: int):
10
+ return jax.tree_map(
11
+ lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
12
+ )
13
+
14
+
15
+ # TODO
16
+ # right now this function won't really keep index ordering
17
+ # as one might expect.
18
+ # It simply appends the `sub_sys` at index end of `sys`, even
19
+ # though it might be injected in the middle of the index range
20
+ # of `sys`.
21
+ # This will be fixed once we have a `dump_sys_to_xml` function
22
+ def inject_system(
23
+ sys: base.System,
24
+ sub_sys: base.System,
25
+ at_body: Optional[str] = None,
26
+ ) -> base.System:
27
+ """Combine two systems into one.
28
+
29
+ Args:
30
+ sys (base.System): Large system.
31
+ sub_sys (base.System): Small system that will be included into the
32
+ large system `sys`.
33
+ at_body (Optional[str], optional): Into which body of the large system
34
+ small system will be included. Defaults to `worldbody`.
35
+
36
+ Returns:
37
+ base.System: _description_
38
+ """
39
+
40
+ # replace parent array
41
+ if at_body is None:
42
+ new_world = -1
43
+ else:
44
+ new_world = sys.name_to_idx(at_body)
45
+
46
+ # append sub_sys at index end and replace sub_sys world with `at_body`
47
+ N = sys.num_links()
48
+
49
+ def new_parent(old_parent: int):
50
+ if old_parent != -1:
51
+ return old_parent + N
52
+ else:
53
+ return new_world
54
+
55
+ sub_sys = sub_sys.replace(
56
+ link_parents=[new_parent(p) for p in sub_sys.link_parents]
57
+ )
58
+
59
+ # replace link indices of geoms in sub_sys
60
+ sub_sys = sub_sys.replace(
61
+ geoms=[
62
+ geom.replace(link_idx=new_parent(geom.link_idx)) for geom in sub_sys.geoms
63
+ ]
64
+ )
65
+
66
+ # build union of two joint_params dictionaries because each system might have custom
67
+ # joints that the other does not have
68
+ missing_in_sys = set(sub_sys.links.joint_params.keys()) - set(
69
+ sys.links.joint_params.keys()
70
+ )
71
+ sys_n_links = sys.num_links()
72
+ for typ in missing_in_sys:
73
+ sys.links.joint_params[typ] = _tree_nan_like(
74
+ sub_sys.links.joint_params[typ], sys_n_links
75
+ )
76
+
77
+ missing_in_subsys = set(
78
+ sys.links.joint_params.keys() - sub_sys.links.joint_params.keys()
79
+ )
80
+ subsys_n_links = sub_sys.num_links()
81
+ for typ in missing_in_subsys:
82
+ sub_sys.links.joint_params[typ] = _tree_nan_like(
83
+ sys.links.joint_params[typ], subsys_n_links
84
+ )
85
+
86
+ # merge two systems
87
+ concat = lambda a1, a2: tree_batch([a1, a2], True, "jax")
88
+ combined_sys = base.System(
89
+ link_parents=sys.link_parents + sub_sys.link_parents,
90
+ links=concat(sys.links, sub_sys.links),
91
+ link_types=sys.link_types + sub_sys.link_types,
92
+ link_damping=concat(sys.link_damping, sub_sys.link_damping),
93
+ link_armature=concat(sys.link_armature, sub_sys.link_armature),
94
+ link_spring_stiffness=concat(
95
+ sys.link_spring_stiffness, sub_sys.link_spring_stiffness
96
+ ),
97
+ link_spring_zeropoint=concat(
98
+ sys.link_spring_zeropoint, sub_sys.link_spring_zeropoint
99
+ ),
100
+ dt=sys.dt,
101
+ geoms=sys.geoms + sub_sys.geoms,
102
+ gravity=sys.gravity,
103
+ integration_method=sys.integration_method,
104
+ mass_mat_iters=sys.mass_mat_iters,
105
+ link_names=sys.link_names + sub_sys.link_names,
106
+ model_name=sys.model_name,
107
+ omc=sys.omc + sub_sys.omc,
108
+ )
109
+
110
+ return combined_sys.parse()