imt-ring 1.6.38__py3-none-any.whl → 1.6.39__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: imt-ring
3
- Version: 1.6.38
3
+ Version: 1.6.39
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,12 +1,12 @@
1
1
  ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=4Yxk6jk-B4UUm_6YYshxmHSHqOg0mhTOxtZP5fFS8nw,35373
3
+ ring/base.py,sha256=zromjIuMpNBoyiwHa9OCyZvAz7jHjXHZIdRt8fN8PoA,50481
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
8
- ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
9
- ring/algorithms/jcalc.py,sha256=QafnCKa1mjEl7nL_KuadPJB5ebW31NKnkdcKn2YtSsM,36171
8
+ ring/algorithms/dynamics.py,sha256=NFOZawjrFoS5RgiWOpG1pQCC8l7RBOEZIi9ok6gvf9U,12268
9
+ ring/algorithms/jcalc.py,sha256=l6BXOmXwrZ_AKKRm4gEHq_k2LSUQ4wd--1qL1qNTcKk,46794
10
10
  ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
11
11
  ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
@@ -15,7 +15,7 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
18
+ ring/algorithms/generator/base.py,sha256=klWYt6TlMluLu0ihGzmmPXBm47DOTpjXJylZVNXHVEk,22419
19
19
  ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
@@ -52,7 +52,7 @@ ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
53
53
  ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
- ring/ml/base.py,sha256=eEpLuXlhFoJ-cpXoSGjctLaYduQhnSVpbv-FEYftNRs,9972
55
+ ring/ml/base.py,sha256=phBUfTpP1Mqt8lvtilSavT7ypkaeaF1oh7nCMjV0dqg,10046
56
56
  ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
57
57
  ring/ml/ml_utils.py,sha256=hu189AnHcmkhkpEPZZ19O0gWz3T-YKpWQW9buqDTMow,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.38.dist-info/METADATA,sha256=9rN1VzsIlGU8eyABz9-pTxj0OTCFOZRilEEzkB4gyvg,4251
90
- imt_ring-1.6.38.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
91
- imt_ring-1.6.38.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.38.dist-info/RECORD,,
89
+ imt_ring-1.6.39.dist-info/METADATA,sha256=v0rBTnCQP-SWJU153byfX31HUCcrWHhCg_EmecLbLf4,4251
90
+ imt_ring-1.6.39.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
91
+ imt_ring-1.6.39.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.39.dist-info/RECORD,,
@@ -303,7 +303,33 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
- "Steps the dynamics. Returns the state of next timestep."
306
+ """
307
+ Advances the system dynamics by a single timestep using semi-implicit Euler integration.
308
+
309
+ This function updates the system's state by integrating the equations of motion
310
+ over a timestep, potentially with multiple substeps for improved numerical stability.
311
+ The method ensures that the system's kinematics are updated before each integration step.
312
+
313
+ Args:
314
+ sys (base.System):
315
+ The system to simulate, containing link information, joint dynamics, and integration parameters.
316
+ state (base.State):
317
+ The current state of the system, including joint positions (`q`), velocities (`qd`), and transforms (`x`).
318
+ taus (Optional[jax.Array], optional):
319
+ The control torques applied to the system joints. If `None`, zero torques are applied.
320
+ Defaults to `None`.
321
+ n_substeps (int, optional):
322
+ The number of integration substeps per timestep to improve numerical accuracy.
323
+ Defaults to `1`.
324
+
325
+ Returns:
326
+ base.State:
327
+ The updated state of the system after integration.
328
+
329
+ Raises:
330
+ AssertionError: If the system's degrees of freedom (`q` and `qd`) do not match expectations.
331
+ AssertionError: If an unsupported integration method is specified in `sys.integration_method`.
332
+ """ # noqa: E501
307
333
  assert sys.q_size() == state.q.size
308
334
  if taus is None:
309
335
  taus = jnp.zeros_like(state.qd)
@@ -53,7 +53,85 @@ class RCMG:
53
53
  cor: bool = False,
54
54
  disable_tqdm: bool = False,
55
55
  ) -> None:
