imt-ring 1.6.38__py3-none-any.whl → 1.6.42__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.38.dist-info → imt_ring-1.6.42.dist-info}/METADATA +1 -1
- {imt_ring-1.6.38.dist-info → imt_ring-1.6.42.dist-info}/RECORD +10 -10
- ring/algorithms/dynamics.py +27 -1
- ring/algorithms/generator/base.py +82 -2
- ring/algorithms/jcalc.py +198 -0
- ring/base.py +355 -26
- ring/ml/base.py +5 -0
- ring/sim2real/sim2real.py +29 -5
- {imt_ring-1.6.38.dist-info → imt_ring-1.6.42.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.38.dist-info → imt_ring-1.6.42.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
|
|
1
1
|
ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=zromjIuMpNBoyiwHa9OCyZvAz7jHjXHZIdRt8fN8PoA,50481
|
4
4
|
ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
7
|
ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
|
8
|
-
ring/algorithms/dynamics.py,sha256=
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
8
|
+
ring/algorithms/dynamics.py,sha256=NFOZawjrFoS5RgiWOpG1pQCC8l7RBOEZIi9ok6gvf9U,12268
|
9
|
+
ring/algorithms/jcalc.py,sha256=l6BXOmXwrZ_AKKRm4gEHq_k2LSUQ4wd--1qL1qNTcKk,46794
|
10
10
|
ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
|
11
11
|
ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
|
@@ -15,7 +15,7 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
|
|
15
15
|
ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
|
16
16
|
ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
|
17
17
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
18
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
+
ring/algorithms/generator/base.py,sha256=klWYt6TlMluLu0ihGzmmPXBm47DOTpjXJylZVNXHVEk,22419
|
19
19
|
ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
|
20
20
|
ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
|
21
21
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
@@ -52,7 +52,7 @@ ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,
|
|
52
52
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
53
53
|
ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
|
54
54
|
ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
|
55
|
-
ring/ml/base.py,sha256=
|
55
|
+
ring/ml/base.py,sha256=HAAM6ehXiyV53cvh1bLvPHIrlM7S4pgN-xcGTI8Mvsw,10238
|
56
56
|
ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
|
57
57
|
ring/ml/ml_utils.py,sha256=hu189AnHcmkhkpEPZZ19O0gWz3T-YKpWQW9buqDTMow,10915
|
58
58
|
ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
|
@@ -68,7 +68,7 @@ ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obm
|
|
68
68
|
ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
|
69
69
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
70
70
|
ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
|
71
|
-
ring/sim2real/sim2real.py,sha256=
|
71
|
+
ring/sim2real/sim2real.py,sha256=4MtxsyQmfnSi9llzL0ZB5wmJ5zfAXBv705RbSpI26gY,10373
|
72
72
|
ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E,193
|
73
73
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
74
74
|
ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
|
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
|
|
86
86
|
ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
|
87
87
|
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
88
88
|
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
89
|
-
imt_ring-1.6.
|
90
|
-
imt_ring-1.6.
|
91
|
-
imt_ring-1.6.
|
92
|
-
imt_ring-1.6.
|
89
|
+
imt_ring-1.6.42.dist-info/METADATA,sha256=xpcG74pMBIr3v0CQkG9zNZ0BCefDZAVhrOPu31Pb4Uk,4251
|
90
|
+
imt_ring-1.6.42.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
91
|
+
imt_ring-1.6.42.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
92
|
+
imt_ring-1.6.42.dist-info/RECORD,,
|
ring/algorithms/dynamics.py
CHANGED
@@ -303,7 +303,33 @@ def step(
|
|
303
303
|
taus: Optional[jax.Array] = None,
|
304
304
|
n_substeps: int = 1,
|
305
305
|
) -> base.State:
|
306
|
-
"
|
306
|
+
"""
|
307
|
+
Advances the system dynamics by a single timestep using semi-implicit Euler integration.
|
308
|
+
|
309
|
+
This function updates the system's state by integrating the equations of motion
|
310
|
+
over a timestep, potentially with multiple substeps for improved numerical stability.
|
311
|
+
The method ensures that the system's kinematics are updated before each integration step.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
sys (base.System):
|
315
|
+
The system to simulate, containing link information, joint dynamics, and integration parameters.
|
316
|
+
state (base.State):
|
317
|
+
The current state of the system, including joint positions (`q`), velocities (`qd`), and transforms (`x`).
|
318
|
+
taus (Optional[jax.Array], optional):
|
319
|
+
The control torques applied to the system joints. If `None`, zero torques are applied.
|
320
|
+
Defaults to `None`.
|
321
|
+
n_substeps (int, optional):
|
322
|
+
The number of integration substeps per timestep to improve numerical accuracy.
|
323
|
+
Defaults to `1`.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
base.State:
|
327
|
+
The updated state of the system after integration.
|
328
|
+
|
329
|
+
Raises:
|
330
|
+
AssertionError: If the system's degrees of freedom (`q` and `qd`) do not match expectations.
|
331
|
+
AssertionError: If an unsupported integration method is specified in `sys.integration_method`.
|
332
|
+
""" # noqa: E501
|
307
333
|
assert sys.q_size() == state.q.size
|
308
334
|
if taus is None:
|
309
335
|
taus = jnp.zeros_like(state.qd)
|
@@ -53,7 +53,85 @@ class RCMG:
|
|
53
53
|
cor: bool = False,
|
54
54
|
disable_tqdm: bool = False,
|
55
55
|
) -> None:
|
56
|
-
"
|
56
|
+
"""
|
57
|
+
Initializes the Random Chain Motion Generator (RCMG).
|
58
|
+
|
59
|
+
The RCMG generates synthetic joint motion sequences for kinematic and dynamic
|
60
|
+
systems based on predefined motion configurations. It allows for system
|
61
|
+
randomization, augmentation with IMU and joint axis data, and optional
|
62
|
+
dynamic simulation.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
sys (base.System | list[base.System]):
|
66
|
+
The system(s) for which motion should be generated.
|
67
|
+
config (jcalc.MotionConfig | list[jcalc.MotionConfig], optional):
|
68
|
+
Motion configuration(s) defining velocity limits, interpolation methods,
|
69
|
+
and range constraints. Defaults to `jcalc.MotionConfig()`.
|
70
|
+
setup_fn (Optional[types.SETUP_FN], optional):
|
71
|
+
A function to modify the system before motion generation. Defaults to `None`.
|
72
|
+
finalize_fn (Optional[types.FINALIZE_FN], optional):
|
73
|
+
A function to modify outputs after motion generation. Defaults to `None`.
|
74
|
+
add_X_imus (bool, optional):
|
75
|
+
Whether to add IMU sensor data to the output. Defaults to `False`.
|
76
|
+
add_X_imus_kwargs (dict, optional):
|
77
|
+
Additional keyword arguments for IMU data processing. Defaults to `{}`.
|
78
|
+
add_X_jointaxes (bool, optional):
|
79
|
+
Whether to add joint axis data to the output. Defaults to `False`.
|
80
|
+
add_X_jointaxes_kwargs (dict, optional):
|
81
|
+
Additional keyword arguments for joint axis data processing. Defaults to `{}`.
|
82
|
+
add_y_relpose (bool, optional):
|
83
|
+
Whether to add relative pose targets to the output. Defaults to `False`.
|
84
|
+
add_y_rootincl (bool, optional):
|
85
|
+
Whether to add root inclination targets to the output. Defaults to `False`.
|
86
|
+
add_y_rootincl_kwargs (dict, optional):
|
87
|
+
Additional keyword arguments for root inclination processing. Defaults to `{}`.
|
88
|
+
add_y_rootfull (bool, optional):
|
89
|
+
Whether to add full root state targets to the output. Defaults to `False`.
|
90
|
+
add_y_rootfull_kwargs (dict, optional):
|
91
|
+
Additional keyword arguments for full root state processing. Defaults to `{}`.
|
92
|
+
sys_ml (Optional[base.System], optional):
|
93
|
+
System that defines the graph and naming structure of the `X` and `y` outputs. Defaults to `None` which then uses the first provided system.
|
94
|
+
randomize_positions (bool, optional):
|
95
|
+
Whether to randomised positions based on `pos_min` and `pos_max`. Defaults to `False`.
|
96
|
+
randomize_motion_artifacts (bool, optional):
|
97
|
+
Whether to randomize the IMU motion artifact simulation. This randomises the spring stiffness and spring damping parameters of the passive free joint that is added between nonrigid and rigid IMU. Defaults to `False`.
|
98
|
+
randomize_joint_params (bool, optional):
|
99
|
+
Whether to randomize joint parameters by calling `JointModel.init_joint_params` before every sequence generation. Defaults to `False`.
|
100
|
+
randomize_hz (bool, optional):
|
101
|
+
Whether to randomize the sampling frequency of the generated data. Defaults to `False`.
|
102
|
+
randomize_hz_kwargs (dict, optional):
|
103
|
+
Additional keyword arguments for sampling frequency randomization. Defaults to `{}`.
|
104
|
+
imu_motion_artifacts (bool, optional):
|
105
|
+
Whether to simulate nonrigid IMU motion artifacts. Defaults to `False`.
|
106
|
+
imu_motion_artifacts_kwargs (dict, optional):
|
107
|
+
Additional keyword arguments for IMU motion artifact simulation. Defaults to `{}`.
|
108
|
+
dynamic_simulation (bool, optional):
|
109
|
+
Whether to use a physics-based simulation to generate motion instead of purely
|
110
|
+
kinematic methods. Defaults to `False`.
|
111
|
+
dynamic_simulation_kwargs (dict, optional):
|
112
|
+
Additional keyword arguments for dynamic simulation. Defaults to `{}`.
|
113
|
+
output_transform (Optional[Callable], optional):
|
114
|
+
A function to transform the generated output data. Defaults to `None`.
|
115
|
+
keep_output_extras (bool, optional):
|
116
|
+
Whether to keep additional output metadata. Defaults to `False`.
|
117
|
+
use_link_number_in_Xy (bool, optional):
|
118
|
+
Whether to replace joint names with numerical indices in the output. Defaults to `False`.
|
119
|
+
cor (bool, optional):
|
120
|
+
Whether to replace free joints with center-of-rotation (COR) 9D free joint. Defaults to `False`.
|
121
|
+
disable_tqdm (bool, optional):
|
122
|
+
Whether to disable progress bars during generation. Defaults to `False`.
|
123
|
+
|
124
|
+
Raises:
|
125
|
+
AssertionError: If any of the provided `MotionConfig` instances are infeasible.
|
126
|
+
|
127
|
+
Notes:
|
128
|
+
- This class supports batch generation, lazy and eager data loading, and
|
129
|
+
motion augmentation.
|
130
|
+
- If `randomize_hz=True`, the time step (`dt`) varies according to the specified
|
131
|
+
sampling rates.
|
132
|
+
- When `cor=True`, free joints are replaced with center-of-rotation models,
|
133
|
+
affecting joint motion behavior.
|
134
|
+
""" # noqa: E501
|
57
135
|
|
58
136
|
# add some default values
|
59
137
|
randomize_hz_kwargs_defaults = dict(add_dt=True)
|
@@ -139,6 +217,7 @@ class RCMG:
|
|
139
217
|
def to_lazy_gen(
|
140
218
|
self, sizes: int | list[int] = 1, jit: bool = True
|
141
219
|
) -> types.BatchedGenerator:
|
220
|
+
"Returns a generator `X, y = gen(key)` that lazily generates batched sequences."
|
142
221
|
return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
|
143
222
|
|
144
223
|
@staticmethod
|
@@ -201,7 +280,7 @@ class RCMG:
|
|
201
280
|
),
|
202
281
|
verbose: bool = True,
|
203
282
|
):
|
204
|
-
|
283
|
+
"Stores unbatched sequences as numpy arrays into folder."
|
205
284
|
i = 0
|
206
285
|
|
207
286
|
def callback(data: list[PyTree[np.ndarray]]) -> None:
|
@@ -237,6 +316,7 @@ class RCMG:
|
|
237
316
|
shuffle: bool = True,
|
238
317
|
transform=None,
|
239
318
|
) -> types.BatchedGenerator:
|
319
|
+
"Returns a generator `X, y = gen(key)` that returns precomputed batched sequences." # noqa: E501
|
240
320
|
data = self.to_list(sizes, seed)
|
241
321
|
assert len(data) >= batchsize
|
242
322
|
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
ring/algorithms/jcalc.py
CHANGED
@@ -19,6 +19,87 @@ from ring.algorithms._random import TimeDependentFloat
|
|
19
19
|
|
20
20
|
@dataclass
|
21
21
|
class MotionConfig:
|
22
|
+
"""
|
23
|
+
Configuration for joint motion generation in kinematic and dynamic simulations.
|
24
|
+
|
25
|
+
This class defines the constraints and parameters for generating random joint motions,
|
26
|
+
including angular and positional velocity limits, interpolation methods, and range
|
27
|
+
restrictions for various joint types.
|
28
|
+
|
29
|
+
Attributes:
|
30
|
+
T (float): Total duration of the motion sequence (in seconds).
|
31
|
+
t_min (float): Minimum time interval between two generated joint states.
|
32
|
+
t_max (float | TimeDependentFloat): Maximum time interval between two generated joint states.
|
33
|
+
|
34
|
+
dang_min (float | TimeDependentFloat): Minimum angular velocity (rad/s).
|
35
|
+
dang_max (float | TimeDependentFloat): Maximum angular velocity (rad/s).
|
36
|
+
dang_min_free_spherical (float | TimeDependentFloat): Minimum angular velocity for free and spherical joints.
|
37
|
+
dang_max_free_spherical (float | TimeDependentFloat): Maximum angular velocity for free and spherical joints.
|
38
|
+
|
39
|
+
delta_ang_min (float | TimeDependentFloat): Minimum allowed change in joint angle (radians).
|
40
|
+
delta_ang_max (float | TimeDependentFloat): Maximum allowed change in joint angle (radians).
|
41
|
+
delta_ang_min_free_spherical (float | TimeDependentFloat): Minimum allowed change in angle for free/spherical joints.
|
42
|
+
delta_ang_max_free_spherical (float | TimeDependentFloat): Maximum allowed change in angle for free/spherical joints.
|
43
|
+
|
44
|
+
dpos_min (float | TimeDependentFloat): Minimum translational velocity.
|
45
|
+
dpos_max (float | TimeDependentFloat): Maximum translational velocity.
|
46
|
+
pos_min (float | TimeDependentFloat): Minimum position constraint.
|
47
|
+
pos_max (float | TimeDependentFloat): Maximum position constraint.
|
48
|
+
|
49
|
+
pos_min_p3d_x (float | TimeDependentFloat): Minimum position in x-direction for P3D joints.
|
50
|
+
pos_max_p3d_x (float | TimeDependentFloat): Maximum position in x-direction for P3D joints.
|
51
|
+
pos_min_p3d_y (float | TimeDependentFloat): Minimum position in y-direction for P3D joints.
|
52
|
+
pos_max_p3d_y (float | TimeDependentFloat): Maximum position in y-direction for P3D joints.
|
53
|
+
pos_min_p3d_z (float | TimeDependentFloat): Minimum position in z-direction for P3D joints.
|
54
|
+
pos_max_p3d_z (float | TimeDependentFloat): Maximum position in z-direction for P3D joints.
|
55
|
+
|
56
|
+
cdf_bins_min (int): Minimum number of bins for cumulative distribution function (CDF)-based random sampling.
|
57
|
+
cdf_bins_max (Optional[int]): Maximum number of bins for CDF-based sampling.
|
58
|
+
|
59
|
+
randomized_interpolation_angle (bool): Whether to use randomized interpolation for angular motion.
|
60
|
+
randomized_interpolation_position (bool): Whether to use randomized interpolation for positional motion.
|
61
|
+
interpolation_method (str): Interpolation method to be used (default: "cosine").
|
62
|
+
|
63
|
+
range_of_motion_hinge (bool): Whether to enforce range-of-motion constraints on hinge joints.
|
64
|
+
range_of_motion_hinge_method (str): Method used for range-of-motion enforcement (e.g., "uniform", "sigmoid").
|
65
|
+
|
66
|
+
rom_halfsize (float | TimeDependentFloat): Half-size of the range of motion restriction.
|
67
|
+
|
68
|
+
ang0_min (float): Minimum initial joint angle.
|
69
|
+
ang0_max (float): Maximum initial joint angle.
|
70
|
+
pos0_min (float): Minimum initial joint position.
|
71
|
+
pos0_max (float): Maximum initial joint position.
|
72
|
+
|
73
|
+
cor_t_min (float): Minimum time step for center-of-rotation (COR) joints.
|
74
|
+
cor_t_max (float | TimeDependentFloat): Maximum time step for COR joints.
|
75
|
+
cor_dpos_min (float | TimeDependentFloat): Minimum velocity for COR translation.
|
76
|
+
cor_dpos_max (float | TimeDependentFloat): Maximum velocity for COR translation.
|
77
|
+
cor_pos_min (float | TimeDependentFloat): Minimum position for COR translation.
|
78
|
+
cor_pos_max (float | TimeDependentFloat): Maximum position for COR translation.
|
79
|
+
cor_pos0_min (float): Initial minimum position for COR translation.
|
80
|
+
cor_pos0_max (float): Initial maximum position for COR translation.
|
81
|
+
|
82
|
+
joint_type_specific_overwrites (dict[str, dict[str, Any]]):
|
83
|
+
A dictionary mapping joint types to specific motion configuration overrides.
|
84
|
+
|
85
|
+
Methods:
|
86
|
+
is_feasible:
|
87
|
+
Checks if the motion configuration satisfies all constraints.
|
88
|
+
|
89
|
+
to_nomotion_config:
|
90
|
+
Returns a new `MotionConfig` where all velocities and angle changes are set to zero.
|
91
|
+
|
92
|
+
overwrite_for_joint_type:
|
93
|
+
Applies specific configuration changes for a given joint type.
|
94
|
+
Note: These changes affect all instances of `MotionConfig` for this joint type.
|
95
|
+
|
96
|
+
overwrite_for_subsystem:
|
97
|
+
Modifies the motion configuration for all joints in a subsystem rooted at `link_name`.
|
98
|
+
|
99
|
+
from_register:
|
100
|
+
Retrieves a predefined `MotionConfig` from the global registry.
|
101
|
+
""" # noqa: E501
|
102
|
+
|
22
103
|
T: float = 60.0 # length of random motion
|
23
104
|
t_min: float = 0.05 # min time between two generated angles
|
24
105
|
t_max: float | TimeDependentFloat = 0.30 # max time ..
|
@@ -412,6 +493,30 @@ def _find_interval(t: jax.Array, boundaries: jax.Array):
|
|
412
493
|
def join_motionconfigs(
|
413
494
|
configs: list[MotionConfig], boundaries: list[float]
|
414
495
|
) -> MotionConfig:
|
496
|
+
"""
|
497
|
+
Joins multiple `MotionConfig` objects in time, transitioning between them at specified boundaries.
|
498
|
+
|
499
|
+
This function takes a list of `MotionConfig` instances and a corresponding list of boundary times,
|
500
|
+
and constructs a new `MotionConfig` that varies in time according to the provided segments.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
configs (list[MotionConfig]): A list of `MotionConfig` objects to be joined.
|
504
|
+
boundaries (list[float]): A list of time values where transitions between `configs` occur.
|
505
|
+
Must have one element less than `configs`, as each boundary defines the transition point
|
506
|
+
between two consecutive configurations.
|
507
|
+
|
508
|
+
Returns:
|
509
|
+
MotionConfig: A new `MotionConfig` object where time-dependent fields transition based on the
|
510
|
+
specified boundaries.
|
511
|
+
|
512
|
+
Raises:
|
513
|
+
AssertionError: If the number of boundaries does not match `len(configs) - 1`.
|
514
|
+
AssertionError: If time-independent fields have differing values across `configs`.
|
515
|
+
|
516
|
+
Notes:
|
517
|
+
- Only fields that are time-dependent (`float | TimeDependentFloat`) will change over time.
|
518
|
+
- Time-independent fields must be the same in all `configs`, or an error is raised.
|
519
|
+
""" # noqa: E501
|
415
520
|
# to avoid a circular import due to `ring.utils.randomize_sys` importing `jcalc`
|
416
521
|
from ring.utils import tree_equal
|
417
522
|
|
@@ -517,6 +622,55 @@ INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
|
|
517
622
|
|
518
623
|
@dataclass
|
519
624
|
class JointModel:
|
625
|
+
"""
|
626
|
+
Represents the kinematic and dynamic properties of a joint type.
|
627
|
+
|
628
|
+
A `JointModel` defines the mathematical functions required to compute joint
|
629
|
+
transformations, motion, control terms, and inverse kinematics. It is used to
|
630
|
+
describe the behavior of various joint types, including revolute, prismatic,
|
631
|
+
spherical, and free joints.
|
632
|
+
|
633
|
+
Attributes:
|
634
|
+
transform (Callable[[jax.Array, jax.Array], base.Transform]):
|
635
|
+
Computes the transformation (position and orientation) of the joint
|
636
|
+
given the joint state `q` and joint parameters.
|
637
|
+
|
638
|
+
motion (list[base.Motion | Callable[[jax.Array], base.Motion]]):
|
639
|
+
Defines the joint motion model. It can be a list of `Motion` objects
|
640
|
+
or callables that return `Motion` based on joint parameters.
|
641
|
+
|
642
|
+
rcmg_draw_fn (Optional[DRAW_FN]):
|
643
|
+
Function used to generate a reference motion trajectory for the joint
|
644
|
+
using Randomized Control Motion Generation (RCMG).
|
645
|
+
|
646
|
+
p_control_term (Optional[P_CONTROL_TERM]):
|
647
|
+
Function that computes the proportional control term for the joint.
|
648
|
+
|
649
|
+
qd_from_q (Optional[QD_FROM_Q]):
|
650
|
+
Function to compute joint velocity (`qd`) from joint positions (`q`).
|
651
|
+
|
652
|
+
coordinate_vector_to_q (Optional[COORDINATE_VECTOR_TO_Q]):
|
653
|
+
Function that maps a coordinate vector to a valid joint state `q`,
|
654
|
+
ensuring constraints (e.g., wrapping angles or normalizing quaternions).
|
655
|
+
|
656
|
+
inv_kin (Optional[INV_KIN]):
|
657
|
+
Function that computes the inverse kinematics for the joint, mapping
|
658
|
+
a desired transform to joint coordinates `q`.
|
659
|
+
|
660
|
+
init_joint_params (Optional[INIT_JOINT_PARAMS]):
|
661
|
+
Function that initializes joint-specific parameters.
|
662
|
+
|
663
|
+
utilities (Optional[dict[str, Any]]):
|
664
|
+
Additional utility functions or metadata related to the joint model.
|
665
|
+
|
666
|
+
Notes:
|
667
|
+
- The `transform` function is essential for computing the joint's spatial
|
668
|
+
transformation based on its generalized coordinates.
|
669
|
+
- The `motion` attribute describes how forces and torques affect the joint.
|
670
|
+
- The `rcmg_draw_fn` is used for RCMG motion generation.
|
671
|
+
- The `coordinate_vector_to_q` is critical for maintaining valid joint states.
|
672
|
+
""" # noqa: E501
|
673
|
+
|
520
674
|
# (q, params) -> Transform
|
521
675
|
transform: Callable[[jax.Array, jax.Array], base.Transform]
|
522
676
|
# len(motion) == len(qd)
|
@@ -1079,6 +1233,50 @@ def register_new_joint_type(
|
|
1079
1233
|
qd_width: Optional[int] = None,
|
1080
1234
|
overwrite: bool = False,
|
1081
1235
|
):
|
1236
|
+
"""
|
1237
|
+
Registers a new joint type with its corresponding `JointModel` and kinematic properties.
|
1238
|
+
|
1239
|
+
This function allows the addition of custom joint types to the system by associating
|
1240
|
+
them with a `JointModel`, specifying their state and velocity dimensions, and optionally
|
1241
|
+
overwriting existing joint definitions.
|
1242
|
+
|
1243
|
+
Args:
|
1244
|
+
joint_type (str):
|
1245
|
+
Name of the new joint type to register.
|
1246
|
+
joint_model (JointModel):
|
1247
|
+
The `JointModel` instance defining the kinematic and dynamic properties of the joint.
|
1248
|
+
q_width (int):
|
1249
|
+
Number of generalized coordinates (degrees of freedom) required to represent the joint.
|
1250
|
+
qd_width (Optional[int], default=None):
|
1251
|
+
Number of velocity coordinates associated with the joint. Defaults to `q_width`.
|
1252
|
+
overwrite (bool, default=False):
|
1253
|
+
If `True`, allows overwriting an existing joint type. Otherwise, raises an error if
|
1254
|
+
the joint type already exists.
|
1255
|
+
|
1256
|
+
Raises:
|
1257
|
+
AssertionError:
|
1258
|
+
- If `joint_type` is `"default"` (reserved name).
|
1259
|
+
- If `joint_type` already exists and `overwrite=False`.
|
1260
|
+
- If `qd_width` is not provided and does not default to `q_width`.
|
1261
|
+
- If `joint_model.motion` length does not match `qd_width`.
|
1262
|
+
|
1263
|
+
Notes:
|
1264
|
+
- The function updates global dictionaries that store joint properties, including:
|
1265
|
+
- `_joint_types`: Maps joint type names to `JointModel` instances.
|
1266
|
+
- `base.Q_WIDTHS`: Stores the number of state coordinates for each joint type.
|
1267
|
+
- `base.QD_WIDTHS`: Stores the number of velocity coordinates for each joint type.
|
1268
|
+
- If `overwrite=True`, existing entries are removed before adding the new joint type.
|
1269
|
+
- Ensures consistency between motion definitions and velocity coordinate dimensions.
|
1270
|
+
|
1271
|
+
Example:
|
1272
|
+
```python
|
1273
|
+
new_joint = JointModel(
|
1274
|
+
transform=my_transform_fn,
|
1275
|
+
motion=[base.Motion.create(ang=jnp.array([1, 0, 0]))],
|
1276
|
+
)
|
1277
|
+
register_new_joint_type("custom_hinge", new_joint, q_width=1)
|
1278
|
+
```
|
1279
|
+
""" # noqa: E501
|
1082
1280
|
# this name is used
|
1083
1281
|
assert joint_type != "default", "Please use another name."
|
1084
1282
|
|
ring/base.py
CHANGED
@@ -112,17 +112,71 @@ class _Base:
|
|
112
112
|
|
113
113
|
@struct.dataclass
|
114
114
|
class Transform(_Base):
|
115
|
-
"""Represents the Transformation from Plücker A to Plücker B,
|
116
|
-
where B is located relative to A at `pos` in frame A and `rot` is the
|
117
|
-
relative quaternion from A to B.
|
118
|
-
Create using `Transform.create(pos=..., rot=...)
|
119
115
|
"""
|
116
|
+
Represents a spatial transformation between two coordinate frames using Plücker coordinates.
|
117
|
+
|
118
|
+
The `Transform` class defines the relative position and orientation of one frame (`B`)
|
119
|
+
with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
|
120
|
+
of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
|
121
|
+
rotation from frame `A` to frame `B`.
|
122
|
+
|
123
|
+
Attributes:
|
124
|
+
pos (Vector):
|
125
|
+
The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
|
126
|
+
Shape: `(..., 3)`, where `...` represents optional batch dimensions.
|
127
|
+
rot (Quaternion):
|
128
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
129
|
+
Shape: `(..., 4)`, where `...` represents optional batch dimensions.
|
130
|
+
|
131
|
+
Methods:
|
132
|
+
create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
|
133
|
+
Creates a `Transform` instance with optional position and rotation.
|
134
|
+
|
135
|
+
zero(shape: Sequence[int] = ()) -> Transform:
|
136
|
+
Returns a zero transform with a given batch shape.
|
137
|
+
|
138
|
+
as_matrix() -> jax.Array:
|
139
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
140
|
+
|
141
|
+
Usage:
|
142
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
143
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
144
|
+
>>> T = Transform.create(pos, rot)
|
145
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
146
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
147
|
+
>>> print(T.as_matrix()) # 4x4 transformation matrix
|
148
|
+
""" # noqa: E501
|
120
149
|
|
121
150
|
pos: Vector
|
122
151
|
rot: Quaternion
|
123
152
|
|
124
153
|
@classmethod
|
125
154
|
def create(cls, pos=None, rot=None):
|
155
|
+
"""
|
156
|
+
Creates a `Transform` instance with the specified position and rotation.
|
157
|
+
|
158
|
+
At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
|
159
|
+
defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
|
160
|
+
defaults to `[0, 0, 0]`.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
pos (Optional[Vector], default=None):
|
164
|
+
The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
|
165
|
+
If `None`, defaults to a zero vector of shape `(3,)`.
|
166
|
+
rot (Optional[Quaternion], default=None):
|
167
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
168
|
+
If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Transform: A new `Transform` instance with the specified position and rotation.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
175
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
176
|
+
>>> T = Transform.create(pos, rot)
|
177
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
178
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
179
|
+
""" # noqa: E501
|
126
180
|
assert not (pos is None and rot is None), "One must be given."
|
127
181
|
shape_rot = rot.shape[:-1] if rot is not None else ()
|
128
182
|
shape_pos = pos.shape[:-1] if pos is not None else ()
|
@@ -139,12 +193,49 @@ class Transform(_Base):
|
|
139
193
|
|
140
194
|
@classmethod
|
141
195
|
def zero(cls, shape=()) -> "Transform":
|
142
|
-
"""
|
196
|
+
"""
|
197
|
+
Returns a zero transform with a given batch shape.
|
198
|
+
|
199
|
+
This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
|
200
|
+
which represents no translation or rotation.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
shape (Sequence[int], default=()):
|
204
|
+
The batch shape for the transform. Defaults to a scalar transform.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Transform: A zero transform with the specified batch shape.
|
208
|
+
|
209
|
+
Example:
|
210
|
+
>>> T = Transform.zero()
|
211
|
+
>>> print(T.pos) # Output: [0. 0. 0.]
|
212
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
213
|
+
""" # noqa: E501
|
143
214
|
pos = jnp.zeros(shape + (3,))
|
144
215
|
rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
|
145
216
|
return Transform(pos, rot)
|
146
217
|
|
147
218
|
def as_matrix(self) -> jax.Array:
|
219
|
+
"""
|
220
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
221
|
+
|
222
|
+
The homogeneous transformation matrix is defined as:
|
223
|
+
|
224
|
+
```
|
225
|
+
[ R t ]
|
226
|
+
[ 0 1 ]
|
227
|
+
```
|
228
|
+
|
229
|
+
where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
|
230
|
+
3x1 position vector.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
jax.Array: A `(4, 4)` homogeneous transformation matrix.
|
234
|
+
|
235
|
+
Example:
|
236
|
+
>>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
|
237
|
+
>>> print(T.as_matrix()) # Output: 4x4 matrix
|
238
|
+
""" # noqa: E501
|
148
239
|
E = maths.quat_to_3x3(self.rot)
|
149
240
|
return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
|
150
241
|
|
@@ -402,7 +493,175 @@ QD_WIDTHS = {
|
|
402
493
|
|
403
494
|
@struct.dataclass
|
404
495
|
class System(_Base):
|
405
|
-
"
|
496
|
+
"""
|
497
|
+
Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
|
498
|
+
|
499
|
+
The `System` class models the kinematic and dynamic properties of a multibody
|
500
|
+
system, providing methods for state representation, transformations, joint
|
501
|
+
configuration management, and rendering. It supports both minimal and maximal
|
502
|
+
coordinate representations and can be parsed from or saved to XML files.
|
503
|
+
|
504
|
+
Attributes:
|
505
|
+
link_parents (list[int]):
|
506
|
+
A list specifying the parent index for each link. The root link has a parent index of `-1`.
|
507
|
+
links (Link):
|
508
|
+
A data structure containing information about all links in the system.
|
509
|
+
link_types (list[str]):
|
510
|
+
A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
|
511
|
+
link_damping (jax.Array):
|
512
|
+
Joint damping coefficients for each link.
|
513
|
+
link_armature (jax.Array):
|
514
|
+
Armature inertia values for each joint.
|
515
|
+
link_spring_stiffness (jax.Array):
|
516
|
+
Stiffness values for joint springs.
|
517
|
+
link_spring_zeropoint (jax.Array):
|
518
|
+
Rest position of joint springs.
|
519
|
+
dt (float):
|
520
|
+
Simulation time step size.
|
521
|
+
geoms (list[Geometry]):
|
522
|
+
List of geometries associated with the system.
|
523
|
+
gravity (jax.Array):
|
524
|
+
Gravity vector applied to the system (default: `[0, 0, -9.81]`).
|
525
|
+
integration_method (str):
|
526
|
+
Integration method for simulation (default: "semi_implicit_euler").
|
527
|
+
mass_mat_iters (int):
|
528
|
+
Number of iterations for mass matrix calculations.
|
529
|
+
link_names (list[str]):
|
530
|
+
Names of the links in the system.
|
531
|
+
model_name (Optional[str]):
|
532
|
+
Name of the system model (if available).
|
533
|
+
omc (list[MaxCoordOMC | None]):
|
534
|
+
List of optional Maximal Coordinate representations.
|
535
|
+
|
536
|
+
Methods:
|
537
|
+
num_links() -> int:
|
538
|
+
Returns the number of links in the system.
|
539
|
+
|
540
|
+
q_size() -> int:
|
541
|
+
Returns the total number of generalized coordinates (`q`) in the system.
|
542
|
+
|
543
|
+
qd_size() -> int:
|
544
|
+
Returns the total number of generalized velocities (`qd`) in the system.
|
545
|
+
|
546
|
+
name_to_idx(name: str) -> int:
|
547
|
+
Returns the index of a link given its name.
|
548
|
+
|
549
|
+
idx_to_name(idx: int, allow_world: bool = False) -> str:
|
550
|
+
Returns the name of a link given its index. If `allow_world` is `True`,
|
551
|
+
returns `"world"` for index `-1`.
|
552
|
+
|
553
|
+
idx_map(type: str) -> dict:
|
554
|
+
Returns a dictionary mapping link names to their indices for a specified type
|
555
|
+
(`"l"`, `"q"`, or `"d"`).
|
556
|
+
|
557
|
+
parent_name(name: str) -> str:
|
558
|
+
Returns the name of the parent link for a given link.
|
559
|
+
|
560
|
+
change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
561
|
+
Changes the name of the system model.
|
562
|
+
|
563
|
+
change_link_name(old_name: str, new_name: str) -> "System":
|
564
|
+
Renames a specific link.
|
565
|
+
|
566
|
+
add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
567
|
+
Adds both a prefix and suffix to all link names.
|
568
|
+
|
569
|
+
freeze(name: str | list[str]) -> "System":
|
570
|
+
Freezes the specified link(s), making them immovable.
|
571
|
+
|
572
|
+
unfreeze(name: str, new_joint_type: str) -> "System":
|
573
|
+
Unfreezes a frozen link and assigns it a new joint type.
|
574
|
+
|
575
|
+
change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
|
576
|
+
Changes the joint type of a specified link.
|
577
|
+
|
578
|
+
joint_type_simplification(typ: str) -> str:
|
579
|
+
Returns a simplified representation of the given joint type.
|
580
|
+
|
581
|
+
joint_type_is_free_or_cor(typ: str) -> bool:
|
582
|
+
Checks if a joint type is either "free" or "cor".
|
583
|
+
|
584
|
+
joint_type_is_spherical(typ: str) -> bool:
|
585
|
+
Checks if a joint type is "spherical".
|
586
|
+
|
587
|
+
joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
588
|
+
Checks if a joint type is "free", "cor", or "spherical".
|
589
|
+
|
590
|
+
findall_imus(names: bool = True) -> list[str] | list[int]:
|
591
|
+
Finds all IMU sensors in the system.
|
592
|
+
|
593
|
+
findall_segments(names: bool = True) -> list[str] | list[int]:
|
594
|
+
Finds all non-IMU segments in the system.
|
595
|
+
|
596
|
+
findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
|
597
|
+
Returns all bodies directly connected to the world.
|
598
|
+
|
599
|
+
find_body_to_world(name: bool = False) -> int | str:
|
600
|
+
Returns the root body connected to the world.
|
601
|
+
|
602
|
+
findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
|
603
|
+
Returns all bodies with the specified joint type.
|
604
|
+
|
605
|
+
children(name: str, names: bool = False) -> list[int] | list[str]:
|
606
|
+
Returns the direct children of a given body.
|
607
|
+
|
608
|
+
findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
|
609
|
+
Finds all bodies in the subsystem rooted at a given link.
|
610
|
+
|
611
|
+
scan(f: Callable, in_types: str, *args, reverse: bool = False):
|
612
|
+
Iterates over system elements while applying a function.
|
613
|
+
|
614
|
+
parse() -> "System":
|
615
|
+
Parses the system, performing consistency checks and computing spatial inertia tensors.
|
616
|
+
|
617
|
+
render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
|
618
|
+
Renders frames of the system using maximal coordinates.
|
619
|
+
|
620
|
+
render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
|
621
|
+
Renders a predicted state transformation.
|
622
|
+
|
623
|
+
delete_system(link_name: str | list[str], strict: bool = True):
|
624
|
+
Removes a subsystem from the system.
|
625
|
+
|
626
|
+
make_sys_noimu(imu_link_names: Optional[list[str]] = None):
|
627
|
+
Returns a version of the system without IMU sensors.
|
628
|
+
|
629
|
+
inject_system(other_system: "System", at_body: Optional[str] = None):
|
630
|
+
Merges another system into this one.
|
631
|
+
|
632
|
+
morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
|
633
|
+
Reorders the system’s link hierarchy.
|
634
|
+
|
635
|
+
from_xml(path: str, seed: int = 1) -> "System":
|
636
|
+
Loads a system from an XML file.
|
637
|
+
|
638
|
+
from_str(xml: str, seed: int = 1) -> "System":
|
639
|
+
Loads a system from an XML string.
|
640
|
+
|
641
|
+
to_str(warn: bool = True) -> str:
|
642
|
+
Serializes the system to an XML string.
|
643
|
+
|
644
|
+
to_xml(path: str) -> None:
|
645
|
+
Saves the system as an XML file.
|
646
|
+
|
647
|
+
create(path_or_str: str, seed: int = 1) -> "System":
|
648
|
+
Creates a `System` instance from an XML file or string.
|
649
|
+
|
650
|
+
coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
|
651
|
+
Converts a coordinate vector to minimal coordinates (`q`), applying
|
652
|
+
constraints such as quaternion normalization.
|
653
|
+
|
654
|
+
Raises:
|
655
|
+
AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
|
656
|
+
InvalidSystemError: If an operation results in an inconsistent system state.
|
657
|
+
|
658
|
+
Notes:
|
659
|
+
- The system must be parsed before use to ensure consistency.
|
660
|
+
- The system supports batch operations using JAX for efficient computations.
|
661
|
+
- Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
|
662
|
+
- Inertial properties of links are computed automatically from associated geometries.
|
663
|
+
""" # noqa: E501
|
664
|
+
|
406
665
|
link_parents: list[int] = struct.field(False)
|
407
666
|
links: Link
|
408
667
|
link_types: list[str] = struct.field(False)
|
@@ -429,25 +688,32 @@ class System(_Base):
|
|
429
688
|
omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
|
430
689
|
|
431
690
|
def num_links(self) -> int:
|
691
|
+
"Returns the number of links in the system."
|
432
692
|
return len(self.link_parents)
|
433
693
|
|
434
694
|
def q_size(self) -> int:
|
695
|
+
"Returns the total number of generalized coordinates (`q`) in the system."
|
435
696
|
return sum([Q_WIDTHS[typ] for typ in self.link_types])
|
436
697
|
|
437
698
|
def qd_size(self) -> int:
|
699
|
+
"Returns the total number of generalized velocities (`qd`) in the system."
|
438
700
|
return sum([QD_WIDTHS[typ] for typ in self.link_types])
|
439
701
|
|
440
702
|
def name_to_idx(self, name: str) -> int:
|
703
|
+
"Returns the index of a link given its name."
|
441
704
|
return self.link_names.index(name)
|
442
705
|
|
443
706
|
def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
|
707
|
+
"""Returns the name of a link given its index. If `allow_world` is `True`,
|
708
|
+
returns `"world"` for index `-1`."""
|
444
709
|
if allow_world and idx == -1:
|
445
710
|
return "world"
|
446
711
|
assert idx >= 0, "Worldbody index has no name."
|
447
712
|
return self.link_names[idx]
|
448
713
|
|
449
714
|
def idx_map(self, type: str) -> dict:
|
450
|
-
"
|
715
|
+
"""Returns a dictionary mapping link names to their indices for a specified type
|
716
|
+
(`"l"`, `"q"`, or `"d"`)."""
|
451
717
|
dict_int_slices = {}
|
452
718
|
|
453
719
|
def f(_, idx_map, name: str, link_idx: int):
|
@@ -458,10 +724,11 @@ class System(_Base):
|
|
458
724
|
return dict_int_slices
|
459
725
|
|
460
726
|
def parent_name(self, name: str) -> str:
|
727
|
+
"Returns the name of the parent link for a given link."
|
461
728
|
return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
|
462
729
|
|
463
730
|
def add_prefix(self, prefix: str = "") -> "System":
|
464
|
-
return self.
|
731
|
+
return self.add_prefix_suffix(prefix=prefix)
|
465
732
|
|
466
733
|
def change_model_name(
|
467
734
|
self,
|
@@ -469,6 +736,7 @@ class System(_Base):
|
|
469
736
|
prefix: Optional[str] = None,
|
470
737
|
suffix: Optional[str] = None,
|
471
738
|
) -> "System":
|
739
|
+
"Changes the name of the system model."
|
472
740
|
if prefix is None:
|
473
741
|
prefix = ""
|
474
742
|
if suffix is None:
|
@@ -479,6 +747,7 @@ class System(_Base):
|
|
479
747
|
return self.replace(model_name=name)
|
480
748
|
|
481
749
|
def change_link_name(self, old_name: str, new_name: str) -> "System":
|
750
|
+
"Renames a specific link."
|
482
751
|
old_idx = self.name_to_idx(old_name)
|
483
752
|
new_link_names = self.link_names.copy()
|
484
753
|
new_link_names[old_idx] = new_name
|
@@ -487,6 +756,7 @@ class System(_Base):
|
|
487
756
|
def add_prefix_suffix(
|
488
757
|
self, prefix: Optional[str] = None, suffix: Optional[str] = None
|
489
758
|
) -> "System":
|
759
|
+
"Adds either or, or both a prefix and suffix to all link names."
|
490
760
|
if prefix is None:
|
491
761
|
prefix = ""
|
492
762
|
if suffix is None:
|
@@ -526,6 +796,7 @@ class System(_Base):
|
|
526
796
|
return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
|
527
797
|
|
528
798
|
def freeze(self, name: str | list[str]):
|
799
|
+
"Freezes the specified link(s), making them immovable (uses `frozen` joint)"
|
529
800
|
if isinstance(name, list):
|
530
801
|
sys = self
|
531
802
|
for n in name:
|
@@ -544,6 +815,7 @@ class System(_Base):
|
|
544
815
|
return _update_sys_if_replace_joint_type(self, logic_freeze)
|
545
816
|
|
546
817
|
def unfreeze(self, name: str, new_joint_type: str):
|
818
|
+
"Unfreezes a frozen link and assigns it a new joint type."
|
547
819
|
assert self.link_types[self.name_to_idx(name)] == "frozen"
|
548
820
|
assert new_joint_type != "frozen"
|
549
821
|
|
@@ -560,7 +832,8 @@ class System(_Base):
|
|
560
832
|
seed: int = 1,
|
561
833
|
warn: bool = True,
|
562
834
|
):
|
563
|
-
"
|
835
|
+
"""Changes the joint type of a specified link.
|
836
|
+
By default damping, stiffness are set to zero."""
|
564
837
|
from ring.algorithms import get_joint_model
|
565
838
|
|
566
839
|
q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
|
@@ -594,6 +867,7 @@ class System(_Base):
|
|
594
867
|
|
595
868
|
@staticmethod
|
596
869
|
def joint_type_simplification(typ: str) -> str:
|
870
|
+
"Returns a simplified name of the given joint type."
|
597
871
|
if typ[:4] == "free":
|
598
872
|
if typ == "free_2d":
|
599
873
|
return "free_2d"
|
@@ -608,23 +882,28 @@ class System(_Base):
|
|
608
882
|
|
609
883
|
@staticmethod
|
610
884
|
def joint_type_is_free_or_cor(typ: str) -> bool:
|
885
|
+
'Checks if a joint type is either "free" or "cor".'
|
611
886
|
return System.joint_type_simplification(typ) in ["free", "cor"]
|
612
887
|
|
613
888
|
@staticmethod
|
614
889
|
def joint_type_is_spherical(typ: str) -> bool:
|
890
|
+
'Checks if a joint type is "spherical".'
|
615
891
|
return System.joint_type_simplification(typ) == "spherical"
|
616
892
|
|
617
893
|
@staticmethod
|
618
894
|
def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
895
|
+
'Checks if a joint type is "free", "cor", or "spherical".'
|
619
896
|
return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
|
620
897
|
typ
|
621
898
|
)
|
622
899
|
|
623
900
|
def findall_imus(self, names: bool = True) -> list[str] | list[int]:
|
901
|
+
"Finds all IMU sensors in the system."
|
624
902
|
bodies = [name for name in self.link_names if name[:3] == "imu"]
|
625
903
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
626
904
|
|
627
905
|
def findall_segments(self, names: bool = True) -> list[str] | list[int]:
|
906
|
+
"Finds all non-IMU segments in the system."
|
628
907
|
imus = self.findall_imus(names=True)
|
629
908
|
bodies = [name for name in self.link_names if name not in imus]
|
630
909
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
@@ -633,10 +912,12 @@ class System(_Base):
|
|
633
912
|
return [self.idx_to_name(i) for i in bodies]
|
634
913
|
|
635
914
|
def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
|
915
|
+
"Returns all bodies directly connected to the world."
|
636
916
|
bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
|
637
917
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
638
918
|
|
639
919
|
def find_body_to_world(self, name: bool = False) -> int | str:
|
920
|
+
"Returns the root body connected to the world."
|
640
921
|
bodies = self.findall_bodies_to_world(names=name)
|
641
922
|
assert len(bodies) == 1
|
642
923
|
return bodies[0]
|
@@ -644,6 +925,7 @@ class System(_Base):
|
|
644
925
|
def findall_bodies_with_jointtype(
|
645
926
|
self, typ: str, names: bool = False
|
646
927
|
) -> list[int] | list[str]:
|
928
|
+
"Returns all bodies with the specified joint type."
|
647
929
|
bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
|
648
930
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
649
931
|
|
@@ -781,20 +1063,25 @@ class System(_Base):
|
|
781
1063
|
|
782
1064
|
@staticmethod
|
783
1065
|
def from_xml(path: str, seed: int = 1):
|
1066
|
+
"Loads a system from an XML file."
|
784
1067
|
return ring.io.load_sys_from_xml(path, seed)
|
785
1068
|
|
786
1069
|
@staticmethod
|
787
1070
|
def from_str(xml: str, seed: int = 1):
|
1071
|
+
"Loads a system from an XML string."
|
788
1072
|
return ring.io.load_sys_from_str(xml, seed)
|
789
1073
|
|
790
1074
|
def to_str(self, warn: bool = True) -> str:
|
1075
|
+
"Serializes the system to an XML string."
|
791
1076
|
return ring.io.save_sys_to_str(self, warn=warn)
|
792
1077
|
|
793
1078
|
def to_xml(self, path: str) -> None:
|
1079
|
+
"Saves the system to an XML file."
|
794
1080
|
ring.io.save_sys_to_xml(self, path)
|
795
1081
|
|
796
1082
|
@classmethod
|
797
1083
|
def create(cls, path_or_str: str, seed: int = 1) -> "System":
|
1084
|
+
"Creates a `System` instance from an XML file or string."
|
798
1085
|
path = Path(path_or_str).with_suffix(".xml")
|
799
1086
|
|
800
1087
|
exists = False
|
@@ -814,7 +1101,8 @@ class System(_Base):
|
|
814
1101
|
q: jax.Array,
|
815
1102
|
custom_joints: dict[str, Callable] = {},
|
816
1103
|
) -> jax.Array:
|
817
|
-
"""
|
1104
|
+
"""Converts a coordinate vector to minimal coordinates (`q`), applying
|
1105
|
+
constraints such as quaternion normalization."""
|
818
1106
|
# Does, e.g.
|
819
1107
|
# - normalize quaternions
|
820
1108
|
# - hinge joints in [-pi, pi]
|
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
|
|
1026
1314
|
|
1027
1315
|
@struct.dataclass
|
1028
1316
|
class State(_Base):
|
1029
|
-
"""The static and dynamic state of a system in minimal and maximal coordinates.
|
1030
|
-
Use `.create()` to create this object.
|
1031
|
-
|
1032
|
-
Args:
|
1033
|
-
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
1034
|
-
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
1035
|
-
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
1036
1317
|
"""
|
1318
|
+
Represents the state of a dynamic system in minimal and maximal coordinates.
|
1319
|
+
|
1320
|
+
The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
|
1321
|
+
of the system in minimal coordinates, as well as the corresponding transforms (`x`)
|
1322
|
+
in maximal coordinates.
|
1323
|
+
|
1324
|
+
Attributes:
|
1325
|
+
q (jax.Array):
|
1326
|
+
The joint positions (generalized coordinates) of the system. The size
|
1327
|
+
of `q` matches `sys.q_size()`.
|
1328
|
+
qd (jax.Array):
|
1329
|
+
The joint velocities (generalized velocities) of the system. The size
|
1330
|
+
of `qd` matches `sys.qd_size()`.
|
1331
|
+
x (Transform):
|
1332
|
+
The maximal coordinate representation of all system links, expressed as
|
1333
|
+
a `Transform` object.
|
1334
|
+
|
1335
|
+
Methods:
|
1336
|
+
create(sys: System, q: Optional[jax.Array] = None,
|
1337
|
+
qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
|
1338
|
+
key: Optional[jax.Array] = None,
|
1339
|
+
custom_joints: dict[str, Callable] = {}) -> State:
|
1340
|
+
Creates a `State` instance for a given system with optional initial conditions.
|
1341
|
+
|
1342
|
+
Usage:
|
1343
|
+
>>> sys = System.create("model.xml")
|
1344
|
+
>>> state = State.create(sys)
|
1345
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1346
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1347
|
+
""" # noqa: E501
|
1037
1348
|
|
1038
1349
|
q: jax.Array
|
1039
1350
|
qd: jax.Array
|
@@ -1048,19 +1359,37 @@ class State(_Base):
|
|
1048
1359
|
x: Optional[Transform] = None,
|
1049
1360
|
key: Optional[jax.Array] = None,
|
1050
1361
|
custom_joints: dict[str, Callable] = {},
|
1051
|
-
):
|
1052
|
-
"""
|
1362
|
+
) -> "State":
|
1363
|
+
"""
|
1364
|
+
Creates a `State` instance for the given system with optional initial conditions.
|
1365
|
+
|
1366
|
+
If no initial values are provided, joint positions (`q`) and velocities (`qd`)
|
1367
|
+
are initialized to zero, except for free and spherical joints, which have unit quaternions.
|
1053
1368
|
|
1054
1369
|
Args:
|
1055
|
-
sys (System):
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1370
|
+
sys (System):
|
1371
|
+
The system for which to create a state.
|
1372
|
+
q (Optional[jax.Array], default=None):
|
1373
|
+
Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
|
1374
|
+
for free and spherical joints.
|
1375
|
+
qd (Optional[jax.Array], default=None):
|
1376
|
+
Initial joint velocities. If `None`, defaults to zeros.
|
1377
|
+
x (Optional[Transform], default=None):
|
1378
|
+
Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
|
1379
|
+
key (Optional[jax.Array], default=None):
|
1380
|
+
Random key for initializing `q` if no values are provided.
|
1381
|
+
custom_joints (dict[str, Callable], default={}):
|
1382
|
+
Custom joint functions for mapping coordinate vectors to minimal coordinates.
|
1060
1383
|
|
1061
1384
|
Returns:
|
1062
|
-
|
1063
|
-
|
1385
|
+
State: A new instance of the `State` class representing the initialized system state.
|
1386
|
+
|
1387
|
+
Example:
|
1388
|
+
>>> sys = System.create("model.xml")
|
1389
|
+
>>> state = State.create(sys)
|
1390
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1391
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1392
|
+
""" # noqa: E501
|
1064
1393
|
if key is not None:
|
1065
1394
|
assert q is None
|
1066
1395
|
q = jax.random.normal(key, shape=(sys.q_size(),))
|
ring/ml/base.py
CHANGED
@@ -297,6 +297,11 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
297
297
|
|
298
298
|
if self._quat_normalize:
|
299
299
|
assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
|
300
|
+
|
301
|
+
# for exporting neural networks to ONNX format, you will have to use
|
302
|
+
# the first version, but for neural network training the second version
|
303
|
+
# is required
|
304
|
+
# yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
|
300
305
|
yhat = ring.maths.safe_normalize(yhat)
|
301
306
|
|
302
307
|
return yhat, state
|
ring/sim2real/sim2real.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
1
|
from typing import Optional, Tuple
|
2
2
|
|
3
3
|
import jax
|
4
|
+
import tree_utils
|
5
|
+
|
4
6
|
from ring import algebra
|
5
7
|
from ring import base
|
6
8
|
from ring import io
|
7
9
|
from ring import maths
|
8
10
|
from ring.algorithms import generator
|
9
11
|
from ring.algorithms import jcalc
|
10
|
-
import tree_utils
|
11
12
|
|
12
13
|
|
13
14
|
def xs_from_raw(
|
@@ -189,7 +190,14 @@ def delete_to_world_pos_rot(sys: base.System, xs: base.Transform) -> base.Transf
|
|
189
190
|
|
190
191
|
|
191
192
|
def randomize_to_world_pos_rot(
|
192
|
-
key: jax.Array,
|
193
|
+
key: jax.Array,
|
194
|
+
sys: base.System,
|
195
|
+
xs: base.Transform,
|
196
|
+
config: jcalc.MotionConfig,
|
197
|
+
world_joint: str = "free",
|
198
|
+
cor: bool = False,
|
199
|
+
overwrite_q_ref: jax.Array = None,
|
200
|
+
damping=None,
|
193
201
|
) -> base.Transform:
|
194
202
|
"""Replace the transforms of all links that connect to the worldbody
|
195
203
|
by randomize transforms.
|
@@ -210,14 +218,30 @@ def randomize_to_world_pos_rot(
|
|
210
218
|
<x_xy>
|
211
219
|
<options dt="0.01"/>
|
212
220
|
<worldbody>
|
213
|
-
<body name="free" joint="free"
|
221
|
+
<body name="free" joint="free" damping="15.0 15.0 15.0 25.0 25.0 25.0">
|
222
|
+
<geom type="box" mass="1" dim="0.1 0.1 0.1"/>
|
223
|
+
</body>
|
214
224
|
</worldbody>
|
215
225
|
</x_xy>
|
216
226
|
"""
|
217
|
-
|
218
227
|
free_sys = io.load_sys_from_str(free_sys_str)
|
228
|
+
|
229
|
+
dynamic_simulation = True if overwrite_q_ref is not None else False
|
230
|
+
|
231
|
+
if world_joint != "free":
|
232
|
+
if dynamic_simulation:
|
233
|
+
assert damping is not None
|
234
|
+
free_sys = free_sys.change_joint_type("free", world_joint, new_damp=damping)
|
235
|
+
|
219
236
|
_, xs_free = generator.RCMG(
|
220
|
-
free_sys,
|
237
|
+
free_sys,
|
238
|
+
config,
|
239
|
+
finalize_fn=lambda key, q, x, sys: (q, x),
|
240
|
+
cor=cor,
|
241
|
+
dynamic_simulation=dynamic_simulation,
|
242
|
+
dynamic_simulation_kwargs=dict(
|
243
|
+
overwrite_q_ref=(overwrite_q_ref, free_sys.idx_map("q"))
|
244
|
+
),
|
221
245
|
).to_lazy_gen()(key)
|
222
246
|
xs_free = xs_free.take(0, axis=0)
|
223
247
|
xs_free = xs_free.take(free_sys.name_to_idx("free"), axis=1)
|
File without changes
|
File without changes
|