imt-ring 1.6.37__tar.gz → 1.6.39__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.37 → imt_ring-1.6.39}/PKG-INFO +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/pyproject.toml +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/custom_joints/suntay.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/dynamics.py +27 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/base.py +82 -2
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/batch.py +2 -2
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/finalize_fns.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/pd_control.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/jcalc.py +198 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/kinematics.py +2 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/sensors.py +12 -10
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/base.py +356 -27
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/from_xml.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/base.py +4 -3
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/ml_utils.py +3 -3
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/ringnet.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/train.py +2 -2
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/rendering/mujoco_render.py +11 -7
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/rendering/vispy_render.py +5 -4
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sys_composer/inject_sys.py +3 -2
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/batchsize.py +3 -3
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/dataloader.py +4 -3
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/dataloader_torch.py +14 -5
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/hdf5.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/normalizer.py +6 -5
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_ml_utils.py +1 -1
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_sim2real.py +3 -2
- {imt_ring-1.6.37 → imt_ring-1.6.39}/readme.md +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/setup.cfg +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/imt_ring.egg-info/SOURCES.txt +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algebra.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/setup_fns.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/examples.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/maths.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/spatial.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/path.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/randomize_sys.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/register_gym_envs/__init__.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/register_gym_envs/saddle.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/src/ring/utils/utils.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_algebra.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_base.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_dynamics.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_generator.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_jcalc.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_jit.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_kinematics.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_maths.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_pd_control.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_random.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_randomize.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_rcmg.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_render.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_sensors.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_train.py +0 -0
- {imt_ring-1.6.37 → imt_ring-1.6.39}/tests/test_utils.py +0 -0
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
184
184
|
|
185
185
|
suntay_link_name = _utils_find_suntay_joint(sys)
|
186
186
|
|
187
|
-
params = jax.
|
187
|
+
params = jax.tree.map(
|
188
188
|
lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
|
189
189
|
sys.links.joint_params[name],
|
190
190
|
)
|
@@ -303,7 +303,33 @@ def step(
|
|
303
303
|
taus: Optional[jax.Array] = None,
|
304
304
|
n_substeps: int = 1,
|
305
305
|
) -> base.State:
|
306
|
-
"
|
306
|
+
"""
|
307
|
+
Advances the system dynamics by a single timestep using semi-implicit Euler integration.
|
308
|
+
|
309
|
+
This function updates the system's state by integrating the equations of motion
|
310
|
+
over a timestep, potentially with multiple substeps for improved numerical stability.
|
311
|
+
The method ensures that the system's kinematics are updated before each integration step.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
sys (base.System):
|
315
|
+
The system to simulate, containing link information, joint dynamics, and integration parameters.
|
316
|
+
state (base.State):
|
317
|
+
The current state of the system, including joint positions (`q`), velocities (`qd`), and transforms (`x`).
|
318
|
+
taus (Optional[jax.Array], optional):
|
319
|
+
The control torques applied to the system joints. If `None`, zero torques are applied.
|
320
|
+
Defaults to `None`.
|
321
|
+
n_substeps (int, optional):
|
322
|
+
The number of integration substeps per timestep to improve numerical accuracy.
|
323
|
+
Defaults to `1`.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
base.State:
|
327
|
+
The updated state of the system after integration.
|
328
|
+
|
329
|
+
Raises:
|
330
|
+
AssertionError: If the system's degrees of freedom (`q` and `qd`) do not match expectations.
|
331
|
+
AssertionError: If an unsupported integration method is specified in `sys.integration_method`.
|
332
|
+
""" # noqa: E501
|
307
333
|
assert sys.q_size() == state.q.size
|
308
334
|
if taus is None:
|
309
335
|
taus = jnp.zeros_like(state.qd)
|
@@ -53,7 +53,85 @@ class RCMG:
|
|
53
53
|
cor: bool = False,
|
54
54
|
disable_tqdm: bool = False,
|
55
55
|
) -> None:
|
56
|
-
"
|
56
|
+
"""
|
57
|
+
Initializes the Random Chain Motion Generator (RCMG).
|
58
|
+
|
59
|
+
The RCMG generates synthetic joint motion sequences for kinematic and dynamic
|
60
|
+
systems based on predefined motion configurations. It allows for system
|
61
|
+
randomization, augmentation with IMU and joint axis data, and optional
|
62
|
+
dynamic simulation.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
sys (base.System | list[base.System]):
|
66
|
+
The system(s) for which motion should be generated.
|
67
|
+
config (jcalc.MotionConfig | list[jcalc.MotionConfig], optional):
|
68
|
+
Motion configuration(s) defining velocity limits, interpolation methods,
|
69
|
+
and range constraints. Defaults to `jcalc.MotionConfig()`.
|
70
|
+
setup_fn (Optional[types.SETUP_FN], optional):
|
71
|
+
A function to modify the system before motion generation. Defaults to `None`.
|
72
|
+
finalize_fn (Optional[types.FINALIZE_FN], optional):
|
73
|
+
A function to modify outputs after motion generation. Defaults to `None`.
|
74
|
+
add_X_imus (bool, optional):
|
75
|
+
Whether to add IMU sensor data to the output. Defaults to `False`.
|
76
|
+
add_X_imus_kwargs (dict, optional):
|
77
|
+
Additional keyword arguments for IMU data processing. Defaults to `{}`.
|
78
|
+
add_X_jointaxes (bool, optional):
|
79
|
+
Whether to add joint axis data to the output. Defaults to `False`.
|
80
|
+
add_X_jointaxes_kwargs (dict, optional):
|
81
|
+
Additional keyword arguments for joint axis data processing. Defaults to `{}`.
|
82
|
+
add_y_relpose (bool, optional):
|
83
|
+
Whether to add relative pose targets to the output. Defaults to `False`.
|
84
|
+
add_y_rootincl (bool, optional):
|
85
|
+
Whether to add root inclination targets to the output. Defaults to `False`.
|
86
|
+
add_y_rootincl_kwargs (dict, optional):
|
87
|
+
Additional keyword arguments for root inclination processing. Defaults to `{}`.
|
88
|
+
add_y_rootfull (bool, optional):
|
89
|
+
Whether to add full root state targets to the output. Defaults to `False`.
|
90
|
+
add_y_rootfull_kwargs (dict, optional):
|
91
|
+
Additional keyword arguments for full root state processing. Defaults to `{}`.
|
92
|
+
sys_ml (Optional[base.System], optional):
|
93
|
+
System that defines the graph and naming structure of the `X` and `y` outputs. Defaults to `None` which then uses the first provided system.
|
94
|
+
randomize_positions (bool, optional):
|
95
|
+
Whether to randomised positions based on `pos_min` and `pos_max`. Defaults to `False`.
|
96
|
+
randomize_motion_artifacts (bool, optional):
|
97
|
+
Whether to randomize the IMU motion artifact simulation. This randomises the spring stiffness and spring damping parameters of the passive free joint that is added between nonrigid and rigid IMU. Defaults to `False`.
|
98
|
+
randomize_joint_params (bool, optional):
|
99
|
+
Whether to randomize joint parameters by calling `JointModel.init_joint_params` before every sequence generation. Defaults to `False`.
|
100
|
+
randomize_hz (bool, optional):
|
101
|
+
Whether to randomize the sampling frequency of the generated data. Defaults to `False`.
|
102
|
+
randomize_hz_kwargs (dict, optional):
|
103
|
+
Additional keyword arguments for sampling frequency randomization. Defaults to `{}`.
|
104
|
+
imu_motion_artifacts (bool, optional):
|
105
|
+
Whether to simulate nonrigid IMU motion artifacts. Defaults to `False`.
|
106
|
+
imu_motion_artifacts_kwargs (dict, optional):
|
107
|
+
Additional keyword arguments for IMU motion artifact simulation. Defaults to `{}`.
|
108
|
+
dynamic_simulation (bool, optional):
|
109
|
+
Whether to use a physics-based simulation to generate motion instead of purely
|
110
|
+
kinematic methods. Defaults to `False`.
|
111
|
+
dynamic_simulation_kwargs (dict, optional):
|
112
|
+
Additional keyword arguments for dynamic simulation. Defaults to `{}`.
|
113
|
+
output_transform (Optional[Callable], optional):
|
114
|
+
A function to transform the generated output data. Defaults to `None`.
|
115
|
+
keep_output_extras (bool, optional):
|
116
|
+
Whether to keep additional output metadata. Defaults to `False`.
|
117
|
+
use_link_number_in_Xy (bool, optional):
|
118
|
+
Whether to replace joint names with numerical indices in the output. Defaults to `False`.
|
119
|
+
cor (bool, optional):
|
120
|
+
Whether to replace free joints with center-of-rotation (COR) 9D free joint. Defaults to `False`.
|
121
|
+
disable_tqdm (bool, optional):
|
122
|
+
Whether to disable progress bars during generation. Defaults to `False`.
|
123
|
+
|
124
|
+
Raises:
|
125
|
+
AssertionError: If any of the provided `MotionConfig` instances are infeasible.
|
126
|
+
|
127
|
+
Notes:
|
128
|
+
- This class supports batch generation, lazy and eager data loading, and
|
129
|
+
motion augmentation.
|
130
|
+
- If `randomize_hz=True`, the time step (`dt`) varies according to the specified
|
131
|
+
sampling rates.
|
132
|
+
- When `cor=True`, free joints are replaced with center-of-rotation models,
|
133
|
+
affecting joint motion behavior.
|
134
|
+
""" # noqa: E501
|
57
135
|
|
58
136
|
# add some default values
|
59
137
|
randomize_hz_kwargs_defaults = dict(add_dt=True)
|
@@ -139,6 +217,7 @@ class RCMG:
|
|
139
217
|
def to_lazy_gen(
|
140
218
|
self, sizes: int | list[int] = 1, jit: bool = True
|
141
219
|
) -> types.BatchedGenerator:
|
220
|
+
"Returns a generator `X, y = gen(key)` that lazily generates batched sequences."
|
142
221
|
return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
|
143
222
|
|
144
223
|
@staticmethod
|
@@ -201,7 +280,7 @@ class RCMG:
|
|
201
280
|
),
|
202
281
|
verbose: bool = True,
|
203
282
|
):
|
204
|
-
|
283
|
+
"Stores unbatched sequences as numpy arrays into folder."
|
205
284
|
i = 0
|
206
285
|
|
207
286
|
def callback(data: list[PyTree[np.ndarray]]) -> None:
|
@@ -237,6 +316,7 @@ class RCMG:
|
|
237
316
|
shuffle: bool = True,
|
238
317
|
transform=None,
|
239
318
|
) -> types.BatchedGenerator:
|
319
|
+
"Returns a generator `X, y = gen(key)` that returns precomputed batched sequences." # noqa: E501
|
240
320
|
data = self.to_list(sizes, seed)
|
241
321
|
assert len(data) >= batchsize
|
242
322
|
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
@@ -80,11 +80,11 @@ def generators_eager(
|
|
80
80
|
# converts also to numpy; but with np.array.flags.writeable = False
|
81
81
|
sample = jax.device_get(sample)
|
82
82
|
# this then sets this flag to True
|
83
|
-
sample = jax.
|
83
|
+
sample = jax.tree.map(np.array, sample)
|
84
84
|
|
85
85
|
sample_flat, _ = jax.tree_util.tree_flatten(sample)
|
86
86
|
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
|
87
|
-
callback([jax.
|
87
|
+
callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
|
88
88
|
|
89
89
|
# cleanup
|
90
90
|
del sample, sample_flat
|
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
86
86
|
controller_state: PDControllerState, sys: base.System, state: base.State
|
87
87
|
) -> jax.Array:
|
88
88
|
taus = jnp.zeros((sys.qd_size()))
|
89
|
-
q_ref, qd_ref = jax.
|
89
|
+
q_ref, qd_ref = jax.tree.map(
|
90
90
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
91
91
|
arr, controller_state.i, keepdims=False
|
92
92
|
),
|
@@ -19,6 +19,87 @@ from ring.algorithms._random import TimeDependentFloat
|
|
19
19
|
|
20
20
|
@dataclass
|
21
21
|
class MotionConfig:
|
22
|
+
"""
|
23
|
+
Configuration for joint motion generation in kinematic and dynamic simulations.
|
24
|
+
|
25
|
+
This class defines the constraints and parameters for generating random joint motions,
|
26
|
+
including angular and positional velocity limits, interpolation methods, and range
|
27
|
+
restrictions for various joint types.
|
28
|
+
|
29
|
+
Attributes:
|
30
|
+
T (float): Total duration of the motion sequence (in seconds).
|
31
|
+
t_min (float): Minimum time interval between two generated joint states.
|
32
|
+
t_max (float | TimeDependentFloat): Maximum time interval between two generated joint states.
|
33
|
+
|
34
|
+
dang_min (float | TimeDependentFloat): Minimum angular velocity (rad/s).
|
35
|
+
dang_max (float | TimeDependentFloat): Maximum angular velocity (rad/s).
|
36
|
+
dang_min_free_spherical (float | TimeDependentFloat): Minimum angular velocity for free and spherical joints.
|
37
|
+
dang_max_free_spherical (float | TimeDependentFloat): Maximum angular velocity for free and spherical joints.
|
38
|
+
|
39
|
+
delta_ang_min (float | TimeDependentFloat): Minimum allowed change in joint angle (radians).
|
40
|
+
delta_ang_max (float | TimeDependentFloat): Maximum allowed change in joint angle (radians).
|
41
|
+
delta_ang_min_free_spherical (float | TimeDependentFloat): Minimum allowed change in angle for free/spherical joints.
|
42
|
+
delta_ang_max_free_spherical (float | TimeDependentFloat): Maximum allowed change in angle for free/spherical joints.
|
43
|
+
|
44
|
+
dpos_min (float | TimeDependentFloat): Minimum translational velocity.
|
45
|
+
dpos_max (float | TimeDependentFloat): Maximum translational velocity.
|
46
|
+
pos_min (float | TimeDependentFloat): Minimum position constraint.
|
47
|
+
pos_max (float | TimeDependentFloat): Maximum position constraint.
|
48
|
+
|
49
|
+
pos_min_p3d_x (float | TimeDependentFloat): Minimum position in x-direction for P3D joints.
|
50
|
+
pos_max_p3d_x (float | TimeDependentFloat): Maximum position in x-direction for P3D joints.
|
51
|
+
pos_min_p3d_y (float | TimeDependentFloat): Minimum position in y-direction for P3D joints.
|
52
|
+
pos_max_p3d_y (float | TimeDependentFloat): Maximum position in y-direction for P3D joints.
|
53
|
+
pos_min_p3d_z (float | TimeDependentFloat): Minimum position in z-direction for P3D joints.
|
54
|
+
pos_max_p3d_z (float | TimeDependentFloat): Maximum position in z-direction for P3D joints.
|
55
|
+
|
56
|
+
cdf_bins_min (int): Minimum number of bins for cumulative distribution function (CDF)-based random sampling.
|
57
|
+
cdf_bins_max (Optional[int]): Maximum number of bins for CDF-based sampling.
|
58
|
+
|
59
|
+
randomized_interpolation_angle (bool): Whether to use randomized interpolation for angular motion.
|
60
|
+
randomized_interpolation_position (bool): Whether to use randomized interpolation for positional motion.
|
61
|
+
interpolation_method (str): Interpolation method to be used (default: "cosine").
|
62
|
+
|
63
|
+
range_of_motion_hinge (bool): Whether to enforce range-of-motion constraints on hinge joints.
|
64
|
+
range_of_motion_hinge_method (str): Method used for range-of-motion enforcement (e.g., "uniform", "sigmoid").
|
65
|
+
|
66
|
+
rom_halfsize (float | TimeDependentFloat): Half-size of the range of motion restriction.
|
67
|
+
|
68
|
+
ang0_min (float): Minimum initial joint angle.
|
69
|
+
ang0_max (float): Maximum initial joint angle.
|
70
|
+
pos0_min (float): Minimum initial joint position.
|
71
|
+
pos0_max (float): Maximum initial joint position.
|
72
|
+
|
73
|
+
cor_t_min (float): Minimum time step for center-of-rotation (COR) joints.
|
74
|
+
cor_t_max (float | TimeDependentFloat): Maximum time step for COR joints.
|
75
|
+
cor_dpos_min (float | TimeDependentFloat): Minimum velocity for COR translation.
|
76
|
+
cor_dpos_max (float | TimeDependentFloat): Maximum velocity for COR translation.
|
77
|
+
cor_pos_min (float | TimeDependentFloat): Minimum position for COR translation.
|
78
|
+
cor_pos_max (float | TimeDependentFloat): Maximum position for COR translation.
|
79
|
+
cor_pos0_min (float): Initial minimum position for COR translation.
|
80
|
+
cor_pos0_max (float): Initial maximum position for COR translation.
|
81
|
+
|
82
|
+
joint_type_specific_overwrites (dict[str, dict[str, Any]]):
|
83
|
+
A dictionary mapping joint types to specific motion configuration overrides.
|
84
|
+
|
85
|
+
Methods:
|
86
|
+
is_feasible:
|
87
|
+
Checks if the motion configuration satisfies all constraints.
|
88
|
+
|
89
|
+
to_nomotion_config:
|
90
|
+
Returns a new `MotionConfig` where all velocities and angle changes are set to zero.
|
91
|
+
|
92
|
+
overwrite_for_joint_type:
|
93
|
+
Applies specific configuration changes for a given joint type.
|
94
|
+
Note: These changes affect all instances of `MotionConfig` for this joint type.
|
95
|
+
|
96
|
+
overwrite_for_subsystem:
|
97
|
+
Modifies the motion configuration for all joints in a subsystem rooted at `link_name`.
|
98
|
+
|
99
|
+
from_register:
|
100
|
+
Retrieves a predefined `MotionConfig` from the global registry.
|
101
|
+
""" # noqa: E501
|
102
|
+
|
22
103
|
T: float = 60.0 # length of random motion
|
23
104
|
t_min: float = 0.05 # min time between two generated angles
|
24
105
|
t_max: float | TimeDependentFloat = 0.30 # max time ..
|
@@ -412,6 +493,30 @@ def _find_interval(t: jax.Array, boundaries: jax.Array):
|
|
412
493
|
def join_motionconfigs(
|
413
494
|
configs: list[MotionConfig], boundaries: list[float]
|
414
495
|
) -> MotionConfig:
|
496
|
+
"""
|
497
|
+
Joins multiple `MotionConfig` objects in time, transitioning between them at specified boundaries.
|
498
|
+
|
499
|
+
This function takes a list of `MotionConfig` instances and a corresponding list of boundary times,
|
500
|
+
and constructs a new `MotionConfig` that varies in time according to the provided segments.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
configs (list[MotionConfig]): A list of `MotionConfig` objects to be joined.
|
504
|
+
boundaries (list[float]): A list of time values where transitions between `configs` occur.
|
505
|
+
Must have one element less than `configs`, as each boundary defines the transition point
|
506
|
+
between two consecutive configurations.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
MotionConfig: A new `MotionConfig` object where time-dependent fields transition based on the
|
510
|
+
specified boundaries.
|
511
|
+
|
512
|
+
Raises:
|
513
|
+
AssertionError: If the number of boundaries does not match `len(configs) - 1`.
|
514
|
+
AssertionError: If time-independent fields have differing values across `configs`.
|
515
|
+
|
516
|
+
Notes:
|
517
|
+
- Only fields that are time-dependent (`float | TimeDependentFloat`) will change over time.
|
518
|
+
- Time-independent fields must be the same in all `configs`, or an error is raised.
|
519
|
+
""" # noqa: E501
|
415
520
|
# to avoid a circular import due to `ring.utils.randomize_sys` importing `jcalc`
|
416
521
|
from ring.utils import tree_equal
|
417
522
|
|
@@ -517,6 +622,55 @@ INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
|
|
517
622
|
|
518
623
|
@dataclass
|
519
624
|
class JointModel:
|
625
|
+
"""
|
626
|
+
Represents the kinematic and dynamic properties of a joint type.
|
627
|
+
|
628
|
+
A `JointModel` defines the mathematical functions required to compute joint
|
629
|
+
transformations, motion, control terms, and inverse kinematics. It is used to
|
630
|
+
describe the behavior of various joint types, including revolute, prismatic,
|
631
|
+
spherical, and free joints.
|
632
|
+
|
633
|
+
Attributes:
|
634
|
+
transform (Callable[[jax.Array, jax.Array], base.Transform]):
|
635
|
+
Computes the transformation (position and orientation) of the joint
|
636
|
+
given the joint state `q` and joint parameters.
|
637
|
+
|
638
|
+
motion (list[base.Motion | Callable[[jax.Array], base.Motion]]):
|
639
|
+
Defines the joint motion model. It can be a list of `Motion` objects
|
640
|
+
or callables that return `Motion` based on joint parameters.
|
641
|
+
|
642
|
+
rcmg_draw_fn (Optional[DRAW_FN]):
|
643
|
+
Function used to generate a reference motion trajectory for the joint
|
644
|
+
using Randomized Control Motion Generation (RCMG).
|
645
|
+
|
646
|
+
p_control_term (Optional[P_CONTROL_TERM]):
|
647
|
+
Function that computes the proportional control term for the joint.
|
648
|
+
|
649
|
+
qd_from_q (Optional[QD_FROM_Q]):
|
650
|
+
Function to compute joint velocity (`qd`) from joint positions (`q`).
|
651
|
+
|
652
|
+
coordinate_vector_to_q (Optional[COORDINATE_VECTOR_TO_Q]):
|
653
|
+
Function that maps a coordinate vector to a valid joint state `q`,
|
654
|
+
ensuring constraints (e.g., wrapping angles or normalizing quaternions).
|
655
|
+
|
656
|
+
inv_kin (Optional[INV_KIN]):
|
657
|
+
Function that computes the inverse kinematics for the joint, mapping
|
658
|
+
a desired transform to joint coordinates `q`.
|
659
|
+
|
660
|
+
init_joint_params (Optional[INIT_JOINT_PARAMS]):
|
661
|
+
Function that initializes joint-specific parameters.
|
662
|
+
|
663
|
+
utilities (Optional[dict[str, Any]]):
|
664
|
+
Additional utility functions or metadata related to the joint model.
|
665
|
+
|
666
|
+
Notes:
|
667
|
+
- The `transform` function is essential for computing the joint's spatial
|
668
|
+
transformation based on its generalized coordinates.
|
669
|
+
- The `motion` attribute describes how forces and torques affect the joint.
|
670
|
+
- The `rcmg_draw_fn` is used for RCMG motion generation.
|
671
|
+
- The `coordinate_vector_to_q` is critical for maintaining valid joint states.
|
672
|
+
""" # noqa: E501
|
673
|
+
|
520
674
|
# (q, params) -> Transform
|
521
675
|
transform: Callable[[jax.Array, jax.Array], base.Transform]
|
522
676
|
# len(motion) == len(qd)
|
@@ -1079,6 +1233,50 @@ def register_new_joint_type(
|
|
1079
1233
|
qd_width: Optional[int] = None,
|
1080
1234
|
overwrite: bool = False,
|
1081
1235
|
):
|
1236
|
+
"""
|
1237
|
+
Registers a new joint type with its corresponding `JointModel` and kinematic properties.
|
1238
|
+
|
1239
|
+
This function allows the addition of custom joint types to the system by associating
|
1240
|
+
them with a `JointModel`, specifying their state and velocity dimensions, and optionally
|
1241
|
+
overwriting existing joint definitions.
|
1242
|
+
|
1243
|
+
Args:
|
1244
|
+
joint_type (str):
|
1245
|
+
Name of the new joint type to register.
|
1246
|
+
joint_model (JointModel):
|
1247
|
+
The `JointModel` instance defining the kinematic and dynamic properties of the joint.
|
1248
|
+
q_width (int):
|
1249
|
+
Number of generalized coordinates (degrees of freedom) required to represent the joint.
|
1250
|
+
qd_width (Optional[int], default=None):
|
1251
|
+
Number of velocity coordinates associated with the joint. Defaults to `q_width`.
|
1252
|
+
overwrite (bool, default=False):
|
1253
|
+
If `True`, allows overwriting an existing joint type. Otherwise, raises an error if
|
1254
|
+
the joint type already exists.
|
1255
|
+
|
1256
|
+
Raises:
|
1257
|
+
AssertionError:
|
1258
|
+
- If `joint_type` is `"default"` (reserved name).
|
1259
|
+
- If `joint_type` already exists and `overwrite=False`.
|
1260
|
+
- If `qd_width` is not provided and does not default to `q_width`.
|
1261
|
+
- If `joint_model.motion` length does not match `qd_width`.
|
1262
|
+
|
1263
|
+
Notes:
|
1264
|
+
- The function updates global dictionaries that store joint properties, including:
|
1265
|
+
- `_joint_types`: Maps joint type names to `JointModel` instances.
|
1266
|
+
- `base.Q_WIDTHS`: Stores the number of state coordinates for each joint type.
|
1267
|
+
- `base.QD_WIDTHS`: Stores the number of velocity coordinates for each joint type.
|
1268
|
+
- If `overwrite=True`, existing entries are removed before adding the new joint type.
|
1269
|
+
- Ensures consistency between motion definitions and velocity coordinate dimensions.
|
1270
|
+
|
1271
|
+
Example:
|
1272
|
+
```python
|
1273
|
+
new_joint = JointModel(
|
1274
|
+
transform=my_transform_fn,
|
1275
|
+
motion=[base.Motion.create(ang=jnp.array([1, 0, 0]))],
|
1276
|
+
)
|
1277
|
+
register_new_joint_type("custom_hinge", new_joint, q_width=1)
|
1278
|
+
```
|
1279
|
+
""" # noqa: E501
|
1082
1280
|
# this name is used
|
1083
1281
|
assert joint_type != "default", "Please use another name."
|
1084
1282
|
|
@@ -4,6 +4,7 @@ import jax
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jaxopt
|
6
6
|
from jaxopt._src.base import Solver
|
7
|
+
|
7
8
|
from ring import algebra
|
8
9
|
from ring import base
|
9
10
|
from ring import maths
|
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
|
|
171
172
|
|
172
173
|
# find result of best q0 initial value
|
173
174
|
best_q_index = jnp.argmin(values)
|
174
|
-
best_q, best_q_value = jax.
|
175
|
+
best_q, best_q_value = jax.tree.map(
|
175
176
|
lambda arr: jax.lax.dynamic_index_in_dim(
|
176
177
|
arr, best_q_index, keepdims=False
|
177
178
|
),
|
@@ -244,7 +244,7 @@ def imu(
|
|
244
244
|
measurements["mag"] = magnetometer(xs.rot, magvec)
|
245
245
|
|
246
246
|
if smoothen_degree is not None:
|
247
|
-
measurements = jax.
|
247
|
+
measurements = jax.tree.map(
|
248
248
|
lambda arr: _moving_average(arr, smoothen_degree),
|
249
249
|
measurements,
|
250
250
|
)
|
@@ -257,7 +257,7 @@ def imu(
|
|
257
257
|
delay = half_window
|
258
258
|
|
259
259
|
if delay is not None and delay > 0:
|
260
|
-
measurements = jax.
|
260
|
+
measurements = jax.tree.map(
|
261
261
|
lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
|
262
262
|
)
|
263
263
|
|
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
|
473
473
|
X[name] = {"joint_axes": joint_axes}
|
474
474
|
|
475
475
|
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
476
|
-
X = jax.
|
476
|
+
X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
|
477
477
|
return X
|
478
478
|
|
479
479
|
|
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
|
|
498
498
|
<x_xy>
|
499
499
|
<options gravity="0 0 0"/>
|
500
500
|
<worldbody>
|
501
|
-
<body name="IMU" joint="
|
502
|
-
<geom type="box" mass="
|
501
|
+
<body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
|
502
|
+
<geom type="box" mass="1" dim="0.01 0.01 0.01"/>
|
503
503
|
</body>
|
504
504
|
</worldbody>
|
505
505
|
</x_xy>
|
506
|
-
"""
|
506
|
+
""" # noqa: E501
|
507
507
|
|
508
508
|
|
509
509
|
def _quasi_physical_simulation_beautiful(
|
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
|
|
512
512
|
sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
|
513
513
|
|
514
514
|
def step_dynamics(state: base.State, x):
|
515
|
-
state = algorithms.step(
|
515
|
+
state = algorithms.step(
|
516
|
+
sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
|
517
|
+
)
|
516
518
|
return state, state.q
|
517
519
|
|
518
|
-
state = base.State.create(sys, q=xs.pos[0])
|
519
|
-
_,
|
520
|
-
return xs.replace(pos=
|
520
|
+
state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
|
521
|
+
_, qs = jax.lax.scan(step_dynamics, state, xs)
|
522
|
+
return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
|
521
523
|
|
522
524
|
|
523
525
|
_constants = {
|