56
- "Random Chain Motion Generator"
56
+ """
57
+ Initializes the Random Chain Motion Generator (RCMG).
58
+
59
+ The RCMG generates synthetic joint motion sequences for kinematic and dynamic
60
+ systems based on predefined motion configurations. It allows for system
61
+ randomization, augmentation with IMU and joint axis data, and optional
62
+ dynamic simulation.
63
+
64
+ Args:
65
+ sys (base.System | list[base.System]):
66
+ The system(s) for which motion should be generated.
67
+ config (jcalc.MotionConfig | list[jcalc.MotionConfig], optional):
68
+ Motion configuration(s) defining velocity limits, interpolation methods,
69
+ and range constraints. Defaults to `jcalc.MotionConfig()`.
70
+ setup_fn (Optional[types.SETUP_FN], optional):
71
+ A function to modify the system before motion generation. Defaults to `None`.
72
+ finalize_fn (Optional[types.FINALIZE_FN], optional):
73
+ A function to modify outputs after motion generation. Defaults to `None`.
74
+ add_X_imus (bool, optional):
75
+ Whether to add IMU sensor data to the output. Defaults to `False`.
76
+ add_X_imus_kwargs (dict, optional):
77
+ Additional keyword arguments for IMU data processing. Defaults to `{}`.
78
+ add_X_jointaxes (bool, optional):
79
+ Whether to add joint axis data to the output. Defaults to `False`.
80
+ add_X_jointaxes_kwargs (dict, optional):
81
+ Additional keyword arguments for joint axis data processing. Defaults to `{}`.
82
+ add_y_relpose (bool, optional):
83
+ Whether to add relative pose targets to the output. Defaults to `False`.
84
+ add_y_rootincl (bool, optional):
85
+ Whether to add root inclination targets to the output. Defaults to `False`.
86
+ add_y_rootincl_kwargs (dict, optional):
87
+ Additional keyword arguments for root inclination processing. Defaults to `{}`.
88
+ add_y_rootfull (bool, optional):
89
+ Whether to add full root state targets to the output. Defaults to `False`.
90
+ add_y_rootfull_kwargs (dict, optional):
91
+ Additional keyword arguments for full root state processing. Defaults to `{}`.
92
+ sys_ml (Optional[base.System], optional):
93
+ System that defines the graph and naming structure of the `X` and `y` outputs. Defaults to `None` which then uses the first provided system.
94
+ randomize_positions (bool, optional):
95
+ Whether to randomised positions based on `pos_min` and `pos_max`. Defaults to `False`.
96
+ randomize_motion_artifacts (bool, optional):
97
+ Whether to randomize the IMU motion artifact simulation. This randomises the spring stiffness and spring damping parameters of the passive free joint that is added between nonrigid and rigid IMU. Defaults to `False`.
98
+ randomize_joint_params (bool, optional):
99
+ Whether to randomize joint parameters by calling `JointModel.init_joint_params` before every sequence generation. Defaults to `False`.
100
+ randomize_hz (bool, optional):
101
+ Whether to randomize the sampling frequency of the generated data. Defaults to `False`.
102
+ randomize_hz_kwargs (dict, optional):
103
+ Additional keyword arguments for sampling frequency randomization. Defaults to `{}`.
104
+ imu_motion_artifacts (bool, optional):
105
+ Whether to simulate nonrigid IMU motion artifacts. Defaults to `False`.
106
+ imu_motion_artifacts_kwargs (dict, optional):
107
+ Additional keyword arguments for IMU motion artifact simulation. Defaults to `{}`.
108
+ dynamic_simulation (bool, optional):
109
+ Whether to use a physics-based simulation to generate motion instead of purely
110
+ kinematic methods. Defaults to `False`.
111
+ dynamic_simulation_kwargs (dict, optional):
112
+ Additional keyword arguments for dynamic simulation. Defaults to `{}`.
113
+ output_transform (Optional[Callable], optional):
114
+ A function to transform the generated output data. Defaults to `None`.
115
+ keep_output_extras (bool, optional):
116
+ Whether to keep additional output metadata. Defaults to `False`.
117
+ use_link_number_in_Xy (bool, optional):
118
+ Whether to replace joint names with numerical indices in the output. Defaults to `False`.
119
+ cor (bool, optional):
120
+ Whether to replace free joints with center-of-rotation (COR) 9D free joint. Defaults to `False`.
121
+ disable_tqdm (bool, optional):
122
+ Whether to disable progress bars during generation. Defaults to `False`.
123
+
124
+ Raises:
125
+ AssertionError: If any of the provided `MotionConfig` instances are infeasible.
126
+
127
+ Notes:
128
+ - This class supports batch generation, lazy and eager data loading, and
129
+ motion augmentation.
130
+ - If `randomize_hz=True`, the time step (`dt`) varies according to the specified
131
+ sampling rates.
132
+ - When `cor=True`, free joints are replaced with center-of-rotation models,
133
+ affecting joint motion behavior.
134
+ """ # noqa: E501
57
135
 
58
136
  # add some default values
59
137
  randomize_hz_kwargs_defaults = dict(add_dt=True)
@@ -139,6 +217,7 @@ class RCMG:
139
217
  def to_lazy_gen(
140
218
  self, sizes: int | list[int] = 1, jit: bool = True
141
219
  ) -> types.BatchedGenerator:
220
+ "Returns a generator `X, y = gen(key)` that lazily generates batched sequences."
142
221
  return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
143
222
 
144
223
  @staticmethod
@@ -201,7 +280,7 @@ class RCMG:
201
280
  ),
202
281
  verbose: bool = True,
203
282
  ):
204
-
283
+ "Stores unbatched sequences as numpy arrays into folder."
205
284
  i = 0
206
285
 
207
286
  def callback(data: list[PyTree[np.ndarray]]) -> None:
@@ -237,6 +316,7 @@ class RCMG:
237
316
  shuffle: bool = True,
238
317
  transform=None,
239
318
  ) -> types.BatchedGenerator:
319
+ "Returns a generator `X, y = gen(key)` that returns precomputed batched sequences." # noqa: E501
240
320
  data = self.to_list(sizes, seed)
241
321
  assert len(data) >= batchsize
242
322
  return self.eager_gen_from_list(data, batchsize, shuffle, transform)
ring/algorithms/jcalc.py CHANGED
@@ -19,6 +19,87 @@ from ring.algorithms._random import TimeDependentFloat
19
19
 
20
20
  @dataclass
21
21
  class MotionConfig:
