imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/METADATA +1 -1
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/RECORD +27 -27
- ring/algorithms/custom_joints/suntay.py +1 -1
- ring/algorithms/dynamics.py +27 -1
- ring/algorithms/generator/base.py +82 -2
- ring/algorithms/generator/batch.py +2 -2
- ring/algorithms/generator/finalize_fns.py +1 -1
- ring/algorithms/generator/pd_control.py +1 -1
- ring/algorithms/jcalc.py +198 -0
- ring/algorithms/kinematics.py +2 -1
- ring/algorithms/sensors.py +12 -10
- ring/base.py +356 -27
- ring/io/xml/from_xml.py +1 -1
- ring/ml/base.py +4 -3
- ring/ml/ml_utils.py +3 -3
- ring/ml/ringnet.py +1 -1
- ring/ml/train.py +2 -2
- ring/rendering/mujoco_render.py +11 -7
- ring/rendering/vispy_render.py +5 -4
- ring/sys_composer/inject_sys.py +3 -2
- ring/utils/batchsize.py +3 -3
- ring/utils/dataloader.py +4 -3
- ring/utils/dataloader_torch.py +14 -5
- ring/utils/hdf5.py +1 -1
- ring/utils/normalizer.py +6 -5
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/WHEEL +0 -0
- {imt_ring-1.6.37.dist-info → imt_ring-1.6.39.dist-info}/top_level.txt +0 -0
ring/base.py
CHANGED
@@ -112,17 +112,71 @@ class _Base:
|
|
112
112
|
|
113
113
|
@struct.dataclass
|
114
114
|
class Transform(_Base):
|
115
|
-
"""Represents the Transformation from Plücker A to Plücker B,
|
116
|
-
where B is located relative to A at `pos` in frame A and `rot` is the
|
117
|
-
relative quaternion from A to B.
|
118
|
-
Create using `Transform.create(pos=..., rot=...)
|
119
115
|
"""
|
116
|
+
Represents a spatial transformation between two coordinate frames using Plücker coordinates.
|
117
|
+
|
118
|
+
The `Transform` class defines the relative position and orientation of one frame (`B`)
|
119
|
+
with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
|
120
|
+
of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
|
121
|
+
rotation from frame `A` to frame `B`.
|
122
|
+
|
123
|
+
Attributes:
|
124
|
+
pos (Vector):
|
125
|
+
The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
|
126
|
+
Shape: `(..., 3)`, where `...` represents optional batch dimensions.
|
127
|
+
rot (Quaternion):
|
128
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
129
|
+
Shape: `(..., 4)`, where `...` represents optional batch dimensions.
|
130
|
+
|
131
|
+
Methods:
|
132
|
+
create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
|
133
|
+
Creates a `Transform` instance with optional position and rotation.
|
134
|
+
|
135
|
+
zero(shape: Sequence[int] = ()) -> Transform:
|
136
|
+
Returns a zero transform with a given batch shape.
|
137
|
+
|
138
|
+
as_matrix() -> jax.Array:
|
139
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
140
|
+
|
141
|
+
Usage:
|
142
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
143
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
144
|
+
>>> T = Transform.create(pos, rot)
|
145
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
146
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
147
|
+
>>> print(T.as_matrix()) # 4x4 transformation matrix
|
148
|
+
""" # noqa: E501
|
120
149
|
|
121
150
|
pos: Vector
|
122
151
|
rot: Quaternion
|
123
152
|
|
124
153
|
@classmethod
|
125
154
|
def create(cls, pos=None, rot=None):
|
155
|
+
"""
|
156
|
+
Creates a `Transform` instance with the specified position and rotation.
|
157
|
+
|
158
|
+
At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
|
159
|
+
defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
|
160
|
+
defaults to `[0, 0, 0]`.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
pos (Optional[Vector], default=None):
|
164
|
+
The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
|
165
|
+
If `None`, defaults to a zero vector of shape `(3,)`.
|
166
|
+
rot (Optional[Quaternion], default=None):
|
167
|
+
The unit quaternion representing the orientation of `B` relative to `A`.
|
168
|
+
If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Transform: A new `Transform` instance with the specified position and rotation.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
>>> pos = jnp.array([1.0, 2.0, 3.0])
|
175
|
+
>>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
|
176
|
+
>>> T = Transform.create(pos, rot)
|
177
|
+
>>> print(T.pos) # Output: [1. 2. 3.]
|
178
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
179
|
+
""" # noqa: E501
|
126
180
|
assert not (pos is None and rot is None), "One must be given."
|
127
181
|
shape_rot = rot.shape[:-1] if rot is not None else ()
|
128
182
|
shape_pos = pos.shape[:-1] if pos is not None else ()
|
@@ -139,12 +193,49 @@ class Transform(_Base):
|
|
139
193
|
|
140
194
|
@classmethod
|
141
195
|
def zero(cls, shape=()) -> "Transform":
|
142
|
-
"""
|
196
|
+
"""
|
197
|
+
Returns a zero transform with a given batch shape.
|
198
|
+
|
199
|
+
This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
|
200
|
+
which represents no translation or rotation.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
shape (Sequence[int], default=()):
|
204
|
+
The batch shape for the transform. Defaults to a scalar transform.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Transform: A zero transform with the specified batch shape.
|
208
|
+
|
209
|
+
Example:
|
210
|
+
>>> T = Transform.zero()
|
211
|
+
>>> print(T.pos) # Output: [0. 0. 0.]
|
212
|
+
>>> print(T.rot) # Output: [1. 0. 0. 0.]
|
213
|
+
""" # noqa: E501
|
143
214
|
pos = jnp.zeros(shape + (3,))
|
144
215
|
rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
|
145
216
|
return Transform(pos, rot)
|
146
217
|
|
147
218
|
def as_matrix(self) -> jax.Array:
|
219
|
+
"""
|
220
|
+
Returns the 4x4 homogeneous transformation matrix representation of this transform.
|
221
|
+
|
222
|
+
The homogeneous transformation matrix is defined as:
|
223
|
+
|
224
|
+
```
|
225
|
+
[ R t ]
|
226
|
+
[ 0 1 ]
|
227
|
+
```
|
228
|
+
|
229
|
+
where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
|
230
|
+
3x1 position vector.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
jax.Array: A `(4, 4)` homogeneous transformation matrix.
|
234
|
+
|
235
|
+
Example:
|
236
|
+
>>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
|
237
|
+
>>> print(T.as_matrix()) # Output: 4x4 matrix
|
238
|
+
""" # noqa: E501
|
148
239
|
E = maths.quat_to_3x3(self.rot)
|
149
240
|
return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
|
150
241
|
|
@@ -402,7 +493,175 @@ QD_WIDTHS = {
|
|
402
493
|
|
403
494
|
@struct.dataclass
|
404
495
|
class System(_Base):
|
405
|
-
"
|
496
|
+
"""
|
497
|
+
Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
|
498
|
+
|
499
|
+
The `System` class models the kinematic and dynamic properties of a multibody
|
500
|
+
system, providing methods for state representation, transformations, joint
|
501
|
+
configuration management, and rendering. It supports both minimal and maximal
|
502
|
+
coordinate representations and can be parsed from or saved to XML files.
|
503
|
+
|
504
|
+
Attributes:
|
505
|
+
link_parents (list[int]):
|
506
|
+
A list specifying the parent index for each link. The root link has a parent index of `-1`.
|
507
|
+
links (Link):
|
508
|
+
A data structure containing information about all links in the system.
|
509
|
+
link_types (list[str]):
|
510
|
+
A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
|
511
|
+
link_damping (jax.Array):
|
512
|
+
Joint damping coefficients for each link.
|
513
|
+
link_armature (jax.Array):
|
514
|
+
Armature inertia values for each joint.
|
515
|
+
link_spring_stiffness (jax.Array):
|
516
|
+
Stiffness values for joint springs.
|
517
|
+
link_spring_zeropoint (jax.Array):
|
518
|
+
Rest position of joint springs.
|
519
|
+
dt (float):
|
520
|
+
Simulation time step size.
|
521
|
+
geoms (list[Geometry]):
|
522
|
+
List of geometries associated with the system.
|
523
|
+
gravity (jax.Array):
|
524
|
+
Gravity vector applied to the system (default: `[0, 0, -9.81]`).
|
525
|
+
integration_method (str):
|
526
|
+
Integration method for simulation (default: "semi_implicit_euler").
|
527
|
+
mass_mat_iters (int):
|
528
|
+
Number of iterations for mass matrix calculations.
|
529
|
+
link_names (list[str]):
|
530
|
+
Names of the links in the system.
|
531
|
+
model_name (Optional[str]):
|
532
|
+
Name of the system model (if available).
|
533
|
+
omc (list[MaxCoordOMC | None]):
|
534
|
+
List of optional Maximal Coordinate representations.
|
535
|
+
|
536
|
+
Methods:
|
537
|
+
num_links() -> int:
|
538
|
+
Returns the number of links in the system.
|
539
|
+
|
540
|
+
q_size() -> int:
|
541
|
+
Returns the total number of generalized coordinates (`q`) in the system.
|
542
|
+
|
543
|
+
qd_size() -> int:
|
544
|
+
Returns the total number of generalized velocities (`qd`) in the system.
|
545
|
+
|
546
|
+
name_to_idx(name: str) -> int:
|
547
|
+
Returns the index of a link given its name.
|
548
|
+
|
549
|
+
idx_to_name(idx: int, allow_world: bool = False) -> str:
|
550
|
+
Returns the name of a link given its index. If `allow_world` is `True`,
|
551
|
+
returns `"world"` for index `-1`.
|
552
|
+
|
553
|
+
idx_map(type: str) -> dict:
|
554
|
+
Returns a dictionary mapping link names to their indices for a specified type
|
555
|
+
(`"l"`, `"q"`, or `"d"`).
|
556
|
+
|
557
|
+
parent_name(name: str) -> str:
|
558
|
+
Returns the name of the parent link for a given link.
|
559
|
+
|
560
|
+
change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
561
|
+
Changes the name of the system model.
|
562
|
+
|
563
|
+
change_link_name(old_name: str, new_name: str) -> "System":
|
564
|
+
Renames a specific link.
|
565
|
+
|
566
|
+
add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
|
567
|
+
Adds both a prefix and suffix to all link names.
|
568
|
+
|
569
|
+
freeze(name: str | list[str]) -> "System":
|
570
|
+
Freezes the specified link(s), making them immovable.
|
571
|
+
|
572
|
+
unfreeze(name: str, new_joint_type: str) -> "System":
|
573
|
+
Unfreezes a frozen link and assigns it a new joint type.
|
574
|
+
|
575
|
+
change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
|
576
|
+
Changes the joint type of a specified link.
|
577
|
+
|
578
|
+
joint_type_simplification(typ: str) -> str:
|
579
|
+
Returns a simplified representation of the given joint type.
|
580
|
+
|
581
|
+
joint_type_is_free_or_cor(typ: str) -> bool:
|
582
|
+
Checks if a joint type is either "free" or "cor".
|
583
|
+
|
584
|
+
joint_type_is_spherical(typ: str) -> bool:
|
585
|
+
Checks if a joint type is "spherical".
|
586
|
+
|
587
|
+
joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
588
|
+
Checks if a joint type is "free", "cor", or "spherical".
|
589
|
+
|
590
|
+
findall_imus(names: bool = True) -> list[str] | list[int]:
|
591
|
+
Finds all IMU sensors in the system.
|
592
|
+
|
593
|
+
findall_segments(names: bool = True) -> list[str] | list[int]:
|
594
|
+
Finds all non-IMU segments in the system.
|
595
|
+
|
596
|
+
findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
|
597
|
+
Returns all bodies directly connected to the world.
|
598
|
+
|
599
|
+
find_body_to_world(name: bool = False) -> int | str:
|
600
|
+
Returns the root body connected to the world.
|
601
|
+
|
602
|
+
findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
|
603
|
+
Returns all bodies with the specified joint type.
|
604
|
+
|
605
|
+
children(name: str, names: bool = False) -> list[int] | list[str]:
|
606
|
+
Returns the direct children of a given body.
|
607
|
+
|
608
|
+
findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
|
609
|
+
Finds all bodies in the subsystem rooted at a given link.
|
610
|
+
|
611
|
+
scan(f: Callable, in_types: str, *args, reverse: bool = False):
|
612
|
+
Iterates over system elements while applying a function.
|
613
|
+
|
614
|
+
parse() -> "System":
|
615
|
+
Parses the system, performing consistency checks and computing spatial inertia tensors.
|
616
|
+
|
617
|
+
render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
|
618
|
+
Renders frames of the system using maximal coordinates.
|
619
|
+
|
620
|
+
render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
|
621
|
+
Renders a predicted state transformation.
|
622
|
+
|
623
|
+
delete_system(link_name: str | list[str], strict: bool = True):
|
624
|
+
Removes a subsystem from the system.
|
625
|
+
|
626
|
+
make_sys_noimu(imu_link_names: Optional[list[str]] = None):
|
627
|
+
Returns a version of the system without IMU sensors.
|
628
|
+
|
629
|
+
inject_system(other_system: "System", at_body: Optional[str] = None):
|
630
|
+
Merges another system into this one.
|
631
|
+
|
632
|
+
morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
|
633
|
+
Reorders the system’s link hierarchy.
|
634
|
+
|
635
|
+
from_xml(path: str, seed: int = 1) -> "System":
|
636
|
+
Loads a system from an XML file.
|
637
|
+
|
638
|
+
from_str(xml: str, seed: int = 1) -> "System":
|
639
|
+
Loads a system from an XML string.
|
640
|
+
|
641
|
+
to_str(warn: bool = True) -> str:
|
642
|
+
Serializes the system to an XML string.
|
643
|
+
|
644
|
+
to_xml(path: str) -> None:
|
645
|
+
Saves the system as an XML file.
|
646
|
+
|
647
|
+
create(path_or_str: str, seed: int = 1) -> "System":
|
648
|
+
Creates a `System` instance from an XML file or string.
|
649
|
+
|
650
|
+
coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
|
651
|
+
Converts a coordinate vector to minimal coordinates (`q`), applying
|
652
|
+
constraints such as quaternion normalization.
|
653
|
+
|
654
|
+
Raises:
|
655
|
+
AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
|
656
|
+
InvalidSystemError: If an operation results in an inconsistent system state.
|
657
|
+
|
658
|
+
Notes:
|
659
|
+
- The system must be parsed before use to ensure consistency.
|
660
|
+
- The system supports batch operations using JAX for efficient computations.
|
661
|
+
- Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
|
662
|
+
- Inertial properties of links are computed automatically from associated geometries.
|
663
|
+
""" # noqa: E501
|
664
|
+
|
406
665
|
link_parents: list[int] = struct.field(False)
|
407
666
|
links: Link
|
408
667
|
link_types: list[str] = struct.field(False)
|
@@ -429,25 +688,32 @@ class System(_Base):
|
|
429
688
|
omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
|
430
689
|
|
431
690
|
def num_links(self) -> int:
|
691
|
+
"Returns the number of links in the system."
|
432
692
|
return len(self.link_parents)
|
433
693
|
|
434
694
|
def q_size(self) -> int:
|
695
|
+
"Returns the total number of generalized coordinates (`q`) in the system."
|
435
696
|
return sum([Q_WIDTHS[typ] for typ in self.link_types])
|
436
697
|
|
437
698
|
def qd_size(self) -> int:
|
699
|
+
"Returns the total number of generalized velocities (`qd`) in the system."
|
438
700
|
return sum([QD_WIDTHS[typ] for typ in self.link_types])
|
439
701
|
|
440
702
|
def name_to_idx(self, name: str) -> int:
|
703
|
+
"Returns the index of a link given its name."
|
441
704
|
return self.link_names.index(name)
|
442
705
|
|
443
706
|
def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
|
707
|
+
"""Returns the name of a link given its index. If `allow_world` is `True`,
|
708
|
+
returns `"world"` for index `-1`."""
|
444
709
|
if allow_world and idx == -1:
|
445
710
|
return "world"
|
446
711
|
assert idx >= 0, "Worldbody index has no name."
|
447
712
|
return self.link_names[idx]
|
448
713
|
|
449
714
|
def idx_map(self, type: str) -> dict:
|
450
|
-
"
|
715
|
+
"""Returns a dictionary mapping link names to their indices for a specified type
|
716
|
+
(`"l"`, `"q"`, or `"d"`)."""
|
451
717
|
dict_int_slices = {}
|
452
718
|
|
453
719
|
def f(_, idx_map, name: str, link_idx: int):
|
@@ -458,10 +724,11 @@ class System(_Base):
|
|
458
724
|
return dict_int_slices
|
459
725
|
|
460
726
|
def parent_name(self, name: str) -> str:
|
727
|
+
"Returns the name of the parent link for a given link."
|
461
728
|
return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
|
462
729
|
|
463
730
|
def add_prefix(self, prefix: str = "") -> "System":
|
464
|
-
return self.
|
731
|
+
return self.add_prefix_suffix(prefix=prefix)
|
465
732
|
|
466
733
|
def change_model_name(
|
467
734
|
self,
|
@@ -469,6 +736,7 @@ class System(_Base):
|
|
469
736
|
prefix: Optional[str] = None,
|
470
737
|
suffix: Optional[str] = None,
|
471
738
|
) -> "System":
|
739
|
+
"Changes the name of the system model."
|
472
740
|
if prefix is None:
|
473
741
|
prefix = ""
|
474
742
|
if suffix is None:
|
@@ -479,6 +747,7 @@ class System(_Base):
|
|
479
747
|
return self.replace(model_name=name)
|
480
748
|
|
481
749
|
def change_link_name(self, old_name: str, new_name: str) -> "System":
|
750
|
+
"Renames a specific link."
|
482
751
|
old_idx = self.name_to_idx(old_name)
|
483
752
|
new_link_names = self.link_names.copy()
|
484
753
|
new_link_names[old_idx] = new_name
|
@@ -487,6 +756,7 @@ class System(_Base):
|
|
487
756
|
def add_prefix_suffix(
|
488
757
|
self, prefix: Optional[str] = None, suffix: Optional[str] = None
|
489
758
|
) -> "System":
|
759
|
+
"Adds either or, or both a prefix and suffix to all link names."
|
490
760
|
if prefix is None:
|
491
761
|
prefix = ""
|
492
762
|
if suffix is None:
|
@@ -526,6 +796,7 @@ class System(_Base):
|
|
526
796
|
return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
|
527
797
|
|
528
798
|
def freeze(self, name: str | list[str]):
|
799
|
+
"Freezes the specified link(s), making them immovable (uses `frozen` joint)"
|
529
800
|
if isinstance(name, list):
|
530
801
|
sys = self
|
531
802
|
for n in name:
|
@@ -544,6 +815,7 @@ class System(_Base):
|
|
544
815
|
return _update_sys_if_replace_joint_type(self, logic_freeze)
|
545
816
|
|
546
817
|
def unfreeze(self, name: str, new_joint_type: str):
|
818
|
+
"Unfreezes a frozen link and assigns it a new joint type."
|
547
819
|
assert self.link_types[self.name_to_idx(name)] == "frozen"
|
548
820
|
assert new_joint_type != "frozen"
|
549
821
|
|
@@ -560,7 +832,8 @@ class System(_Base):
|
|
560
832
|
seed: int = 1,
|
561
833
|
warn: bool = True,
|
562
834
|
):
|
563
|
-
"
|
835
|
+
"""Changes the joint type of a specified link.
|
836
|
+
By default damping, stiffness are set to zero."""
|
564
837
|
from ring.algorithms import get_joint_model
|
565
838
|
|
566
839
|
q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
|
@@ -594,6 +867,7 @@ class System(_Base):
|
|
594
867
|
|
595
868
|
@staticmethod
|
596
869
|
def joint_type_simplification(typ: str) -> str:
|
870
|
+
"Returns a simplified name of the given joint type."
|
597
871
|
if typ[:4] == "free":
|
598
872
|
if typ == "free_2d":
|
599
873
|
return "free_2d"
|
@@ -608,23 +882,28 @@ class System(_Base):
|
|
608
882
|
|
609
883
|
@staticmethod
|
610
884
|
def joint_type_is_free_or_cor(typ: str) -> bool:
|
885
|
+
'Checks if a joint type is either "free" or "cor".'
|
611
886
|
return System.joint_type_simplification(typ) in ["free", "cor"]
|
612
887
|
|
613
888
|
@staticmethod
|
614
889
|
def joint_type_is_spherical(typ: str) -> bool:
|
890
|
+
'Checks if a joint type is "spherical".'
|
615
891
|
return System.joint_type_simplification(typ) == "spherical"
|
616
892
|
|
617
893
|
@staticmethod
|
618
894
|
def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
|
895
|
+
'Checks if a joint type is "free", "cor", or "spherical".'
|
619
896
|
return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
|
620
897
|
typ
|
621
898
|
)
|
622
899
|
|
623
900
|
def findall_imus(self, names: bool = True) -> list[str] | list[int]:
|
901
|
+
"Finds all IMU sensors in the system."
|
624
902
|
bodies = [name for name in self.link_names if name[:3] == "imu"]
|
625
903
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
626
904
|
|
627
905
|
def findall_segments(self, names: bool = True) -> list[str] | list[int]:
|
906
|
+
"Finds all non-IMU segments in the system."
|
628
907
|
imus = self.findall_imus(names=True)
|
629
908
|
bodies = [name for name in self.link_names if name not in imus]
|
630
909
|
return bodies if names else [self.name_to_idx(n) for n in bodies]
|
@@ -633,10 +912,12 @@ class System(_Base):
|
|
633
912
|
return [self.idx_to_name(i) for i in bodies]
|
634
913
|
|
635
914
|
def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
|
915
|
+
"Returns all bodies directly connected to the world."
|
636
916
|
bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
|
637
917
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
638
918
|
|
639
919
|
def find_body_to_world(self, name: bool = False) -> int | str:
|
920
|
+
"Returns the root body connected to the world."
|
640
921
|
bodies = self.findall_bodies_to_world(names=name)
|
641
922
|
assert len(bodies) == 1
|
642
923
|
return bodies[0]
|
@@ -644,6 +925,7 @@ class System(_Base):
|
|
644
925
|
def findall_bodies_with_jointtype(
|
645
926
|
self, typ: str, names: bool = False
|
646
927
|
) -> list[int] | list[str]:
|
928
|
+
"Returns all bodies with the specified joint type."
|
647
929
|
bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
|
648
930
|
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
649
931
|
|
@@ -781,20 +1063,25 @@ class System(_Base):
|
|
781
1063
|
|
782
1064
|
@staticmethod
|
783
1065
|
def from_xml(path: str, seed: int = 1):
|
1066
|
+
"Loads a system from an XML file."
|
784
1067
|
return ring.io.load_sys_from_xml(path, seed)
|
785
1068
|
|
786
1069
|
@staticmethod
|
787
1070
|
def from_str(xml: str, seed: int = 1):
|
1071
|
+
"Loads a system from an XML string."
|
788
1072
|
return ring.io.load_sys_from_str(xml, seed)
|
789
1073
|
|
790
1074
|
def to_str(self, warn: bool = True) -> str:
|
1075
|
+
"Serializes the system to an XML string."
|
791
1076
|
return ring.io.save_sys_to_str(self, warn=warn)
|
792
1077
|
|
793
1078
|
def to_xml(self, path: str) -> None:
|
1079
|
+
"Saves the system to an XML file."
|
794
1080
|
ring.io.save_sys_to_xml(self, path)
|
795
1081
|
|
796
1082
|
@classmethod
|
797
1083
|
def create(cls, path_or_str: str, seed: int = 1) -> "System":
|
1084
|
+
"Creates a `System` instance from an XML file or string."
|
798
1085
|
path = Path(path_or_str).with_suffix(".xml")
|
799
1086
|
|
800
1087
|
exists = False
|
@@ -807,14 +1094,15 @@ class System(_Base):
|
|
807
1094
|
if exists:
|
808
1095
|
return cls.from_xml(path, seed=seed)
|
809
1096
|
else:
|
810
|
-
return cls.from_str(path_or_str)
|
1097
|
+
return cls.from_str(path_or_str, seed=seed)
|
811
1098
|
|
812
1099
|
def coordinate_vector_to_q(
|
813
1100
|
self,
|
814
1101
|
q: jax.Array,
|
815
1102
|
custom_joints: dict[str, Callable] = {},
|
816
1103
|
) -> jax.Array:
|
817
|
-
"""
|
1104
|
+
"""Converts a coordinate vector to minimal coordinates (`q`), applying
|
1105
|
+
constraints such as quaternion normalization."""
|
818
1106
|
# Does, e.g.
|
819
1107
|
# - normalize quaternions
|
820
1108
|
# - hinge joints in [-pi, pi]
|
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
|
|
1026
1314
|
|
1027
1315
|
@struct.dataclass
|
1028
1316
|
class State(_Base):
|
1029
|
-
"""The static and dynamic state of a system in minimal and maximal coordinates.
|
1030
|
-
Use `.create()` to create this object.
|
1031
|
-
|
1032
|
-
Args:
|
1033
|
-
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
1034
|
-
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
1035
|
-
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
1036
1317
|
"""
|
1318
|
+
Represents the state of a dynamic system in minimal and maximal coordinates.
|
1319
|
+
|
1320
|
+
The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
|
1321
|
+
of the system in minimal coordinates, as well as the corresponding transforms (`x`)
|
1322
|
+
in maximal coordinates.
|
1323
|
+
|
1324
|
+
Attributes:
|
1325
|
+
q (jax.Array):
|
1326
|
+
The joint positions (generalized coordinates) of the system. The size
|
1327
|
+
of `q` matches `sys.q_size()`.
|
1328
|
+
qd (jax.Array):
|
1329
|
+
The joint velocities (generalized velocities) of the system. The size
|
1330
|
+
of `qd` matches `sys.qd_size()`.
|
1331
|
+
x (Transform):
|
1332
|
+
The maximal coordinate representation of all system links, expressed as
|
1333
|
+
a `Transform` object.
|
1334
|
+
|
1335
|
+
Methods:
|
1336
|
+
create(sys: System, q: Optional[jax.Array] = None,
|
1337
|
+
qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
|
1338
|
+
key: Optional[jax.Array] = None,
|
1339
|
+
custom_joints: dict[str, Callable] = {}) -> State:
|
1340
|
+
Creates a `State` instance for a given system with optional initial conditions.
|
1341
|
+
|
1342
|
+
Usage:
|
1343
|
+
>>> sys = System.create("model.xml")
|
1344
|
+
>>> state = State.create(sys)
|
1345
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1346
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1347
|
+
""" # noqa: E501
|
1037
1348
|
|
1038
1349
|
q: jax.Array
|
1039
1350
|
qd: jax.Array
|
@@ -1048,19 +1359,37 @@ class State(_Base):
|
|
1048
1359
|
x: Optional[Transform] = None,
|
1049
1360
|
key: Optional[jax.Array] = None,
|
1050
1361
|
custom_joints: dict[str, Callable] = {},
|
1051
|
-
):
|
1052
|
-
"""
|
1362
|
+
) -> "State":
|
1363
|
+
"""
|
1364
|
+
Creates a `State` instance for the given system with optional initial conditions.
|
1365
|
+
|
1366
|
+
If no initial values are provided, joint positions (`q`) and velocities (`qd`)
|
1367
|
+
are initialized to zero, except for free and spherical joints, which have unit quaternions.
|
1053
1368
|
|
1054
1369
|
Args:
|
1055
|
-
sys (System):
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1370
|
+
sys (System):
|
1371
|
+
The system for which to create a state.
|
1372
|
+
q (Optional[jax.Array], default=None):
|
1373
|
+
Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
|
1374
|
+
for free and spherical joints.
|
1375
|
+
qd (Optional[jax.Array], default=None):
|
1376
|
+
Initial joint velocities. If `None`, defaults to zeros.
|
1377
|
+
x (Optional[Transform], default=None):
|
1378
|
+
Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
|
1379
|
+
key (Optional[jax.Array], default=None):
|
1380
|
+
Random key for initializing `q` if no values are provided.
|
1381
|
+
custom_joints (dict[str, Callable], default={}):
|
1382
|
+
Custom joint functions for mapping coordinate vectors to minimal coordinates.
|
1060
1383
|
|
1061
1384
|
Returns:
|
1062
|
-
|
1063
|
-
|
1385
|
+
State: A new instance of the `State` class representing the initialized system state.
|
1386
|
+
|
1387
|
+
Example:
|
1388
|
+
>>> sys = System.create("model.xml")
|
1389
|
+
>>> state = State.create(sys)
|
1390
|
+
>>> print(state.q.shape) # Should match sys.q_size()
|
1391
|
+
>>> print(state.qd.shape) # Should match sys.qd_size()
|
1392
|
+
""" # noqa: E501
|
1064
1393
|
if key is not None:
|
1065
1394
|
assert q is None
|
1066
1395
|
q = jax.random.normal(key, shape=(sys.q_size(),))
|
ring/io/xml/from_xml.py
CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
|
|
252
252
|
|
253
253
|
# numpy -> jax
|
254
254
|
# we load using numpy in order to have float64 precision
|
255
|
-
sys = jax.
|
255
|
+
sys = jax.tree.map(jax.numpy.asarray, sys)
|
256
256
|
|
257
257
|
sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
|
258
258
|
|
ring/ml/base.py
CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
|
|
13
13
|
def _to_3d(tree):
|
14
14
|
if tree is None:
|
15
15
|
return None
|
16
|
-
return jax.
|
16
|
+
return jax.tree.map(lambda arr: arr[None], tree)
|
17
17
|
|
18
18
|
|
19
19
|
def _to_2d(tree, i: int = 0):
|
20
20
|
if tree is None:
|
21
21
|
return None
|
22
|
-
return jax.
|
22
|
+
return jax.tree.map(lambda arr: arr[i], tree)
|
23
23
|
|
24
24
|
|
25
25
|
class AbstractFilter(ABC):
|
@@ -297,7 +297,8 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
297
297
|
|
298
298
|
if self._quat_normalize:
|
299
299
|
assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
|
300
|
-
yhat = ring.maths.safe_normalize(yhat)
|
300
|
+
# yhat = ring.maths.safe_normalize(yhat)
|
301
|
+
yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
|
301
302
|
|
302
303
|
return yhat, state
|
303
304
|
|
ring/ml/ml_utils.py
CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
|
|
161
161
|
metrices: NestedDict, filter_nan_inf: bool = True
|
162
162
|
):
|
163
163
|
metrices = _flatten_dict(metrices)
|
164
|
-
metrices = jax.
|
164
|
+
metrices = jax.tree.map(_to_float_if_not_string, metrices)
|
165
165
|
|
166
166
|
if not filter_nan_inf:
|
167
167
|
return metrices
|
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
216
216
|
from jax.experimental import jax2tf
|
217
217
|
import tensorflow as tf
|
218
218
|
|
219
|
-
signature = jax.
|
219
|
+
signature = jax.tree.map(
|
220
220
|
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
221
221
|
)
|
222
222
|
|
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
241
241
|
if validate:
|
242
242
|
output_jax = jax_func(*input)
|
243
243
|
output_tf = tf.saved_model.load(path)(*input)
|
244
|
-
jax.
|
244
|
+
jax.tree.map(
|
245
245
|
lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
|
246
246
|
output_jax,
|
247
247
|
output_tf,
|
ring/ml/ringnet.py
CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
|
|
248
248
|
params, state = self.forward_lam_factory(lam=lam).init(key, X)
|
249
249
|
|
250
250
|
if bs is not None:
|
251
|
-
state = jax.
|
251
|
+
state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
|
252
252
|
|
253
253
|
return params, state
|
254
254
|
|
ring/ml/train.py
CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
|
|
50
50
|
# this vmap maps along batch-axis, not time-axis
|
51
51
|
# time-axis is handled by `metric_fn`
|
52
52
|
pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
|
53
|
-
error_tree = jax.
|
53
|
+
error_tree = jax.tree.map(pipe, y, yhat)
|
54
54
|
return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
|
55
55
|
|
56
56
|
@partial(
|
@@ -274,7 +274,7 @@ def _build_eval_fn(
|
|
274
274
|
), f"The metric identitifier {metric_name} is not unique"
|
275
275
|
|
276
276
|
pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
|
277
|
-
values.update({metric_name: jax.
|
277
|
+
values.update({metric_name: jax.tree.map(pipe, y, yhat)})
|
278
278
|
|
279
279
|
return values
|
280
280
|
|