imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/METADATA +1 -1
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/RECORD +27 -27
- ring/algorithms/custom_joints/suntay.py +1 -1
- ring/algorithms/dynamics.py +27 -1
- ring/algorithms/generator/base.py +82 -2
- ring/algorithms/generator/batch.py +2 -2
- ring/algorithms/generator/finalize_fns.py +1 -1
- ring/algorithms/generator/pd_control.py +1 -1
- ring/algorithms/jcalc.py +198 -0
- ring/algorithms/kinematics.py +2 -1
- ring/algorithms/sensors.py +12 -10
- ring/base.py +356 -27
- ring/io/xml/from_xml.py +1 -1
- ring/ml/base.py +4 -3
- ring/ml/ml_utils.py +3 -3
- ring/ml/ringnet.py +1 -1
- ring/ml/train.py +2 -2
- ring/rendering/mujoco_render.py +11 -7
- ring/rendering/vispy_render.py +5 -4
- ring/sys_composer/inject_sys.py +3 -2
- ring/utils/batchsize.py +3 -3
- ring/utils/dataloader.py +4 -3
- ring/utils/dataloader_torch.py +14 -5
- ring/utils/hdf5.py +1 -1
- ring/utils/normalizer.py +6 -5
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/top_level.txt +0 -0
ring/base.py
CHANGED
@@ -112,17 +112,71 @@ class _Base:
|
|
112
112
|
|
113
113
|
@struct.dataclass
|
114
114
|
class Transform(_Base):
|
115
|
-
"""Represents the Transformation from Plücker A to Plücker B,
|
116
|
-
where B is located relative to A at `pos` in frame A and `rot` is the
|
117
|
-
relative quaternion from A to B.
|
118
|
-
Create using `Transform.create(pos=..., rot=...)
|
119
115
|
"""
|
116
|
+
Represents a spatial transformation between two coordinate frames using Plücker coordinates.
|
117
|
+
|
118
|
+
The `Transform` class defines the relative position and orientation of one frame (`B`)
|
119
|
+
with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
|
120
|
+
of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
|
121
|
+
rotation from frame `A` to frame `B`.
|
122
|
+
|
123
|
+
Attributes:
|
124
|
+
pos (Vector):
|
125
|
+
The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
|
126
|
+
Shape: `(..., 3)`, where `...` represents optional batch dimensions.
|
127
|
+
rot (Quaternion):
|
128
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
129
|
+
Shape: `(..., 4)`, where `...` represents optional batch dimensions.
|
130
|
+
|
131
|
+
Methods:
|
132
|
+
create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
|
133
|
+
Creates a `Transform` instance with optional position and rotation.
|
134
|
+
|
135
|
+
zero(shape: Sequence[int] = ()) -> Transform:
|
136
|
+
Returns a zero transform with a given batch shape.
|
137
|
+
|
138
|
+
as_matrix() -> jax.Array:
|
139
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
140
|
+
|
141
|
+
Usage:
|
142
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
143
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
144
|
+
>>> T = Transform.create(pos, rot)
|
145
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
146
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
147
|
+
>>> print(T.as_matrix()) # 4x4 transformation matrix
|
148
|
+
""" # noqa: E501
|
120
149
|
|
121
150
|
pos: Vector
|
122
151
|
rot: Quaternion
|
123
152
|
|
124
153
|
@classmethod
|
125
154
|
def create(cls, pos=None, rot=None):
|
155
|
+
"""
|
156
|
+
Creates a `Transform` instance with the specified position and rotation.
|
157
|
+
|
158
|
+
At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
|
159
|
+
defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
|
160
|
+
defaults to `[0, 0, 0]`.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
pos (Optional[Vector], default=None):
|
164
|
+
The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
|
165
|
+
If `None`, defaults to a zero vector of shape `(3,)`.
|
166
|
+
rot (Optional[Quaternion], default=None):
|
167
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
168
|
+
If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Transform: A new `Transform` instance with the specified position and rotation.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
175
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
176
|
+
>>> T = Transform.create(pos, rot)
|
177
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
178
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
179
|
+
""" # noqa: E501
|
126
180
|
assert not (pos is None and rot is None), "One must be given."
|
127
181
|
shape_rot = rot.shape[:-1] if rot is not None else ()
|
128
182
|
shape_pos = pos.shape[:-1] if pos is not None else ()
|
@@ -139,12 +193,49 @@ class Transform(_Base):
|
|
139
193
|
|
140
194
|
@classmethod
|
141
195
|
def zero(cls, shape=()) -> "Transform":
|
142
|
-
"""
|
196
|
+
"""
|
197
|
+
Returns a zero transform with a given batch shape.
|
198
|
+
|
199
|
+
This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
|
200
|
+
which represents no translation or rotation.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
shape (Sequence[int], default=()):
|
204
|
+
The batch shape for the transform. Defaults to a scalar transform.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Transform: A zero transform with the specified batch shape.
|
208
|
+
|
209
|
+
Example:
|
210
|
+
>>> T = Transform.zero()
|
211
|
+
>>> print(T.pos) # Output: [0. 0. 0.]
|
212
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
213
|
+
""" # noqa: E501
|
143
214
|
pos = jnp.zeros(shape + (3,))
|
144
215
|
rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
|
145
216
|
return Transform(pos, rot)
|
146
217
|
|
147
218
|
def as_matrix(self) -> jax.Array:
|
219
|
+
"""
|
220
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
221
|
+
|
222
|
+
The homogeneous transformation matrix is defined as:
|
223
|
+
|
224
|
+
```
|
225
|
+
[ R t ]
|
226
|
+
[ 0 1 ]
|
227
|
+
```
|
228
|
+
|
229
|
+
where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
|
230
|
+
3x1 position vector.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
jax.Array: A `(4, 4)` homogeneous transformation matrix.
|
234
|
+
|
235
|
+
Example:
|
236
|
+
>>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
|
237
|
+
>>> print(T.as_matrix()) # Output: 4x4 matrix
|
238
|
+
""" # noqa: E501
|
148
239
|
E = maths.quat_to_3x3(self.rot)
|
149
240
|
return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
|
150
241
|
|
@@ -402,7 +493,175 @@ QD_WIDTHS = {
|
|
402
493
|
|
403
494
|
@struct.dataclass
|
404
495
|
class System(_Base):
|
405
|
-
"
|
496
|
+
"""
|
497
|
+
Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
|
498
|
+
|
499
|
+
The `System` class models the kinematic and dynamic properties of a multibody
|
500
|
+
system, providing methods for state representation, transformations, joint
|
501
|
+
configuration management, and rendering. It supports both minimal and maximal
|
502
|
+
coordinate representations and can be parsed from or saved to XML files.
|
503
|
+
|
504
|
+
Attributes:
|
505
|
+
link_parents (list[int]):
|
506
|
+
A list specifying the parent index for each link. The root link has a parent index of `-1`.
|
507
|
+
links (Link):
|
508
|
+
A data structure containing information about all links in the system.
|
509
|
+
link_types (list[str]):
|
510
|
+
A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
|
511
|
+
link_damping (jax.Array):
|
512
|
+
Joint damping coefficients for each link.
|
513
|
+
link_armature (jax.Array):
|
514
|
+
Armature inertia values for each joint.
|
515
|
+
link_spring_stiffness (jax.Array):
|
516
|
+
Stiffness values for joint springs.
|
517
|
+
link_spring_zeropoint (jax.Array):
|
518
|
+
Rest position of joint springs.
|
519
|
+
dt (float):
|
520
|
+
Simulation time step size.
|
521
|
+
geoms (list[Geometry]):
|
522
|
+
List of geometries associated with the system.
|
523
|
+
gravity (jax.Array):
|
524
|
+
Gravity vector applied to the system (default: `[0, 0, -9.81]`).
|
525
|
+
integration_method (str):
|
526
|
+
Integration method for simulation (default: "semi_implicit_euler").
|
527
|
+
mass_mat_iters (int):
|
528
|
+
Number of iterations for mass matrix calculations.
|
529
|
+
link_names (list[str]):
|
530
|
+
Names of the links in the system.
|
531
|
+
model_name (Optional[str]):
|
532
|
+
Name of the system model (if available).
|
533
|
+
omc (list[MaxCoordOMC | None]):
|
534
|
+
List of optional Maximal Coordinate representations.
|
535
|
+
|
536
|
+
Methods:
|
537
|
+
num_links() -> int:
|
538
|
+
Returns the number of links in the system.
|
539
|
+
|
540
|
+
q_size() -> int:
|
541
|
+
Returns the total number of generalized coordinates (`q`) in the system.
|
542
|
+
|
543
|
+
qd_size() -> int:
|
544
|
+
Returns the total number of generalized velocities (`qd`) in the system.
|
545
|
+
|
546
|
+
name_to_idx(name: str) -> int:
|
547
|
+
Returns the index of a link given its name.
|
548
|
+
|
549
|
+
idx_to_name(idx: int, allow_world: bool = False) -> str:
|
550
|
+
Returns the name of a link given its index. If `allow_world` is `True`,
|
551
|
+
returns `"world"` for index `-1`.
|
552
|
+
|
553
|
+
idx_map(type: str) -> dict:
|
554
|
+
Returns a dictionary mapping link names to their indices for a specified type
|
555
|
+
(`"l"`, `"q"`, or `"d"`).
|
556
|
+
|
557
|
+
parent_name(name: str) -> str:
|
558
|
+
Returns the name of the parent link for a given link.
|
559
|
+
|
560
|
+
change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
561
|
+
Changes the name of the system model.
|
562
|
+
|
563
|
+
change_link_name(old_name: str, new_name: str) -> "System":
|
564
|
+
Renames a specific link.
|
565
|
+
|
566
|
+
add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
567
|
+
Adds both a prefix and suffix to all link names.
|
568
|
+
|
569
|
+
freeze(name: str | list[str]) -> "System":
|
570
|
+
Freezes the specified link(s), making them immovable.
|
571
|
+
|
572
|
+
unfreeze(name: str, new_joint_type: str) -> "System":
|
573
|
+
Unfreezes a frozen link and assigns it a new joint type.
|
574
|
+
|
575
|
+
change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
|
576
|
+
Changes the joint type of a specified link.
|
577
|
+
|
578
|
+
joint_type_simplification(typ: str) -> str:
|
579
|
+
Returns a simplified representation of the given joint type.
|
580
|
+
|
581
|
+
joint_type_is_free_or_cor(typ: str) -> bool:
|
582
|
+
Checks if a joint type is either "free" or "cor".
|
583
|
+
|
584
|
+
joint_type_is_spherical(typ: str) -> bool:
|
585
|
+
Checks if a joint type is "spherical".
|
586
|
+
|
587
|
+
joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
588
|
+
Checks if a joint type is "free", "cor", or "spherical".
|
589
|
+
|
590
|
+
findall_imus(names: bool = True) -> list[str] | list[int]:
|
591
|
+
Finds all IMU sensors in the system.
|
592
|
+
|
593
|
+
findall_segments(names: bool = True) -> list[str] | list[int]:
|
594
|
+
Finds all non-IMU segments in the system.
|
595
|
+
|
596
|
+
findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
|
597
|
+
Returns all bodies directly connected to the world.
|
598
|
+
|
599
|
+
find_body_to_world(name: bool = False) -> int | str:
|
600
|
+
Returns the root body connected to the world.
|
601
|
+
|
602
|
+
findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
|
603
|
+
Returns all bodies with the specified joint type.
|
604
|
+
|
605
|
+
children(name: str, names: bool = False) -> list[int] | list[str]:
|
606
|
+
Returns the direct children of a given body.
|
607
|
+
|
608
|
+
findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
|
609
|
+
Finds all bodies in the subsystem rooted at a given link.
|
610
|
+
|
611
|
+
scan(f: Callable, in_types: str, *args, reverse: bool = False):
|
612
|
+
Iterates over system elements while applying a function.
|
613
|
+
|
614
|
+
parse() -> "System":
|
615
|
+
Parses the system, performing consistency checks and computing spatial inertia tensors.
|
616
|
+
|
617
|
+
render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
|
618
|
+
Renders frames of the system using maximal coordinates.
|
619
|
+
|
620
|
+
render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
|
621
|
+
Renders a predicted state transformation.
|
622
|
+
|
623
|
+
delete_system(link_name: str | list[str], strict: bool = True):
|
624
|
+
Removes a subsystem from the system.
|
625
|
+
|
626
|
+
make_sys_noimu(imu_link_names: Optional[list[str]] = None):
|
627
|
+
Returns a version of the system without IMU sensors.
|
628
|
+
|
629
|
+
inject_system(other_system: "System", at_body: Optional[str] = None):
|
630
|
+
Merges another system into this one.
|
631
|
+
|
632
|
+
morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
|
633
|
+
Reorders the system’s link hierarchy.
|
634
|
+
|
635
|
+
from_xml(path: str, seed: int = 1) -> "System":
|
636
|
+
Loads a system from an XML file.
|
637
|
+
|
638
|
+
from_str(xml: str, seed: int = 1) -> "System":
|
639
|
+
Loads a system from an XML string.
|
640
|
+
|
641
|
+
to_str(warn: bool = True) -> str:
|
642
|
+
Serializes the system to an XML string.
|
643
|
+
|
644
|
+
to_xml(path: str) -> None:
|
645
|
+
Saves the system as an XML file.
|
646
|
+
|
647
|
+
create(path_or_str: str, seed: int = 1) -> "System":
|
648
|
+
Creates a `System` instance from an XML file or string.
|
649
|
+
|
650
|
+
coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
|
651
|
+
Converts a coordinate vector to minimal coordinates (`q`), applying
|
652
|
+
constraints such as quaternion normalization.
|
653
|
+
|
654
|
+
Raises:
|
655
|
+
AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
|
656
|
+
InvalidSystemError: If an operation results in an inconsistent system state.
|
657
|
+
|
658
|
+
Notes:
|
659
|
+
- The system must be parsed before use to ensure consistency.
|
660
|
+
- The system supports batch operations using JAX for efficient computations.
|
661
|
+
- Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
|
662
|
+
- Inertial properties of links are computed automatically from associated geometries.
|
663
|
+
""" # noqa: E501
|
664
|
+
|
406
665
|
link_parents: list[int] = struct.field(False)
|
407
666
|
links: Link
|
408
667
|
link_types: list[str] = struct.field(False)
|
@@ -429,25 +688,32 @@ class System(_Base):
|
|
429
688
|
omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
|
430
689
|
|
431
690
|
def num_links(self) -> int:
|
691
|
+
"Returns the number of links in the system."
|
432
692
|
return len(self.link_parents)
|
433
693
|
|
434
694
|
def q_size(self) -> int:
|
695
|
+
"Returns the total number of generalized coordinates (`q`) in the system."
|
435
696
|
return sum([Q_WIDTHS[typ] for typ in self.link_types])
|
436
697
|
|
437
698
|
def qd_size(self) -> int:
|
699
|
+
"Returns the total number of generalized velocities (`qd`) in the system."
|
438
700
|
return sum([QD_WIDTHS[typ] for typ in self.link_types])
|
439
701
|
|
440
702
|
def name_to_idx(self, name: str) -> int:
|
703
|
+
"Returns the index of a link given its name."
|
441
704
|
return self.link_names.index(name)
|
442
705
|
|
443
706
|
def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
|
707
|
+
"""Returns the name of a link given its index. If `allow_world` is `True`,
|
708
|
+
returns `"world"` for index `-1`."""
|
444
709
|
if allow_world and idx == -1:
|
445
710
|
return "world"
|
446
711
|
assert idx >= 0, "Worldbody index has no name."
|
447
712
|
return self.link_names[idx]
|
448
713
|
|
449
714
|
def idx_map(self, type: str) -> dict:
|
450
|
-
"
|
715
|
+
"""Returns a dictionary mapping link names to their indices for a specified type
|
716
|
+
(`"l"`, `"q"`, or `"d"`)."""
|
451
717
|
dict_int_slices = {}
|
452
718
|
|
453
719
|
def f(_, idx_map, name: str, link_idx: int):
|
@@ -458,10 +724,11 @@ class System(_Base):
|
|
458
724
|
return dict_int_slices
|
459
725
|
|
460
726
|
def parent_name(self, name: str) -> str:
|
727
|
+
"Returns the name of the parent link for a given link."
|
461
728
|
return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
|
462
729
|
|
463
730
|
def add_prefix(self, prefix: str = "") -> "System":
|
464
|
-
return self.
|
731
|
+
return self.add_prefix_suffix(prefix=prefix)
|
465
732
|
|
466
733
|
def change_model_name(
|
467
734
|
self,
|
@@ -469,6 +736,7 @@ class System(_Base):
|
|
469
736
|
prefix: Optional[str] = None,
|
470
737
|
suffix: Optional[str] = None,
|
471
738
|
) -> "System":
|
739
|
+
"Changes the name of the system model."
|
472
740
|
if prefix is None:
|
473
741
|
prefix = ""
|
474
742
|
if suffix is None:
|
@@ -479,6 +747,7 @@ class System(_Base):
|
|
479
747
|
return self.replace(model_name=name)
|
480
748
|
|
481
749
|
def change_link_name(self, old_name: str, new_name: str) -> "System":
|
750
|
+
"Renames a specific link."
|
482
751
|
old_idx = self.name_to_idx(old_name)
|
483
752
|
new_link_names = self.link_names.copy()
|
484
753
|
new_link_names[old_idx] = new_name
|
@@ -487,6 +756,7 @@ class System(_Base):
|
|
487
756
|
def add_prefix_suffix(
|
488
757
|
self, prefix: Optional[str] = None, suffix: Optional[str] = None
|
489
758
|
) -> "System":
|
759
|
+
"Adds either or, or both a prefix and suffix to all link names."
|
490
760
|
if prefix is None:
|
491
761
|
prefix = ""
|
492
762
|
if suffix is None:
|
@@ -526,6 +796,7 @@ class System(_Base):
|
|
526
796
|
return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
|
527
797
|
|
528
798
|
def freeze(self, name: str | list[str]):
|
799
|
+
"Freezes the specified link(s), making them immovable (uses `frozen` joint)"
|
529
800
|
if isinstance(name, list):
|
530
801
|
sys = self
|
531
802
|
for n in name:
|
@@ -544,6 +815,7 @@ class System(_Base):
|
|
544
815
|
return _update_sys_if_replace_joint_type(self, logic_freeze)
|
545
816
|
|
546
817
|
def unfreeze(self, name: str, new_joint_type: str):
|
818
|
+
"Unfreezes a frozen link and assigns it a new joint type."
|
547
819
|
assert self.link_types[self.name_to_idx(name)] == "frozen"
|
548
820
|
assert new_joint_type != "frozen"
|
549
821
|
|
@@ -560,7 +832,8 @@ class System(_Base):
|
|
560
832
|
seed: int = 1,
|
561
833
|
warn: bool = True,
|
562
834
|
):
|
563
|
-
"
|
835
|
+
"""Changes the joint type of a specified link.
|
836
|
+
By default damping, stiffness are set to zero."""
|
564
837
|
from ring.algorithms import get_joint_model
|
565
838
|
|
566
839
|
q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
|
@@ -594,6 +867,7 @@ class System(_Base):
|
|
594
867
|
|
595
868
|
@staticmethod
|
596
869
|
def joint_type_simplification(typ: str) -> str:
|
870
|
+
"Returns a simplified name of the given joint type."
|
597
871
|
if typ[:4] == "free":
|
598
872
|
if typ == "free_2d":
|
599
873
|
return "free_2d"
|
@@ -608,23 +882,28 @@ class System(_Base):
|
|
608
882
|
|
609
883
|
@staticmethod
|
610
884
|
def joint_type_is_free_or_cor(typ: str) -> bool:
|
885
|
+
'Checks if a joint type is either "free" or "cor".'
|
611
886
|
return System.joint_type_simplification(typ) in ["free", "cor"]
|
612
887
|
|
613
888
|
@staticmethod
|
614
889
|
def joint_type_is_spherical(typ: str) -> bool:
|
890
|
+
'Checks if a joint type is "spherical".'
|
615
891
|
return System.joint_type_simplification(typ) == "spherical"
|
616
892
|
|
617
893
|
@staticmethod
|
618
894
|
def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
895
|
+
'Checks if a joint type is "free", "cor", or "spherical".'
|
619
896
|
return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
|
620
897
|
typ
|
621
898
|
)
|
622
899
|
|
623
900
|
def findall_imus(self, names: bool = True) -> list[str] | list[int]:
|
901
|
+
"Finds all IMU sensors in the system."
|
624
902
|
bodies = [name for name in self.link_names if name[:3] == "imu"]
|
625
903
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
626
904
|
|
627
905
|
def findall_segments(self, names: bool = True) -> list[str] | list[int]:
|
906
|
+
"Finds all non-IMU segments in the system."
|
628
907
|
imus = self.findall_imus(names=True)
|
629
908
|
bodies = [name for name in self.link_names if name not in imus]
|
630
909
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
@@ -633,10 +912,12 @@ class System(_Base):
|
|
633
912
|
return [self.idx_to_name(i) for i in bodies]
|
634
913
|
|
635
914
|
def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
|
915
|
+
"Returns all bodies directly connected to the world."
|
636
916
|
bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
|
637
917
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
638
918
|
|
639
919
|
def find_body_to_world(self, name: bool = False) -> int | str:
|
920
|
+
"Returns the root body connected to the world."
|
640
921
|
bodies = self.findall_bodies_to_world(names=name)
|
641
922
|
assert len(bodies) == 1
|
642
923
|
return bodies[0]
|
@@ -644,6 +925,7 @@ class System(_Base):
|
|
644
925
|
def findall_bodies_with_jointtype(
|
645
926
|
self, typ: str, names: bool = False
|
646
927
|
) -> list[int] | list[str]:
|
928
|
+
"Returns all bodies with the specified joint type."
|
647
929
|
bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
|
648
930
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
649
931
|
|
@@ -781,20 +1063,25 @@ class System(_Base):
|
|
781
1063
|
|
782
1064
|
@staticmethod
|
783
1065
|
def from_xml(path: str, seed: int = 1):
|
1066
|
+
"Loads a system from an XML file."
|
784
1067
|
return ring.io.load_sys_from_xml(path, seed)
|
785
1068
|
|
786
1069
|
@staticmethod
|
787
1070
|
def from_str(xml: str, seed: int = 1):
|
1071
|
+
"Loads a system from an XML string."
|
788
1072
|
return ring.io.load_sys_from_str(xml, seed)
|
789
1073
|
|
790
1074
|
def to_str(self, warn: bool = True) -> str:
|
1075
|
+
"Serializes the system to an XML string."
|
791
1076
|
return ring.io.save_sys_to_str(self, warn=warn)
|
792
1077
|
|
793
1078
|
def to_xml(self, path: str) -> None:
|
1079
|
+
"Saves the system to an XML file."
|
794
1080
|
ring.io.save_sys_to_xml(self, path)
|
795
1081
|
|
796
1082
|
@classmethod
|
797
1083
|
def create(cls, path_or_str: str, seed: int = 1) -> "System":
|
1084
|
+
"Creates a `System` instance from an XML file or string."
|
798
1085
|
path = Path(path_or_str).with_suffix(".xml")
|
799
1086
|
|
800
1087
|
exists = False
|
@@ -807,14 +1094,15 @@ class System(_Base):
|
|
807
1094
|
if exists:
|
808
1095
|
return cls.from_xml(path, seed=seed)
|
809
1096
|
else:
|
810
|
-
return cls.from_str(path_or_str)
|
1097
|
+
return cls.from_str(path_or_str, seed=seed)
|
811
1098
|
|
812
1099
|
def coordinate_vector_to_q(
|
813
1100
|
self,
|
814
1101
|
q: jax.Array,
|
815
1102
|
custom_joints: dict[str, Callable] = {},
|
816
1103
|
) -> jax.Array:
|
817
|
-
"""
|
1104
|
+
"""Converts a coordinate vector to minimal coordinates (`q`), applying
|
1105
|
+
constraints such as quaternion normalization."""
|
818
1106
|
# Does, e.g.
|
819
1107
|
# - normalize quaternions
|
820
1108
|
# - hinge joints in [-pi, pi]
|
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
|
|
1026
1314
|
|
1027
1315
|
@struct.dataclass
|
1028
1316
|
class State(_Base):
|
1029
|
-
"""The static and dynamic state of a system in minimal and maximal coordinates.
|
1030
|
-
Use `.create()` to create this object.
|
1031
|
-
|
1032
|
-
Args:
|
1033
|
-
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
1034
|
-
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
1035
|
-
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
1036
1317
|
"""
|
1318
|
+
Represents the state of a dynamic system in minimal and maximal coordinates.
|
1319
|
+
|
1320
|
+
The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
|
1321
|
+
of the system in minimal coordinates, as well as the corresponding transforms (`x`)
|
1322
|
+
in maximal coordinates.
|
1323
|
+
|
1324
|
+
Attributes:
|
1325
|
+
q (jax.Array):
|
1326
|
+
The joint positions (generalized coordinates) of the system. The size
|
1327
|
+
of `q` matches `sys.q_size()`.
|
1328
|
+
qd (jax.Array):
|
1329
|
+
The joint velocities (generalized velocities) of the system. The size
|
1330
|
+
of `qd` matches `sys.qd_size()`.
|
1331
|
+
x (Transform):
|
1332
|
+
The maximal coordinate representation of all system links, expressed as
|
1333
|
+
a `Transform` object.
|
1334
|
+
|
1335
|
+
Methods:
|
1336
|
+
create(sys: System, q: Optional[jax.Array] = None,
|
1337
|
+
qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
|
1338
|
+
key: Optional[jax.Array] = None,
|
1339
|
+
custom_joints: dict[str, Callable] = {}) -> State:
|
1340
|
+
Creates a `State` instance for a given system with optional initial conditions.
|
1341
|
+
|
1342
|
+
Usage:
|
1343
|
+
>>> sys = System.create("model.xml")
|
1344
|
+
>>> state = State.create(sys)
|
1345
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1346
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1347
|
+
""" # noqa: E501
|
1037
1348
|
|
1038
1349
|
q: jax.Array
|
1039
1350
|
qd: jax.Array
|
@@ -1048,19 +1359,37 @@ class State(_Base):
|
|
1048
1359
|
x: Optional[Transform] = None,
|
1049
1360
|
key: Optional[jax.Array] = None,
|
1050
1361
|
custom_joints: dict[str, Callable] = {},
|
1051
|
-
):
|
1052
|
-
"""
|
1362
|
+
) -> "State":
|
1363
|
+
"""
|
1364
|
+
Creates a `State` instance for the given system with optional initial conditions.
|
1365
|
+
|
1366
|
+
If no initial values are provided, joint positions (`q`) and velocities (`qd`)
|
1367
|
+
are initialized to zero, except for free and spherical joints, which have unit quaternions.
|
1053
1368
|
|
1054
1369
|
Args:
|
1055
|
-
sys (System):
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1370
|
+
sys (System):
|
1371
|
+
The system for which to create a state.
|
1372
|
+
q (Optional[jax.Array], default=None):
|
1373
|
+
Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
|
1374
|
+
for free and spherical joints.
|
1375
|
+
qd (Optional[jax.Array], default=None):
|
1376
|
+
Initial joint velocities. If `None`, defaults to zeros.
|
1377
|
+
x (Optional[Transform], default=None):
|
1378
|
+
Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
|
1379
|
+
key (Optional[jax.Array], default=None):
|
1380
|
+
Random key for initializing `q` if no values are provided.
|
1381
|
+
custom_joints (dict[str, Callable], default={}):
|
1382
|
+
Custom joint functions for mapping coordinate vectors to minimal coordinates.
|
1060
1383
|
|
1061
1384
|
Returns:
|
1062
|
-
|
1063
|
-
|
1385
|
+
State: A new instance of the `State` class representing the initialized system state.
|
1386
|
+
|
1387
|
+
Example:
|
1388
|
+
>>> sys = System.create("model.xml")
|
1389
|
+
>>> state = State.create(sys)
|
1390
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1391
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1392
|
+
""" # noqa: E501
|
1064
1393
|
if key is not None:
|
1065
1394
|
assert q is None
|
1066
1395
|
q = jax.random.normal(key, shape=(sys.q_size(),))
|
ring/io/xml/from_xml.py
CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
|
|
252
252
|
|
253
253
|
# numpy -> jax
|
254
254
|
# we load using numpy in order to have float64 precision
|
255
|
-
sys = jax.
|
255
|
+
sys = jax.tree.map(jax.numpy.asarray, sys)
|
256
256
|
|
257
257
|
sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
|
258
258
|
|
ring/ml/base.py
CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
|
|
13
13
|
def _to_3d(tree):
|
14
14
|
if tree is None:
|
15
15
|
return None
|
16
|
-
return jax.
|
16
|
+
return jax.tree.map(lambda arr: arr[None], tree)
|
17
17
|
|
18
18
|
|
19
19
|
def _to_2d(tree, i: int = 0):
|
20
20
|
if tree is None:
|
21
21
|
return None
|
22
|
-
return jax.
|
22
|
+
return jax.tree.map(lambda arr: arr[i], tree)
|
23
23
|
|
24
24
|
|
25
25
|
class AbstractFilter(ABC):
|
@@ -297,7 +297,8 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
297
297
|
|
298
298
|
if self._quat_normalize:
|
299
299
|
assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
|
300
|
-
yhat = ring.maths.safe_normalize(yhat)
|
300
|
+
# yhat = ring.maths.safe_normalize(yhat)
|
301
|
+
yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
|
301
302
|
|
302
303
|
return yhat, state
|
303
304
|
|
ring/ml/ml_utils.py
CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
|
|
161
161
|
metrices: NestedDict, filter_nan_inf: bool = True
|
162
162
|
):
|
163
163
|
metrices = _flatten_dict(metrices)
|
164
|
-
metrices = jax.
|
164
|
+
metrices = jax.tree.map(_to_float_if_not_string, metrices)
|
165
165
|
|
166
166
|
if not filter_nan_inf:
|
167
167
|
return metrices
|
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
216
216
|
from jax.experimental import jax2tf
|
217
217
|
import tensorflow as tf
|
218
218
|
|
219
|
-
signature = jax.
|
219
|
+
signature = jax.tree.map(
|
220
220
|
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
221
221
|
)
|
222
222
|
|
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
241
241
|
if validate:
|
242
242
|
output_jax = jax_func(*input)
|
243
243
|
output_tf = tf.saved_model.load(path)(*input)
|
244
|
-
jax.
|
244
|
+
jax.tree.map(
|
245
245
|
lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
|
246
246
|
output_jax,
|
247
247
|
output_tf,
|
ring/ml/ringnet.py
CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
|
|
248
248
|
params, state = self.forward_lam_factory(lam=lam).init(key, X)
|
249
249
|
|
250
250
|
if bs is not None:
|
251
|
-
state = jax.
|
251
|
+
state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
|
252
252
|
|
253
253
|
return params, state
|
254
254
|
|
ring/ml/train.py
CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
|
|
50
50
|
# this vmap maps along batch-axis, not time-axis
|
51
51
|
# time-axis is handled by `metric_fn`
|
52
52
|
pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
|
53
|
-
error_tree = jax.
|
53
|
+
error_tree = jax.tree.map(pipe, y, yhat)
|
54
54
|
return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
|
55
55
|
|
56
56
|
@partial(
|
@@ -274,7 +274,7 @@ def _build_eval_fn(
|
|
274
274
|
), f"The metric identitifier {metric_name} is not unique"
|
275
275
|
|
276
276
|
pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
|
277
|
-
values.update({metric_name: jax.
|
277
|
+
values.update({metric_name: jax.tree.map(pipe, y, yhat)})
|
278
278
|
|
279
279
|
return values
|
280
280
|
|