22
+ """
23
+ Configuration for joint motion generation in kinematic and dynamic simulations.
24
+
25
+ This class defines the constraints and parameters for generating random joint motions,
26
+ including angular and positional velocity limits, interpolation methods, and range
27
+ restrictions for various joint types.
28
+
29
+ Attributes:
30
+ T (float): Total duration of the motion sequence (in seconds).
31
+ t_min (float): Minimum time interval between two generated joint states.
32
+ t_max (float | TimeDependentFloat): Maximum time interval between two generated joint states.
33
+
34
+ dang_min (float | TimeDependentFloat): Minimum angular velocity (rad/s).
35
+ dang_max (float | TimeDependentFloat): Maximum angular velocity (rad/s).
36
+ dang_min_free_spherical (float | TimeDependentFloat): Minimum angular velocity for free and spherical joints.
37
+ dang_max_free_spherical (float | TimeDependentFloat): Maximum angular velocity for free and spherical joints.
38
+
39
+ delta_ang_min (float | TimeDependentFloat): Minimum allowed change in joint angle (radians).
40
+ delta_ang_max (float | TimeDependentFloat): Maximum allowed change in joint angle (radians).
41
+ delta_ang_min_free_spherical (float | TimeDependentFloat): Minimum allowed change in angle for free/spherical joints.
42
+ delta_ang_max_free_spherical (float | TimeDependentFloat): Maximum allowed change in angle for free/spherical joints.
43
+
44
+ dpos_min (float | TimeDependentFloat): Minimum translational velocity.
45
+ dpos_max (float | TimeDependentFloat): Maximum translational velocity.
46
+ pos_min (float | TimeDependentFloat): Minimum position constraint.
47
+ pos_max (float | TimeDependentFloat): Maximum position constraint.
48
+
49
+ pos_min_p3d_x (float | TimeDependentFloat): Minimum position in x-direction for P3D joints.
50
+ pos_max_p3d_x (float | TimeDependentFloat): Maximum position in x-direction for P3D joints.
51
+ pos_min_p3d_y (float | TimeDependentFloat): Minimum position in y-direction for P3D joints.
52
+ pos_max_p3d_y (float | TimeDependentFloat): Maximum position in y-direction for P3D joints.
53
+ pos_min_p3d_z (float | TimeDependentFloat): Minimum position in z-direction for P3D joints.
54
+ pos_max_p3d_z (float | TimeDependentFloat): Maximum position in z-direction for P3D joints.
55
+
56
+ cdf_bins_min (int): Minimum number of bins for cumulative distribution function (CDF)-based random sampling.
57
+ cdf_bins_max (Optional[int]): Maximum number of bins for CDF-based sampling.
58
+
59
+ randomized_interpolation_angle (bool): Whether to use randomized interpolation for angular motion.
60
+ randomized_interpolation_position (bool): Whether to use randomized interpolation for positional motion.
61
+ interpolation_method (str): Interpolation method to be used (default: "cosine").
62
+
63
+ range_of_motion_hinge (bool): Whether to enforce range-of-motion constraints on hinge joints.
64
+ range_of_motion_hinge_method (str): Method used for range-of-motion enforcement (e.g., "uniform", "sigmoid").
65
+
66
+ rom_halfsize (float | TimeDependentFloat): Half-size of the range of motion restriction.
67
+
68
+ ang0_min (float): Minimum initial joint angle.
69
+ ang0_max (float): Maximum initial joint angle.
70
+ pos0_min (float): Minimum initial joint position.
71
+ pos0_max (float): Maximum initial joint position.
72
+
73
+ cor_t_min (float): Minimum time step for center-of-rotation (COR) joints.
74
+ cor_t_max (float | TimeDependentFloat): Maximum time step for COR joints.
75
+ cor_dpos_min (float | TimeDependentFloat): Minimum velocity for COR translation.
76
+ cor_dpos_max (float | TimeDependentFloat): Maximum velocity for COR translation.
77
+ cor_pos_min (float | TimeDependentFloat): Minimum position for COR translation.
78
+ cor_pos_max (float | TimeDependentFloat): Maximum position for COR translation.
79
+ cor_pos0_min (float): Initial minimum position for COR translation.
80
+ cor_pos0_max (float): Initial maximum position for COR translation.
81
+
82
+ joint_type_specific_overwrites (dict[str, dict[str, Any]]):
83
+ A dictionary mapping joint types to specific motion configuration overrides.
84
+
85
+ Methods:
86
+ is_feasible:
87
+ Checks if the motion configuration satisfies all constraints.
88
+
89
+ to_nomotion_config:
90
+ Returns a new `MotionConfig` where all velocities and angle changes are set to zero.
91
+
92
+ overwrite_for_joint_type:
93
+ Applies specific configuration changes for a given joint type.
94
+ Note: These changes affect all instances of `MotionConfig` for this joint type.
95
+
96
+ overwrite_for_subsystem:
97
+ Modifies the motion configuration for all joints in a subsystem rooted at `link_name`.
98
+
99
+ from_register:
100
+ Retrieves a predefined `MotionConfig` from the global registry.
101
+ """ # noqa: E501
102
+
22
103
  T: float = 60.0 # length of random motion
23
104
  t_min: float = 0.05 # min time between two generated angles
24
105
  t_max: float | TimeDependentFloat = 0.30 # max time ..
@@ -412,6 +493,30 @@ def _find_interval(t: jax.Array, boundaries: jax.Array):
412
493
  def join_motionconfigs(
413
494
  configs: list[MotionConfig], boundaries: list[float]
414
495
  ) -> MotionConfig:
496
+ """
497
+ Joins multiple `MotionConfig` objects in time, transitioning between them at specified boundaries.
498
+
499
+ This function takes a list of `MotionConfig` instances and a corresponding list of boundary times,
500
+ and constructs a new `MotionConfig` that varies in time according to the provided segments.
501
+
502
+ Args:
503
+ configs (list[MotionConfig]): A list of `MotionConfig` objects to be joined.
504
+ boundaries (list[float]): A list of time values where transitions between `configs` occur.
505
+ Must have one element less than `configs`, as each boundary defines the transition point
506
+ between two consecutive configurations.
507
+
508
+ Returns:
509
+ MotionConfig: A new `MotionConfig` object where time-dependent fields transition based on the
510
+ specified boundaries.
511
+
512
+ Raises:
513
+ AssertionError: If the number of boundaries does not match `len(configs) - 1`.
514
+ AssertionError: If time-independent fields have differing values across `configs`.
515
+
516
+ Notes:
517
+ - Only fields that are time-dependent (`float | TimeDependentFloat`) will change over time.
518
+ - Time-independent fields must be the same in all `configs`, or an error is raised.
519
+ """ # noqa: E501
415
520
  # to avoid a circular import due to `ring.utils.randomize_sys` importing `jcalc`
416
521
  from ring.utils import tree_equal
417
522
 
@@ -517,6 +622,55 @@ INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
517
622
 
