imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,222 @@
1
+ import warnings
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from ring import base
6
+ from ring import io
7
+ import tree_utils
8
+
9
+
10
+ def imu_reference_link_name(imu_link_name: str) -> str:
11
+ return "_" + imu_link_name
12
+
13
+
14
+ def unactuated_subsystem(sys) -> list[str]:
15
+ return [imu_reference_link_name(name) for name in sys.findall_imus()]
16
+
17
+
18
+ def _subsystem_factory(imu_name: str, pos_min_max: float) -> base.System:
19
+ assert pos_min_max >= 0
20
+ pos = f'pos_min="-{pos_min_max} -{pos_min_max} -{pos_min_max}" pos_max="{pos_min_max} {pos_min_max} {pos_min_max}"' # noqa: E501
21
+ stiff = 'spring_stiff="50 50 50"'
22
+ damping = 'damping="5 5 5"'
23
+ return io.load_sys_from_str(
24
+ f"""
25
+ <x_xy>
26
+ <worldbody>
27
+ <body name="{imu_name}" joint="p3d" {pos if pos_min_max != 0.0 else ""} {stiff} {damping}/>
28
+ </worldbody>
29
+ </x_xy>
30
+ """ # noqa: E501
31
+ )
32
+
33
+
34
+ def inject_subsystems(
35
+ sys: base.System,
36
+ pos_min_max: float = 0.0,
37
+ **kwargs,
38
+ ) -> base.System:
39
+ imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
40
+
41
+ default_spher_stif = jnp.ones((3,)) * 0.3
42
+ default_spher_damp = default_spher_stif * 0.1
43
+ for imu in sys.findall_imus():
44
+ sys = sys.unfreeze(imu, "spherical")
45
+ # set default stiffness and damping of spherical joint
46
+ # this won't override anything because the frozen joint can not have any values
47
+ qd_slice = sys.idx_map("d")[imu]
48
+ stiffne = sys.link_spring_stiffness.at[qd_slice].set(default_spher_stif)
49
+ damping = sys.link_damping.at[qd_slice].set(default_spher_damp)
50
+ sys = sys.replace(link_spring_stiffness=stiffne, link_damping=damping)
51
+
52
+ _imu = imu_reference_link_name(imu)
53
+ sys = sys.change_link_name(imu, _imu)
54
+ sys = sys.inject_system(_subsystem_factory(imu, pos_min_max), _imu)
55
+
56
+ # attach geoms to newly injected link
57
+ new_geoms = []
58
+
59
+ for geom in sys.geoms:
60
+ if geom.link_idx in imu_idx_to_name_map:
61
+ imu_name = imu_idx_to_name_map[geom.link_idx]
62
+ new_link_idx = sys.name_to_idx(imu_name)
63
+ geom = geom.replace(link_idx=new_link_idx)
64
+ new_geoms.append(geom)
65
+
66
+ sys = sys.replace(geoms=new_geoms)
67
+
68
+ # TODO investigate whether this parse is needed; I don't think so
69
+ # re-calculate the inertia matrices because the geoms have been re-attached
70
+ sys = sys.parse()
71
+
72
+ # TODO set all joint_params to zeros; they can not be preserved anyways and
73
+ # otherwise many warnings will be rose
74
+ # instead warn explicitly once now and move on
75
+ warnings.warn(
76
+ "`sys.links.joint_params` has been set to zero, this might lead to "
77
+ "unexpected behaviour unless you use `randomize_joint_params`"
78
+ )
79
+ joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
80
+ sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
81
+
82
+ # double load; this fixes the issue that injected links got appended at the end
83
+ sys = io.load_sys_from_str(io.save_sys_to_str(sys))
84
+
85
+ return sys
86
+
87
+
88
+ _STIF_MIN_SPH = 0.2
89
+ _STIF_MAX_SPH = 10.0
90
+ _STIF_MIN_P3D = 25.0
91
+ _STIF_MAX_P3D = 1e3
92
+ # damping = factor * stiffness
93
+ _DAMP_MIN = 0.05
94
+ _DAMP_MAX = 0.5
95
+
96
+
97
+ def _log_uniform(key, shape, minval, maxval):
98
+ assert 0 <= minval <= maxval
99
+ minval, maxval = map(jnp.log, (minval, maxval))
100
+ return jnp.exp(jax.random.uniform(key, shape, minval=minval, maxval=maxval))
101
+
102
+
103
+ def setup_fn_randomize_damping_stiffness_factory(
104
+ prob_rigid: float,
105
+ all_imus_either_rigid_or_flex: bool,
106
+ imus_surely_rigid: list[str],
107
+ ):
108
+ assert 0 <= prob_rigid <= 1
109
+ assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
110
+ if prob_rigid == 0.0:
111
+ assert len(imus_surely_rigid) == 0
112
+
113
+ def stif_damp_rigid(key):
114
+ stif_sph = 200.0 * jnp.ones((3,))
115
+ stif_p3d = 2e4 * jnp.ones((3,))
116
+ stif = jnp.concatenate((stif_sph, stif_p3d))
117
+ return stif, stif * 0.2
118
+
119
+ def stif_damp_nonrigid(key):
120
+ keys = jax.random.split(key, 3)
121
+ stif_sph = _log_uniform(keys[0], (3,), _STIF_MIN_SPH, _STIF_MAX_SPH)
122
+ stif_p3d = _log_uniform(keys[1], (3,), _STIF_MIN_P3D, _STIF_MAX_P3D)
123
+ stif = jnp.concatenate((stif_sph, stif_p3d))
124
+ damp = _log_uniform(keys[2], (6,), _DAMP_MIN, _DAMP_MAX)
125
+ return stif, stif * damp
126
+
127
+ def setup_fn_randomize_damping_stiffness(key, sys: base.System) -> base.System:
128
+ link_damping = sys.link_damping
129
+ link_spring_stiffness = sys.link_spring_stiffness
130
+
131
+ idx_map = sys.idx_map("d")
132
+ imus = sys.findall_imus()
133
+
134
+ # initialize this RV because it might not get redrawn if
135
+ # `all_imus_either_rigid_or_flex` is set
136
+ key, consume = jax.random.split(key)
137
+ is_rigid = jax.random.bernoulli(consume, prob_rigid)
138
+
139
+ # this is only for the assertion used below
140
+ triggered_surely_rigid = []
141
+
142
+ for imu in imus:
143
+ # _imu has spherical joint and imu has p3d joint
144
+ slice = jnp.r_[idx_map[imu_reference_link_name(imu)], idx_map[imu]]
145
+ key, c1, c2 = jax.random.split(key, 3)
146
+
147
+ if prob_rigid > 0:
148
+ if imu in imus_surely_rigid:
149
+ triggered_surely_rigid.append(imu)
150
+ # logging.debug(f"IMU {imu} is surely rigid.")
151
+ stif, damp = stif_damp_rigid(c2)
152
+ else:
153
+ if not all_imus_either_rigid_or_flex:
154
+ is_rigid = jax.random.bernoulli(c1, prob_rigid)
155
+ stif, damp = jax.lax.cond(
156
+ is_rigid, stif_damp_rigid, stif_damp_nonrigid, c2
157
+ )
158
+ else:
159
+ stif, damp = stif_damp_nonrigid(c2)
160
+ link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
161
+ link_damping = link_damping.at[slice].set(damp)
162
+
163
+ assert len(imus_surely_rigid) == len(triggered_surely_rigid)
164
+ for imu_surely_rigid in imus_surely_rigid:
165
+ assert imu_surely_rigid in triggered_surely_rigid
166
+
167
+ return sys.replace(
168
+ link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
169
+ )
170
+
171
+ return setup_fn_randomize_damping_stiffness
172
+
173
+
174
+ def _match_q_x_between_sys(
175
+ sys_small: base.System,
176
+ q_large: jax.Array,
177
+ x_large: base.Transform,
178
+ sys_large: base.System,
179
+ q_large_skip: list[str],
180
+ ) -> tree_utils.PyTree:
181
+ assert q_large.ndim == 2
182
+ assert q_large.shape[1] == sys_large.q_size()
183
+ assert x_large.shape(1) == sys_large.num_links()
184
+
185
+ x_small_indices = []
186
+ q_small = []
187
+ q_idx_map = sys_large.idx_map("q")
188
+
189
+ def f(_, __, name: str):
190
+ x_small_indices.append(sys_large.name_to_idx(name))
191
+ # for the imu links the joint type was changed from spherical to frozen
192
+ # thus the q_idx_map has slices of length 4 but the `sys_small` has those
193
+ # imus but with frozen joint type and thus slices of length 0; so skip them
194
+ if name in q_large_skip:
195
+ return
196
+ q_small.append(q_large[:, q_idx_map[name]])
197
+
198
+ sys_small.scan(f, "l", sys_small.link_names)
199
+
200
+ x_small = tree_utils.tree_indices(x_large, jnp.array(x_small_indices), axis=1)
201
+ q_small = jnp.concatenate(q_small, axis=1)
202
+ return q_small, x_small
203
+
204
+
205
+ class GeneratorTrafoHideInjectedBodies:
206
+ def __call__(self, gen):
207
+ def _gen(*args):
208
+ (X, y), (key, q, x, sys_x) = gen(*args)
209
+
210
+ # delete injected frames; then rename from `_imu` back to `imu`
211
+ imus = sys_x.findall_imus()
212
+ _imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
213
+ sys = sys_x.delete_system(imus)
214
+ for _imu, imu in _imu2imu_map.items():
215
+ sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
216
+
217
+ # match q and x to `sys`; second axis is link axis
218
+ q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
219
+
220
+ return (X, y), (key, q, x, sys)
221
+
222
+ return _gen
@@ -0,0 +1,182 @@
1
+ from types import SimpleNamespace
2
+ from typing import Optional
3
+
4
+ from flax import struct
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from ring import base
8
+ from ring.algorithms import dynamics
9
+ from ring.algorithms import jcalc
10
+
11
+
12
+ @struct.dataclass
13
+ class PDControllerState:
14
+ i: int
15
+ q_ref_as_dict: dict
16
+ qd_ref_as_dict: dict
17
+ P_gains: dict
18
+ D_gains: dict
19
+
20
+
21
+ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
22
+ """Computes tau using a PD controller. Returns a pair of (init, apply) functions.
23
+
24
+ NOTE: Gains around ~10_000 are good for spherical joints, everything else ~250-300
25
+ works just fine. Damping should be about 2500 for spherical joints, and
26
+ about 25 for everything else.
27
+
28
+ Args:
29
+ P: jax.Array of P gains. Shape: (sys_init.qd_size())
30
+ D: jax.Array of D gains. Shape: (sys_init.qd_size()) where `sys_init` is the
31
+ system that recorded the reference trajectory `q_ref`
32
+ If not given, then no D control is applied.
33
+
34
+ Returns: Pair of (init, apply) functions
35
+ init: (sys, q_ref) -> controller_state
36
+ apply: (controller_state, sys, state) -> controller_state, tau
37
+
38
+ Example:
39
+ >>> gains = jnp.array([250.0] * sys1.qd_size())
40
+ >>> controller = pd_control(gains, gains)
41
+ >>> q_ref = rcmg(sys1)
42
+ >>> cs = controller.init(sys1, q_ref)
43
+ >>> for t in range(1000):
44
+ >>> cs, tau = controller.apply(cs, sys2, state)
45
+ >>> state = dynamics.step(sys2, state, tau)
46
+ """
47
+
48
+ def init(sys: base.System, q_ref: jax.Array) -> dict:
49
+ assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
50
+ assert sys.qd_size() == P.size
51
+ if D is not None:
52
+ sys.qd_size() == D.size
53
+
54
+ q_ref_as_dict = {}
55
+ qd_ref_as_dict = {}
56
+ P_as_dict = {}
57
+ D_as_dict = {}
58
+
59
+ def f(_, __, q_ref_link, name, typ, P_link, D_link):
60
+ P_as_dict[name] = P_link
61
+ q_ref_link = q_ref_link.T
62
+ q_ref_as_dict[name] = q_ref_link
63
+
64
+ if D is not None:
65
+ qd_from_q = jcalc.get_joint_model(typ).qd_from_q
66
+ if qd_from_q is None:
67
+ raise NotImplementedError(
68
+ f"Please specify `JointModel.qd_from_q` for joint type `{typ}`"
69
+ )
70
+ qd_ref_as_dict[name] = qd_from_q(q_ref_link, sys.dt)
71
+ D_as_dict[name] = D_link
72
+
73
+ sys.scan(
74
+ f,
75
+ "qlldd",
76
+ q_ref.T,
77
+ sys.link_names,
78
+ sys.link_types,
79
+ P,
80
+ D if D is not None else jnp.zeros((sys.qd_size(),)),
81
+ )
82
+ return PDControllerState(0, q_ref_as_dict, qd_ref_as_dict, P_as_dict, D_as_dict)
83
+
84
+ def apply(
85
+ controller_state: PDControllerState, sys: base.System, state: base.State
86
+ ) -> jax.Array:
87
+ taus = jnp.zeros((sys.qd_size()))
88
+ q_ref, qd_ref = jax.tree_map(
89
+ lambda arr: jax.lax.dynamic_index_in_dim(
90
+ arr, controller_state.i, keepdims=False
91
+ ),
92
+ (controller_state.q_ref_as_dict, controller_state.qd_ref_as_dict),
93
+ )
94
+
95
+ def f(_, idx_map, idx, name, typ, q_curr, qd_curr):
96
+ nonlocal taus
97
+
98
+ if name not in controller_state.q_ref_as_dict:
99
+ return
100
+
101
+ p_control_term = jcalc.get_joint_model(typ).p_control_term
102
+ if p_control_term is None:
103
+ raise NotImplementedError(
104
+ f"Please specify `JointModel.p_control_term` for joint type `{typ}`"
105
+ )
106
+ P_term = p_control_term(q_curr, q_ref[name])
107
+ tau = P_term * controller_state.P_gains[name]
108
+
109
+ if name in controller_state.qd_ref_as_dict:
110
+ D_term = (qd_ref[name] - qd_curr) * controller_state.D_gains[name]
111
+ tau += D_term
112
+
113
+ taus = taus.at[idx_map["d"](idx)].set(tau)
114
+
115
+ sys.scan(
116
+ f,
117
+ "lllqd",
118
+ list(range(sys.num_links())),
119
+ sys.link_names,
120
+ sys.link_types,
121
+ state.q,
122
+ state.qd,
123
+ )
124
+
125
+ return controller_state.replace(i=controller_state.i + 1), taus
126
+
127
+ return SimpleNamespace(init=init, apply=apply)
128
+
129
+
130
+ def _unroll_dynamics_pd_control(
131
+ sys: base.System,
132
+ q_ref: jax.Array,
133
+ P: jax.Array,
134
+ D: Optional[jax.Array] = None,
135
+ nograv: bool = False,
136
+ sys_q_ref: Optional[base.System] = None,
137
+ initial_sim_state_is_zeros: bool = False,
138
+ clip_taus: Optional[float] = None,
139
+ ):
140
+ assert q_ref.ndim == 2
141
+
142
+ if sys_q_ref is None:
143
+ sys_q_ref = sys
144
+
145
+ if nograv:
146
+ sys = sys.replace(gravity=sys.gravity * 0.0)
147
+
148
+ if initial_sim_state_is_zeros:
149
+ state = base.State.create(sys)
150
+ else:
151
+ state = _initial_q_is_q_ref(sys, sys_q_ref, q_ref[0])
152
+
153
+ controller = _pd_control(P, D)
154
+ cs = controller.init(sys_q_ref, q_ref)
155
+
156
+ def step(carry, _):
157
+ state, cs = carry
158
+ cs, taus = controller.apply(cs, sys, state)
159
+ if clip_taus is not None:
160
+ assert clip_taus > 0.0
161
+ taus = jnp.clip(taus, -clip_taus, clip_taus)
162
+ state = dynamics.step(sys, state, taus)
163
+ carry = (state, cs)
164
+ return carry, state
165
+
166
+ states = jax.lax.scan(step, (state, cs), None, length=q_ref.shape[0])[1]
167
+ return states
168
+
169
+
170
+ def _initial_q_is_q_ref(sys: base.System, sys_q_ref: base.System, q_ref):
171
+ # you can not preallocate q using zeros because of quaternions..
172
+ q = base.State.create(sys).q
173
+
174
+ sys_q_map = sys.idx_map("q")
175
+
176
+ def f(_, __, name, q_ref_link):
177
+ nonlocal q
178
+ q = q.at[sys_q_map[name]].set(q_ref_link)
179
+
180
+ sys_q_ref.scan(f, "lq", sys_q_ref.link_names, q_ref)
181
+
182
+ return base.State.create(sys, q=q)
@@ -0,0 +1,119 @@
1
+ """Randomization by modifying System and MotionConfig objects before building
2
+ generator."""
3
+
4
+ from dataclasses import replace
5
+ import itertools
6
+ from typing import Optional
7
+ import warnings
8
+
9
+ import jax.numpy as jnp
10
+ from ring import base
11
+ from ring.algorithms import jcalc
12
+ from ring.algorithms.generator import types
13
+
14
+
15
+ def _find_children(lam: list[int], body: int) -> list[int]:
16
+
17
+ children = []
18
+
19
+ def _children(body: int) -> None:
20
+ for i in range(len(lam)):
21
+ if lam[i] == body:
22
+ children.append(i)
23
+ _children(i)
24
+
25
+ _children(body)
26
+ return children
27
+
28
+
29
+ def _find_root_of_subsys_that_contains_body(sys: base.System, body: str) -> str:
30
+ body_i = sys.name_to_idx(body)
31
+ for i, p in enumerate(sys.link_parents):
32
+ if p == -1:
33
+ if body_i == i or body_i in _find_children(sys.link_parents, i):
34
+ return sys.idx_to_name(i)
35
+
36
+
37
+ def _assign_anchors_to_subsys(sys: base.System, anchors: list[str]) -> list[list[str]]:
38
+ anchors_per_subsys = []
39
+ for i, p in enumerate(sys.link_parents):
40
+ if p == -1:
41
+ link_idxs_subsys = [i] + _find_children(sys.link_parents, i)
42
+ link_names_subsys = [sys.idx_to_name(i) for i in link_idxs_subsys]
43
+ anchors_this_subsys = [
44
+ name for name in anchors if name in link_names_subsys
45
+ ]
46
+ if len(anchors_this_subsys) == 0:
47
+ anchors_this_subsys = [sys.idx_to_name(i)]
48
+ anchors_per_subsys.append(anchors_this_subsys)
49
+ return anchors_per_subsys
50
+
51
+
52
+ def _morph_extract_subsys(sys: base.System, anchor: str):
53
+ root = _find_root_of_subsys_that_contains_body(sys, anchor)
54
+ roots = sys.findall_bodies_to_world(names=True)
55
+ subsys = sys.delete_system(list(set(roots) - set([root])))
56
+ return subsys.morph_system(new_anchor=anchor)
57
+
58
+
59
+ def randomize_anchors(
60
+ sys: base.System, anchors: Optional[list[str]] = None
61
+ ) -> list[base.System]:
62
+
63
+ if anchors is None:
64
+ anchors = sys.findall_segments()
65
+
66
+ anchors = _assign_anchors_to_subsys(sys, anchors)
67
+ syss = []
68
+ for anchors_subsys in itertools.product(*anchors):
69
+ sys_mod = _morph_extract_subsys(sys, anchors_subsys[0])
70
+ for anchor_subsys in anchors_subsys[1:]:
71
+ sys_mod = sys_mod.inject_system(_morph_extract_subsys(sys, anchor_subsys))
72
+ syss.append(sys_mod)
73
+
74
+ return syss
75
+
76
+
77
+ _WARN_HZ_Threshold: float = 40.0
78
+
79
+
80
+ def randomize_hz(
81
+ sys: list[base.System],
82
+ configs: list[jcalc.MotionConfig],
83
+ sampling_rates: list[float],
84
+ ) -> tuple[list[base.System], list[jcalc.MotionConfig]]:
85
+ Ts = [c.T for c in configs]
86
+ assert len(set(Ts)), f"Time length between configs does not agree {Ts}"
87
+ T_global = Ts[0]
88
+
89
+ for hz in sampling_rates:
90
+ if hz < _WARN_HZ_Threshold:
91
+ warnings.warn(
92
+ "The sampling rate {hz} is below the warning threshold of "
93
+ f"{_WARN_HZ_Threshold}. This might lead to NaNs."
94
+ )
95
+
96
+ sys_out, configs_out = [], []
97
+ for _sys in sys:
98
+ for _config in configs:
99
+ for hz in sampling_rates:
100
+ dt = 1 / hz
101
+ T = (T_global / _sys.dt) * dt
102
+
103
+ sys_out.append(_sys.replace(dt=dt))
104
+ configs_out.append(replace(_config, T=T))
105
+ return sys_out, configs_out
106
+
107
+
108
+ def randomize_hz_finalize_fn_factory(finalize_fn_user: types.FINALIZE_FN | None):
109
+ def finalize_fn(key, q, x, sys: base.System):
110
+ X, y = {}, {}
111
+ if finalize_fn_user is not None:
112
+ X, y = finalize_fn_user(key, q, x, sys)
113
+
114
+ assert "dt" not in X
115
+ X["dt"] = jnp.array([sys.dt], dtype=jnp.float32)
116
+
117
+ return X, y
118
+
119
+ return finalize_fn