518
623
  @dataclass
519
624
  class JointModel:
625
+ """
626
+ Represents the kinematic and dynamic properties of a joint type.
627
+
628
+ A `JointModel` defines the mathematical functions required to compute joint
629
+ transformations, motion, control terms, and inverse kinematics. It is used to
630
+ describe the behavior of various joint types, including revolute, prismatic,
631
+ spherical, and free joints.
632
+
633
+ Attributes:
634
+ transform (Callable[[jax.Array, jax.Array], base.Transform]):
635
+ Computes the transformation (position and orientation) of the joint
636
+ given the joint state `q` and joint parameters.
637
+
638
+ motion (list[base.Motion | Callable[[jax.Array], base.Motion]]):
639
+ Defines the joint motion model. It can be a list of `Motion` objects
640
+ or callables that return `Motion` based on joint parameters.
641
+
642
+ rcmg_draw_fn (Optional[DRAW_FN]):
643
+ Function used to generate a reference motion trajectory for the joint
644
+ using Randomized Control Motion Generation (RCMG).
645
+
646
+ p_control_term (Optional[P_CONTROL_TERM]):
647
+ Function that computes the proportional control term for the joint.
648
+
649
+ qd_from_q (Optional[QD_FROM_Q]):
650
+ Function to compute joint velocity (`qd`) from joint positions (`q`).
651
+
652
+ coordinate_vector_to_q (Optional[COORDINATE_VECTOR_TO_Q]):
653
+ Function that maps a coordinate vector to a valid joint state `q`,
654
+ ensuring constraints (e.g., wrapping angles or normalizing quaternions).
655
+
656
+ inv_kin (Optional[INV_KIN]):
657
+ Function that computes the inverse kinematics for the joint, mapping
658
+ a desired transform to joint coordinates `q`.
659
+
660
+ init_joint_params (Optional[INIT_JOINT_PARAMS]):
661
+ Function that initializes joint-specific parameters.
662
+
663
+ utilities (Optional[dict[str, Any]]):
664
+ Additional utility functions or metadata related to the joint model.
665
+
666
+ Notes:
667
+ - The `transform` function is essential for computing the joint's spatial
668
+ transformation based on its generalized coordinates.
669
+ - The `motion` attribute describes how forces and torques affect the joint.
670
+ - The `rcmg_draw_fn` is used for RCMG motion generation.
671
+ - The `coordinate_vector_to_q` is critical for maintaining valid joint states.
672
+ """ # noqa: E501
673
+
520
674
  # (q, params) -> Transform
521
675
  transform: Callable[[jax.Array, jax.Array], base.Transform]
522
676
  # len(motion) == len(qd)
@@ -1079,6 +1233,50 @@ def register_new_joint_type(
1079
1233
  qd_width: Optional[int] = None,
1080
1234
  overwrite: bool = False,
1081
1235
  ):
1236
+ """
1237
+ Registers a new joint type with its corresponding `JointModel` and kinematic properties.
1238
+
1239
+ This function allows the addition of custom joint types to the system by associating
1240
+ them with a `JointModel`, specifying their state and velocity dimensions, and optionally
1241
+ overwriting existing joint definitions.
1242
+
1243
+ Args:
1244
+ joint_type (str):
1245
+ Name of the new joint type to register.
1246
+ joint_model (JointModel):
1247
+ The `JointModel` instance defining the kinematic and dynamic properties of the joint.
1248
+ q_width (int):
1249
+ Number of generalized coordinates (degrees of freedom) required to represent the joint.
1250
+ qd_width (Optional[int], default=None):
1251
+ Number of velocity coordinates associated with the joint. Defaults to `q_width`.
1252
+ overwrite (bool, default=False):
1253
+ If `True`, allows overwriting an existing joint type. Otherwise, raises an error if
1254
+ the joint type already exists.
1255
+
1256
+ Raises:
1257
+ AssertionError:
1258
+ - If `joint_type` is `"default"` (reserved name).
1259
+ - If `joint_type` already exists and `overwrite=False`.
1260
+ - If `qd_width` is not provided and does not default to `q_width`.
1261
+ - If `joint_model.motion` length does not match `qd_width`.
1262
+
1263
+ Notes:
1264
+ - The function updates global dictionaries that store joint properties, including:
1265
+ - `_joint_types`: Maps joint type names to `JointModel` instances.
1266
+ - `base.Q_WIDTHS`: Stores the number of state coordinates for each joint type.
1267
+ - `base.QD_WIDTHS`: Stores the number of velocity coordinates for each joint type.
1268
+ - If `overwrite=True`, existing entries are removed before adding the new joint type.
1269
+ - Ensures consistency between motion definitions and velocity coordinate dimensions.
1270
+
1271
+ Example:
1272
+ ```python
1273
+ new_joint = JointModel(
1274
+ transform=my_transform_fn,
1275
+ motion=[base.Motion.create(ang=jnp.array([1, 0, 0]))],
1276
+ )
1277
+ register_new_joint_type("custom_hinge", new_joint, q_width=1)
1278
+ ```
1279
+ """ # noqa: E501
1082
1280
  # this name is used
1083
1281
  assert joint_type != "default", "Please use another name."
1084
1282
 
ring/base.py CHANGED
@@ -112,17 +112,71 @@ class _Base:
112
112
 
113
113
  @struct.dataclass
114
114
  class Transform(_Base):
115
- """Represents the Transformation from Plücker A to Plücker B,
116
- where B is located relative to A at `pos` in frame A and `rot` is the
117
- relative quaternion from A to B.
118
- Create using `Transform.create(pos=..., rot=...)
119
115
  """
116
+ Represents a spatial transformation between two coordinate frames using Plücker coordinates.
117
+
118
+ The `Transform` class defines the relative position and orientation of one frame (`B`)
119
+ with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
120
+ of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
121
+ rotation from frame `A` to frame `B`.
122
+
123
+ Attributes:
124
+ pos (Vector):
125
+ The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
126
+ Shape: `(..., 3)`, where `...` represents optional batch dimensions.
127
+ rot (Quaternion):
128
+ The unit quaternion representing the orientation of `B` relative to `A`.
129
+ Shape: `(..., 4)`, where `...` represents optional batch dimensions.
130
+
131
+ Methods:
132
+ create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
133
+ Creates a `Transform` instance with optional position and rotation.
134
+
135
+ zero(shape: Sequence[int] = ()) -> Transform:
136
+ Returns a zero transform with a given batch shape.
137
+
138
+ as_matrix() -> jax.Array:
139
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
140
+
141
+ Usage:
142
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
143
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
144
+ >>> T = Transform.create(pos, rot)
145
+ >>> print(T.pos) # Output: [1. 2. 3.]
146
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
147
+ >>> print(T.as_matrix()) # 4x4 transformation matrix
148
+ """ # noqa: E501
120
149
 
121
150
  pos: Vector
122
151
  rot: Quaternion
123
152
 
124
153
  @classmethod
125
154
  def create(cls, pos=None, rot=None):
155
+ """
156
+ Creates a `Transform` instance with the specified position and rotation.
157
+
158
+ At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
159
+ defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
160
+ defaults to `[0, 0, 0]`.
161
+
162
+ Args:
163
+ pos (Optional[Vector], default=None):
164
+ The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
165
+ If `None`, defaults to a zero vector of shape `(3,)`.
166
+ rot (Optional[Quaternion], default=None):
167
+ The unit quaternion representing the orientation of `B` relative to `A`.
168
+ If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
169
+
170
+ Returns:
171
+ Transform: A new `Transform` instance with the specified position and rotation.
172
+
173
+ Example:
174
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
175
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
176
+ >>> T = Transform.create(pos, rot)
177
+ >>> print(T.pos) # Output: [1. 2. 3.]
178
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
179
+ """ # noqa: E501
126
180
  assert not (pos is None and rot is None), "One must be given."
127
181
  shape_rot = rot.shape[:-1] if rot is not None else ()
128
182
  shape_pos = pos.shape[:-1] if pos is not None else ()
@@ -139,12 +193,49 @@ class Transform(_Base):
139
193
 
140
194
  @classmethod
141
195
  def zero(cls, shape=()) -> "Transform":
142
- """Returns a zero transform with a batch shape."""
196
+ """
197
+ Returns a zero transform with a given batch shape.
198
+
199
+ This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
200
+ which represents no translation or rotation.
201
+
202
+ Args:
203
+ shape (Sequence[int], default=()):
204
+ The batch shape for the transform. Defaults to a scalar transform.
205
+
206
+ Returns:
207
+ Transform: A zero transform with the specified batch shape.
208
+
209
+ Example:
210
+ >>> T = Transform.zero()
211
+ >>> print(T.pos) # Output: [0. 0. 0.]
212
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
213
+ """ # noqa: E501
143
214
  pos = jnp.zeros(shape + (3,))
144
215
  rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
145
216
  return Transform(pos, rot)
146
217
 
147
218
  def as_matrix(self) -> jax.Array:
219
+ """
220
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
221
+
222
+ The homogeneous transformation matrix is defined as:
223
+
224
+ ```
225
+ [ R t ]
226
+ [ 0 1 ]
227
+ ```
228
+
229
+ where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
230
+ 3x1 position vector.
231
+
232
+ Returns:
233
+ jax.Array: A `(4, 4)` homogeneous transformation matrix.
234
+
235
+ Example:
236
+ >>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
237
+ >>> print(T.as_matrix()) # Output: 4x4 matrix
238
+ """ # noqa: E501
148
239
  E = maths.quat_to_3x3(self.rot)
149
240
  return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
150
241
 
@@ -402,7 +493,175 @@ QD_WIDTHS = {
402
493
 
403
494
  @struct.dataclass
404
495
  class System(_Base):
405
- "System object. Create using `System.create(path_xml)`"
496
+ """
497
+ Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
498
+
499
+ The `System` class models the kinematic and dynamic properties of a multibody
500
+ system, providing methods for state representation, transformations, joint
501
+ configuration management, and rendering. It supports both minimal and maximal
502
+ coordinate representations and can be parsed from or saved to XML files.
503
+
504
+ Attributes:
505
+ link_parents (list[int]):
506
+ A list specifying the parent index for each link. The root link has a parent index of `-1`.
507
+ links (Link):
508
+ A data structure containing information about all links in the system.
509
+ link_types (list[str]):
510
+ A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
511
+ link_damping (jax.Array):
512
+ Joint damping coefficients for each link.
513
+ link_armature (jax.Array):
514
+ Armature inertia values for each joint.
515
+ link_spring_stiffness (jax.Array):
516
+ Stiffness values for joint springs.
517
+ link_spring_zeropoint (jax.Array):
518
+ Rest position of joint springs.
519
+ dt (float):
520
+ Simulation time step size.
521
+ geoms (list[Geometry]):
522
+ List of geometries associated with the system.
523
+ gravity (jax.Array):
524
+ Gravity vector applied to the system (default: `[0, 0, -9.81]`).
525
+ integration_method (str):
526
+ Integration method for simulation (default: "semi_implicit_euler").
527
+ mass_mat_iters (int):
528
+ Number of iterations for mass matrix calculations.
529
+ link_names (list[str]):
530
+ Names of the links in the system.
531
+ model_name (Optional[str]):
532
+ Name of the system model (if available).
533
+ omc (list[MaxCoordOMC | None]):
534
+ List of optional Maximal Coordinate representations.
535
+
536
+ Methods:
537
+ num_links() -> int:
538
+ Returns the number of links in the system.
539
+
540
+ q_size() -> int:
541
+ Returns the total number of generalized coordinates (`q`) in the system.
542
+
543
+ qd_size() -> int:
544
+ Returns the total number of generalized velocities (`qd`) in the system.
545
+
546
+ name_to_idx(name: str) -> int:
547
+ Returns the index of a link given its name.
548
+
549
+ idx_to_name(idx: int, allow_world: bool = False) -> str:
550
+ Returns the name of a link given its index. If `allow_world` is `True`,
551
+ returns `"world"` for index `-1`.
552
+
553
+ idx_map(type: str) -> dict:
554
+ Returns a dictionary mapping link names to their indices for a specified type
555
+ (`"l"`, `"q"`, or `"d"`).
556
+
557
+ parent_name(name: str) -> str:
558
+ Returns the name of the parent link for a given link.
559
+
560
+ change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
561
+ Changes the name of the system model.
562
+
563
+ change_link_name(old_name: str, new_name: str) -> "System":
564
+ Renames a specific link.
565
+
566
+ add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
567
+ Adds both a prefix and suffix to all link names.
568
+
569
+ freeze(name: str | list[str]) -> "System":
570
+ Freezes the specified link(s), making them immovable.
571
+
572
+ unfreeze(name: str, new_joint_type: str) -> "System":
573
+ Unfreezes a frozen link and assigns it a new joint type.
574
+
575
+ change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
576
+ Changes the joint type of a specified link.
577
+
578
+ joint_type_simplification(typ: str) -> str:
579
+ Returns a simplified representation of the given joint type.
580
+
581
+ joint_type_is_free_or_cor(typ: str) -> bool:
582
+ Checks if a joint type is either "free" or "cor".
583
+
584
+ joint_type_is_spherical(typ: str) -> bool:
585
+ Checks if a joint type is "spherical".
586
+
587
+ joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
588
+ Checks if a joint type is "free", "cor", or "spherical".
589
+
590
+ findall_imus(names: bool = True) -> list[str] | list[int]:
591
+ Finds all IMU sensors in the system.
592
+
593
+ findall_segments(names: bool = True) -> list[str] | list[int]:
594
+ Finds all non-IMU segments in the system.
595
+
596
+ findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
597
+ Returns all bodies directly connected to the world.
598
+
599
+ find_body_to_world(name: bool = False) -> int | str:
600
+ Returns the root body connected to the world.
601
+
602
+ findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
603
+ Returns all bodies with the specified joint type.
604
+
605
+ children(name: str, names: bool = False) -> list[int] | list[str]:
606
+ Returns the direct children of a given body.
607
+
608
+ findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
609
+ Finds all bodies in the subsystem rooted at a given link.
610
+
611
+ scan(f: Callable, in_types: str, *args, reverse: bool = False):
612
+ Iterates over system elements while applying a function.
613
+
614
+ parse() -> "System":
615
+ Parses the system, performing consistency checks and computing spatial inertia tensors.
616
+
617
+ render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
618
+ Renders frames of the system using maximal coordinates.
619
+
620
+ render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
621
+ Renders a predicted state transformation.
622
+
623
+ delete_system(link_name: str | list[str], strict: bool = True):
624
+ Removes a subsystem from the system.
625
+
626
+ make_sys_noimu(imu_link_names: Optional[list[str]] = None):
627
+ Returns a version of the system without IMU sensors.
628
+
629
+ inject_system(other_system: "System", at_body: Optional[str] = None):
630
+ Merges another system into this one.
631
+
632
+ morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
633
+ Reorders the system’s link hierarchy.
634
+
635
+ from_xml(path: str, seed: int = 1) -> "System":
636
+ Loads a system from an XML file.
637
+
638
+ from_str(xml: str, seed: int = 1) -> "System":
639
+ Loads a system from an XML string.
640
+
641
+ to_str(warn: bool = True) -> str:
642
+ Serializes the system to an XML string.
643
+
644
+ to_xml(path: str) -> None:
645
+ Saves the system as an XML file.
646
+
647
+ create(path_or_str: str, seed: int = 1) -> "System":
648
+ Creates a `System` instance from an XML file or string.
649
+
650
+ coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
651
+ Converts a coordinate vector to minimal coordinates (`q`), applying
652
+ constraints such as quaternion normalization.
653
+
654
+ Raises:
655
+ AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
656
+ InvalidSystemError: If an operation results in an inconsistent system state.
657
+
658
+ Notes:
659
+ - The system must be parsed before use to ensure consistency.
660
+ - The system supports batch operations using JAX for efficient computations.
661
+ - Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
662
+ - Inertial properties of links are computed automatically from associated geometries.
663
+ """ # noqa: E501
664
+
406
665
  link_parents: list[int] = struct.field(False)
407
666
  links: Link
408
667
  link_types: list[str] = struct.field(False)
@@ -429,25 +688,32 @@ class System(_Base):
429
688
  omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
430
689
 
431
690
  def num_links(self) -> int:
691
+ "Returns the number of links in the system."
432
692
  return len(self.link_parents)
433
693
 
434
694
  def q_size(self) -> int:
695
+ "Returns the total number of generalized coordinates (`q`) in the system."
435
696
  return sum([Q_WIDTHS[typ] for typ in self.link_types])
436
697
 
437
698
  def qd_size(self) -> int:
699
+ "Returns the total number of generalized velocities (`qd`) in the system."
438
700
  return sum([QD_WIDTHS[typ] for typ in self.link_types])
439
701
 
440
702
  def name_to_idx(self, name: str) -> int:
703
+ "Returns the index of a link given its name."
441
704
  return self.link_names.index(name)
442
705
 
443
706
  def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
707
+ """Returns the name of a link given its index. If `allow_world` is `True`,
708
+ returns `"world"` for index `-1`."""
444
709
  if allow_world and idx == -1:
445
710
  return "world"
446
711
  assert idx >= 0, "Worldbody index has no name."
447
712
  return self.link_names[idx]
448
713
 
449
714
  def idx_map(self, type: str) -> dict:
450
- "type: is either `l` or `q` or `d`"
715
+ """Returns a dictionary mapping link names to their indices for a specified type
716
+ (`"l"`, `"q"`, or `"d"`)."""
451
717
  dict_int_slices = {}
452
718
 
453
719
  def f(_, idx_map, name: str, link_idx: int):
@@ -458,10 +724,11 @@ class System(_Base):
458
724
  return dict_int_slices
459
725
 
460
726
  def parent_name(self, name: str) -> str:
727
+ "Returns the name of the parent link for a given link."
461
728
  return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
462
729
 
463
730
  def add_prefix(self, prefix: str = "") -> "System":
464
- return self.replace(link_names=[prefix + name for name in self.link_names])
731
+ return self.add_prefix_suffix(prefix=prefix)
465
732
 
466
733
  def change_model_name(
467
734
  self,
@@ -469,6 +736,7 @@ class System(_Base):
469
736
  prefix: Optional[str] = None,
470
737
  suffix: Optional[str] = None,
471
738
  ) -> "System":
739
+ "Changes the name of the system model."
472
740
  if prefix is None:
473
741
  prefix = ""
474
742
  if suffix is None:
@@ -479,6 +747,7 @@ class System(_Base):
479
747
  return self.replace(model_name=name)
480
748
 
481
749
  def change_link_name(self, old_name: str, new_name: str) -> "System":
750
+ "Renames a specific link."
482
751
  old_idx = self.name_to_idx(old_name)
483
752
  new_link_names = self.link_names.copy()
484
753
  new_link_names[old_idx] = new_name
@@ -487,6 +756,7 @@ class System(_Base):
487
756
  def add_prefix_suffix(
488
757
  self, prefix: Optional[str] = None, suffix: Optional[str] = None
489
758
  ) -> "System":
759
+ "Adds either or, or both a prefix and suffix to all link names."
490
760
  if prefix is None:
491
761
  prefix = ""
492
762
  if suffix is None:
@@ -526,6 +796,7 @@ class System(_Base):
526
796
  return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
527
797
 
528
798
  def freeze(self, name: str | list[str]):
799
+ "Freezes the specified link(s), making them immovable (uses `frozen` joint)"
529
800
  if isinstance(name, list):
530
801
  sys = self
531
802
  for n in name:
@@ -544,6 +815,7 @@ class System(_Base):
544
815
  return _update_sys_if_replace_joint_type(self, logic_freeze)
545
816
 
546
817
  def unfreeze(self, name: str, new_joint_type: str):
818
+ "Unfreezes a frozen link and assigns it a new joint type."
547
819
  assert self.link_types[self.name_to_idx(name)] == "frozen"
548
820
  assert new_joint_type != "frozen"
549
821
 
@@ -560,7 +832,8 @@ class System(_Base):
560
832
  seed: int = 1,
561
833
  warn: bool = True,
562
834
  ):
563
- "By default damping, stiffness are set to zero."
835
+ """Changes the joint type of a specified link.
836
+ By default damping, stiffness are set to zero."""
564
837
  from ring.algorithms import get_joint_model
565
838
 
566
839
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
@@ -594,6 +867,7 @@ class System(_Base):
594
867
 
595
868
  @staticmethod
596
869
  def joint_type_simplification(typ: str) -> str:
870
+ "Returns a simplified name of the given joint type."
597
871
  if typ[:4] == "free":
598
872
  if typ == "free_2d":
599
873
  return "free_2d"
@@ -608,23 +882,28 @@ class System(_Base):
608
882
 
609
883
  @staticmethod
610
884
  def joint_type_is_free_or_cor(typ: str) -> bool:
885
+ 'Checks if a joint type is either "free" or "cor".'
611
886
  return System.joint_type_simplification(typ) in ["free", "cor"]
612
887
 
613
888
  @staticmethod
614
889
  def joint_type_is_spherical(typ: str) -> bool:
890
+ 'Checks if a joint type is "spherical".'
615
891
  return System.joint_type_simplification(typ) == "spherical"
616
892
 
617
893
  @staticmethod
618
894
  def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
895
+ 'Checks if a joint type is "free", "cor", or "spherical".'
619
896
  return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
620
897
  typ
621
898
  )
622
899
 
623
900
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
901
+ "Finds all IMU sensors in the system."
624
902
  bodies = [name for name in self.link_names if name[:3] == "imu"]
625
903
  return bodies if names else [self.name_to_idx(n) for n in bodies]
626
904
 
627
905
  def findall_segments(self, names: bool = True) -> list[str] | list[int]:
906
+ "Finds all non-IMU segments in the system."
628
907
  imus = self.findall_imus(names=True)
629
908
  bodies = [name for name in self.link_names if name not in imus]
630
909
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -633,10 +912,12 @@ class System(_Base):
633
912
  return [self.idx_to_name(i) for i in bodies]
634
913
 
635
914
  def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
915
+ "Returns all bodies directly connected to the world."
636
916
  bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
637
917
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
638
918
 
639
919
  def find_body_to_world(self, name: bool = False) -> int | str:
920
+ "Returns the root body connected to the world."
640
921
  bodies = self.findall_bodies_to_world(names=name)
641
922
  assert len(bodies) == 1
642
923
  return bodies[0]
@@ -644,6 +925,7 @@ class System(_Base):
644
925
  def findall_bodies_with_jointtype(
645
926
  self, typ: str, names: bool = False
646
927
  ) -> list[int] | list[str]:
928
+ "Returns all bodies with the specified joint type."
647
929
  bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
648
930
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
649
931
 
@@ -781,20 +1063,25 @@ class System(_Base):
781
1063
 
782
1064
  @staticmethod
783
1065
  def from_xml(path: str, seed: int = 1):
1066
+ "Loads a system from an XML file."
784
1067
  return ring.io.load_sys_from_xml(path, seed)
785
1068
 
786
1069
  @staticmethod
787
1070
  def from_str(xml: str, seed: int = 1):
1071
+ "Loads a system from an XML string."
788
1072
  return ring.io.load_sys_from_str(xml, seed)
789
1073
 
790
1074
  def to_str(self, warn: bool = True) -> str:
1075
+ "Serializes the system to an XML string."
791
1076
  return ring.io.save_sys_to_str(self, warn=warn)
792
1077
 
793
1078
  def to_xml(self, path: str) -> None:
1079
+ "Saves the system to an XML file."
794
1080
  ring.io.save_sys_to_xml(self, path)
795
1081
 
796
1082
  @classmethod
797
1083
  def create(cls, path_or_str: str, seed: int = 1) -> "System":
1084
+ "Creates a `System` instance from an XML file or string."
798
1085
  path = Path(path_or_str).with_suffix(".xml")
799
1086
 
800
1087
  exists = False
@@ -814,7 +1101,8 @@ class System(_Base):
814
1101
  q: jax.Array,
815
1102
  custom_joints: dict[str, Callable] = {},
816
1103
  ) -> jax.Array:
817
- """Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
1104
+ """Converts a coordinate vector to minimal coordinates (`q`), applying
1105
+ constraints such as quaternion normalization."""
818
1106
  # Does, e.g.
819
1107
  # - normalize quaternions
820
1108
  # - hinge joints in [-pi, pi]
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
1026
1314
 
1027
1315
  @struct.dataclass
1028
1316
  class State(_Base):
1029
- """The static and dynamic state of a system in minimal and maximal coordinates.
1030
- Use `.create()` to create this object.
1031
-
1032
- Args:
1033
- q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
1034
- qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
1035
- x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
1036
1317
  """
1318
+ Represents the state of a dynamic system in minimal and maximal coordinates.
1319
+
1320
+ The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
1321
+ of the system in minimal coordinates, as well as the corresponding transforms (`x`)
1322
+ in maximal coordinates.
1323
+
1324
+ Attributes:
1325
+ q (jax.Array):
1326
+ The joint positions (generalized coordinates) of the system. The size
1327
+ of `q` matches `sys.q_size()`.
1328
+ qd (jax.Array):
1329
+ The joint velocities (generalized velocities) of the system. The size
1330
+ of `qd` matches `sys.qd_size()`.
1331
+ x (Transform):
1332
+ The maximal coordinate representation of all system links, expressed as
1333
+ a `Transform` object.
1334
+
1335
+ Methods:
1336
+ create(sys: System, q: Optional[jax.Array] = None,
1337
+ qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
1338
+ key: Optional[jax.Array] = None,
1339
+ custom_joints: dict[str, Callable] = {}) -> State:
1340
+ Creates a `State` instance for a given system with optional initial conditions.
1341
+
1342
+ Usage:
1343
+ >>> sys = System.create("model.xml")
1344
+ >>> state = State.create(sys)
1345
+ >>> print(state.q.shape) # Should match sys.q_size()
1346
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1347
+ """ # noqa: E501
1037
1348
 
1038
1349
  q: jax.Array
1039
1350
  qd: jax.Array
@@ -1048,19 +1359,37 @@ class State(_Base):
1048
1359
  x: Optional[Transform] = None,
1049
1360
  key: Optional[jax.Array] = None,
1050
1361
  custom_joints: dict[str, Callable] = {},
1051
- ):
1052
- """Create state of system.
1362
+ ) -> "State":
1363
+ """
1364
+ Creates a `State` instance for the given system with optional initial conditions.
1365
+
1366
+ If no initial values are provided, joint positions (`q`) and velocities (`qd`)
1367
+ are initialized to zero, except for free and spherical joints, which have unit quaternions.
1053
1368
 
1054
1369
  Args:
1055
- sys (System): The system for which to create a state.
1056
- q (jax.Array, optional): The joint values of the system. Defaults to None.
1057
- Which then defaults to zeros.
1058
- qd (jax.Array, optional): The joint velocities of the system.
1059
- Defaults to None. Which then defaults to zeros.
1370
+ sys (System):
1371
+ The system for which to create a state.
1372
+ q (Optional[jax.Array], default=None):
1373
+ Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
1374
+ for free and spherical joints.
1375
+ qd (Optional[jax.Array], default=None):
1376
+ Initial joint velocities. If `None`, defaults to zeros.
1377
+ x (Optional[Transform], default=None):
1378
+ Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
1379
+ key (Optional[jax.Array], default=None):
1380
+ Random key for initializing `q` if no values are provided.
1381
+ custom_joints (dict[str, Callable], default={}):
1382
+ Custom joint functions for mapping coordinate vectors to minimal coordinates.
1060
1383
 
1061
1384
  Returns:
1062
- (State): Create State object.
1063
- """
1385
+ State: A new instance of the `State` class representing the initialized system state.
1386
+
1387
+ Example:
1388
+ >>> sys = System.create("model.xml")
1389
+ >>> state = State.create(sys)
1390
+ >>> print(state.q.shape) # Should match sys.q_size()
1391
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1392
+ """ # noqa: E501
1064
1393
  if key is not None:
1065
1394
  assert q is None
1066
1395
  q = jax.random.normal(key, shape=(sys.q_size(),))
ring/ml/base.py CHANGED
@@ -297,7 +297,8 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
297
297
 
298
298
  if self._quat_normalize:
299
299
  assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
300
- yhat = ring.maths.safe_normalize(yhat)
300
+ # yhat = ring.maths.safe_normalize(yhat)
301
+ yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
301
302
 
302
303
  return yhat, state
303
304