imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
ring/base.py CHANGED
@@ -112,17 +112,71 @@ class _Base:
112
112
 
113
113
  @struct.dataclass
114
114
  class Transform(_Base):
115
- """Represents the Transformation from Plücker A to Plücker B,
116
- where B is located relative to A at `pos` in frame A and `rot` is the
117
- relative quaternion from A to B.
118
- Create using `Transform.create(pos=..., rot=...)
119
115
  """
116
+ Represents a spatial transformation between two coordinate frames using Plücker coordinates.
117
+
118
+ The `Transform` class defines the relative position and orientation of one frame (`B`)
119
+ with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
120
+ of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
121
+ rotation from frame `A` to frame `B`.
122
+
123
+ Attributes:
124
+ pos (Vector):
125
+ The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
126
+ Shape: `(..., 3)`, where `...` represents optional batch dimensions.
127
+ rot (Quaternion):
128
+ The unit quaternion representing the orientation of `B` relative to `A`.
129
+ Shape: `(..., 4)`, where `...` represents optional batch dimensions.
130
+
131
+ Methods:
132
+ create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
133
+ Creates a `Transform` instance with optional position and rotation.
134
+
135
+ zero(shape: Sequence[int] = ()) -> Transform:
136
+ Returns a zero transform with a given batch shape.
137
+
138
+ as_matrix() -> jax.Array:
139
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
140
+
141
+ Usage:
142
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
143
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
144
+ >>> T = Transform.create(pos, rot)
145
+ >>> print(T.pos) # Output: [1. 2. 3.]
146
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
147
+ >>> print(T.as_matrix()) # 4x4 transformation matrix
148
+ """ # noqa: E501
120
149
 
121
150
  pos: Vector
122
151
  rot: Quaternion
123
152
 
124
153
  @classmethod
125
154
  def create(cls, pos=None, rot=None):
155
+ """
156
+ Creates a `Transform` instance with the specified position and rotation.
157
+
158
+ At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
159
+ defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
160
+ defaults to `[0, 0, 0]`.
161
+
162
+ Args:
163
+ pos (Optional[Vector], default=None):
164
+ The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
165
+ If `None`, defaults to a zero vector of shape `(3,)`.
166
+ rot (Optional[Quaternion], default=None):
167
+ The unit quaternion representing the orientation of `B` relative to `A`.
168
+ If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
169
+
170
+ Returns:
171
+ Transform: A new `Transform` instance with the specified position and rotation.
172
+
173
+ Example:
174
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
175
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
176
+ >>> T = Transform.create(pos, rot)
177
+ >>> print(T.pos) # Output: [1. 2. 3.]
178
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
179
+ """ # noqa: E501
126
180
  assert not (pos is None and rot is None), "One must be given."
127
181
  shape_rot = rot.shape[:-1] if rot is not None else ()
128
182
  shape_pos = pos.shape[:-1] if pos is not None else ()
@@ -139,12 +193,49 @@ class Transform(_Base):
139
193
 
140
194
  @classmethod
141
195
  def zero(cls, shape=()) -> "Transform":
142
- """Returns a zero transform with a batch shape."""
196
+ """
197
+ Returns a zero transform with a given batch shape.
198
+
199
+ This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
200
+ which represents no translation or rotation.
201
+
202
+ Args:
203
+ shape (Sequence[int], default=()):
204
+ The batch shape for the transform. Defaults to a scalar transform.
205
+
206
+ Returns:
207
+ Transform: A zero transform with the specified batch shape.
208
+
209
+ Example:
210
+ >>> T = Transform.zero()
211
+ >>> print(T.pos) # Output: [0. 0. 0.]
212
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
213
+ """ # noqa: E501
143
214
  pos = jnp.zeros(shape + (3,))
144
215
  rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
145
216
  return Transform(pos, rot)
146
217
 
147
218
  def as_matrix(self) -> jax.Array:
219
+ """
220
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
221
+
222
+ The homogeneous transformation matrix is defined as:
223
+
224
+ ```
225
+ [ R t ]
226
+ [ 0 1 ]
227
+ ```
228
+
229
+ where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
230
+ 3x1 position vector.
231
+
232
+ Returns:
233
+ jax.Array: A `(4, 4)` homogeneous transformation matrix.
234
+
235
+ Example:
236
+ >>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
237
+ >>> print(T.as_matrix()) # Output: 4x4 matrix
238
+ """ # noqa: E501
148
239
  E = maths.quat_to_3x3(self.rot)
149
240
  return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
150
241
 
@@ -402,7 +493,175 @@ QD_WIDTHS = {
402
493
 
403
494
  @struct.dataclass
404
495
  class System(_Base):
405
- "System object. Create using `System.create(path_xml)`"
496
+ """
497
+ Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
498
+
499
+ The `System` class models the kinematic and dynamic properties of a multibody
500
+ system, providing methods for state representation, transformations, joint
501
+ configuration management, and rendering. It supports both minimal and maximal
502
+ coordinate representations and can be parsed from or saved to XML files.
503
+
504
+ Attributes:
505
+ link_parents (list[int]):
506
+ A list specifying the parent index for each link. The root link has a parent index of `-1`.
507
+ links (Link):
508
+ A data structure containing information about all links in the system.
509
+ link_types (list[str]):
510
+ A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
511
+ link_damping (jax.Array):
512
+ Joint damping coefficients for each link.
513
+ link_armature (jax.Array):
514
+ Armature inertia values for each joint.
515
+ link_spring_stiffness (jax.Array):
516
+ Stiffness values for joint springs.
517
+ link_spring_zeropoint (jax.Array):
518
+ Rest position of joint springs.
519
+ dt (float):
520
+ Simulation time step size.
521
+ geoms (list[Geometry]):
522
+ List of geometries associated with the system.
523
+ gravity (jax.Array):
524
+ Gravity vector applied to the system (default: `[0, 0, -9.81]`).
525
+ integration_method (str):
526
+ Integration method for simulation (default: "semi_implicit_euler").
527
+ mass_mat_iters (int):
528
+ Number of iterations for mass matrix calculations.
529
+ link_names (list[str]):
530
+ Names of the links in the system.
531
+ model_name (Optional[str]):
532
+ Name of the system model (if available).
533
+ omc (list[MaxCoordOMC | None]):
534
+ List of optional Maximal Coordinate representations.
535
+
536
+ Methods:
537
+ num_links() -> int:
538
+ Returns the number of links in the system.
539
+
540
+ q_size() -> int:
541
+ Returns the total number of generalized coordinates (`q`) in the system.
542
+
543
+ qd_size() -> int:
544
+ Returns the total number of generalized velocities (`qd`) in the system.
545
+
546
+ name_to_idx(name: str) -> int:
547
+ Returns the index of a link given its name.
548
+
549
+ idx_to_name(idx: int, allow_world: bool = False) -> str:
550
+ Returns the name of a link given its index. If `allow_world` is `True`,
551
+ returns `"world"` for index `-1`.
552
+
553
+ idx_map(type: str) -> dict:
554
+ Returns a dictionary mapping link names to their indices for a specified type
555
+ (`"l"`, `"q"`, or `"d"`).
556
+
557
+ parent_name(name: str) -> str:
558
+ Returns the name of the parent link for a given link.
559
+
560
+ change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
561
+ Changes the name of the system model.
562
+
563
+ change_link_name(old_name: str, new_name: str) -> "System":
564
+ Renames a specific link.
565
+
566
+ add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
567
+ Adds both a prefix and suffix to all link names.
568
+
569
+ freeze(name: str | list[str]) -> "System":
570
+ Freezes the specified link(s), making them immovable.
571
+
572
+ unfreeze(name: str, new_joint_type: str) -> "System":
573
+ Unfreezes a frozen link and assigns it a new joint type.
574
+
575
+ change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
576
+ Changes the joint type of a specified link.
577
+
578
+ joint_type_simplification(typ: str) -> str:
579
+ Returns a simplified representation of the given joint type.
580
+
581
+ joint_type_is_free_or_cor(typ: str) -> bool:
582
+ Checks if a joint type is either "free" or "cor".
583
+
584
+ joint_type_is_spherical(typ: str) -> bool:
585
+ Checks if a joint type is "spherical".
586
+
587
+ joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
588
+ Checks if a joint type is "free", "cor", or "spherical".
589
+
590
+ findall_imus(names: bool = True) -> list[str] | list[int]:
591
+ Finds all IMU sensors in the system.
592
+
593
+ findall_segments(names: bool = True) -> list[str] | list[int]:
594
+ Finds all non-IMU segments in the system.
595
+
596
+ findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
597
+ Returns all bodies directly connected to the world.
598
+
599
+ find_body_to_world(name: bool = False) -> int | str:
600
+ Returns the root body connected to the world.
601
+
602
+ findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
603
+ Returns all bodies with the specified joint type.
604
+
605
+ children(name: str, names: bool = False) -> list[int] | list[str]:
606
+ Returns the direct children of a given body.
607
+
608
+ findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
609
+ Finds all bodies in the subsystem rooted at a given link.
610
+
611
+ scan(f: Callable, in_types: str, *args, reverse: bool = False):
612
+ Iterates over system elements while applying a function.
613
+
614
+ parse() -> "System":
615
+ Parses the system, performing consistency checks and computing spatial inertia tensors.
616
+
617
+ render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
618
+ Renders frames of the system using maximal coordinates.
619
+
620
+ render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
621
+ Renders a predicted state transformation.
622
+
623
+ delete_system(link_name: str | list[str], strict: bool = True):
624
+ Removes a subsystem from the system.
625
+
626
+ make_sys_noimu(imu_link_names: Optional[list[str]] = None):
627
+ Returns a version of the system without IMU sensors.
628
+
629
+ inject_system(other_system: "System", at_body: Optional[str] = None):
630
+ Merges another system into this one.
631
+
632
+ morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
633
+ Reorders the system’s link hierarchy.
634
+
635
+ from_xml(path: str, seed: int = 1) -> "System":
636
+ Loads a system from an XML file.
637
+
638
+ from_str(xml: str, seed: int = 1) -> "System":
639
+ Loads a system from an XML string.
640
+
641
+ to_str(warn: bool = True) -> str:
642
+ Serializes the system to an XML string.
643
+
644
+ to_xml(path: str) -> None:
645
+ Saves the system as an XML file.
646
+
647
+ create(path_or_str: str, seed: int = 1) -> "System":
648
+ Creates a `System` instance from an XML file or string.
649
+
650
+ coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
651
+ Converts a coordinate vector to minimal coordinates (`q`), applying
652
+ constraints such as quaternion normalization.
653
+
654
+ Raises:
655
+ AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
656
+ InvalidSystemError: If an operation results in an inconsistent system state.
657
+
658
+ Notes:
659
+ - The system must be parsed before use to ensure consistency.
660
+ - The system supports batch operations using JAX for efficient computations.
661
+ - Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
662
+ - Inertial properties of links are computed automatically from associated geometries.
663
+ """ # noqa: E501
664
+
406
665
  link_parents: list[int] = struct.field(False)
407
666
  links: Link
408
667
  link_types: list[str] = struct.field(False)
@@ -429,25 +688,32 @@ class System(_Base):
429
688
  omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
430
689
 
431
690
  def num_links(self) -> int:
691
+ "Returns the number of links in the system."
432
692
  return len(self.link_parents)
433
693
 
434
694
  def q_size(self) -> int:
695
+ "Returns the total number of generalized coordinates (`q`) in the system."
435
696
  return sum([Q_WIDTHS[typ] for typ in self.link_types])
436
697
 
437
698
  def qd_size(self) -> int:
699
+ "Returns the total number of generalized velocities (`qd`) in the system."
438
700
  return sum([QD_WIDTHS[typ] for typ in self.link_types])
439
701
 
440
702
  def name_to_idx(self, name: str) -> int:
703
+ "Returns the index of a link given its name."
441
704
  return self.link_names.index(name)
442
705
 
443
706
  def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
707
+ """Returns the name of a link given its index. If `allow_world` is `True`,
708
+ returns `"world"` for index `-1`."""
444
709
  if allow_world and idx == -1:
445
710
  return "world"
446
711
  assert idx >= 0, "Worldbody index has no name."
447
712
  return self.link_names[idx]
448
713
 
449
714
  def idx_map(self, type: str) -> dict:
450
- "type: is either `l` or `q` or `d`"
715
+ """Returns a dictionary mapping link names to their indices for a specified type
716
+ (`"l"`, `"q"`, or `"d"`)."""
451
717
  dict_int_slices = {}
452
718
 
453
719
  def f(_, idx_map, name: str, link_idx: int):
@@ -458,10 +724,11 @@ class System(_Base):
458
724
  return dict_int_slices
459
725
 
460
726
  def parent_name(self, name: str) -> str:
727
+ "Returns the name of the parent link for a given link."
461
728
  return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
462
729
 
463
730
  def add_prefix(self, prefix: str = "") -> "System":
464
- return self.replace(link_names=[prefix + name for name in self.link_names])
731
+ return self.add_prefix_suffix(prefix=prefix)
465
732
 
466
733
  def change_model_name(
467
734
  self,
@@ -469,6 +736,7 @@ class System(_Base):
469
736
  prefix: Optional[str] = None,
470
737
  suffix: Optional[str] = None,
471
738
  ) -> "System":
739
+ "Changes the name of the system model."
472
740
  if prefix is None:
473
741
  prefix = ""
474
742
  if suffix is None:
@@ -479,6 +747,7 @@ class System(_Base):
479
747
  return self.replace(model_name=name)
480
748
 
481
749
  def change_link_name(self, old_name: str, new_name: str) -> "System":
750
+ "Renames a specific link."
482
751
  old_idx = self.name_to_idx(old_name)
483
752
  new_link_names = self.link_names.copy()
484
753
  new_link_names[old_idx] = new_name
@@ -487,6 +756,7 @@ class System(_Base):
487
756
  def add_prefix_suffix(
488
757
  self, prefix: Optional[str] = None, suffix: Optional[str] = None
489
758
  ) -> "System":
759
+ "Adds either or, or both a prefix and suffix to all link names."
490
760
  if prefix is None:
491
761
  prefix = ""
492
762
  if suffix is None:
@@ -526,6 +796,7 @@ class System(_Base):
526
796
  return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
527
797
 
528
798
  def freeze(self, name: str | list[str]):
799
+ "Freezes the specified link(s), making them immovable (uses `frozen` joint)"
529
800
  if isinstance(name, list):
530
801
  sys = self
531
802
  for n in name:
@@ -544,6 +815,7 @@ class System(_Base):
544
815
  return _update_sys_if_replace_joint_type(self, logic_freeze)
545
816
 
546
817
  def unfreeze(self, name: str, new_joint_type: str):
818
+ "Unfreezes a frozen link and assigns it a new joint type."
547
819
  assert self.link_types[self.name_to_idx(name)] == "frozen"
548
820
  assert new_joint_type != "frozen"
549
821
 
@@ -560,7 +832,8 @@ class System(_Base):
560
832
  seed: int = 1,
561
833
  warn: bool = True,
562
834
  ):
563
- "By default damping, stiffness are set to zero."
835
+ """Changes the joint type of a specified link.
836
+ By default damping, stiffness are set to zero."""
564
837
  from ring.algorithms import get_joint_model
565
838
 
566
839
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
@@ -594,6 +867,7 @@ class System(_Base):
594
867
 
595
868
  @staticmethod
596
869
  def joint_type_simplification(typ: str) -> str:
870
+ "Returns a simplified name of the given joint type."
597
871
  if typ[:4] == "free":
598
872
  if typ == "free_2d":
599
873
  return "free_2d"
@@ -608,23 +882,28 @@ class System(_Base):
608
882
 
609
883
  @staticmethod
610
884
  def joint_type_is_free_or_cor(typ: str) -> bool:
885
+ 'Checks if a joint type is either "free" or "cor".'
611
886
  return System.joint_type_simplification(typ) in ["free", "cor"]
612
887
 
613
888
  @staticmethod
614
889
  def joint_type_is_spherical(typ: str) -> bool:
890
+ 'Checks if a joint type is "spherical".'
615
891
  return System.joint_type_simplification(typ) == "spherical"
616
892
 
617
893
  @staticmethod
618
894
  def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
895
+ 'Checks if a joint type is "free", "cor", or "spherical".'
619
896
  return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
620
897
  typ
621
898
  )
622
899
 
623
900
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
901
+ "Finds all IMU sensors in the system."
624
902
  bodies = [name for name in self.link_names if name[:3] == "imu"]
625
903
  return bodies if names else [self.name_to_idx(n) for n in bodies]
626
904
 
627
905
  def findall_segments(self, names: bool = True) -> list[str] | list[int]:
906
+ "Finds all non-IMU segments in the system."
628
907
  imus = self.findall_imus(names=True)
629
908
  bodies = [name for name in self.link_names if name not in imus]
630
909
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -633,10 +912,12 @@ class System(_Base):
633
912
  return [self.idx_to_name(i) for i in bodies]
634
913
 
635
914
  def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
915
+ "Returns all bodies directly connected to the world."
636
916
  bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
637
917
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
638
918
 
639
919
  def find_body_to_world(self, name: bool = False) -> int | str:
920
+ "Returns the root body connected to the world."
640
921
  bodies = self.findall_bodies_to_world(names=name)
641
922
  assert len(bodies) == 1
642
923
  return bodies[0]
@@ -644,6 +925,7 @@ class System(_Base):
644
925
  def findall_bodies_with_jointtype(
645
926
  self, typ: str, names: bool = False
646
927
  ) -> list[int] | list[str]:
928
+ "Returns all bodies with the specified joint type."
647
929
  bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
648
930
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
649
931
 
@@ -781,20 +1063,25 @@ class System(_Base):
781
1063
 
782
1064
  @staticmethod
783
1065
  def from_xml(path: str, seed: int = 1):
1066
+ "Loads a system from an XML file."
784
1067
  return ring.io.load_sys_from_xml(path, seed)
785
1068
 
786
1069
  @staticmethod
787
1070
  def from_str(xml: str, seed: int = 1):
1071
+ "Loads a system from an XML string."
788
1072
  return ring.io.load_sys_from_str(xml, seed)
789
1073
 
790
1074
  def to_str(self, warn: bool = True) -> str:
1075
+ "Serializes the system to an XML string."
791
1076
  return ring.io.save_sys_to_str(self, warn=warn)
792
1077
 
793
1078
  def to_xml(self, path: str) -> None:
1079
+ "Saves the system to an XML file."
794
1080
  ring.io.save_sys_to_xml(self, path)
795
1081
 
796
1082
  @classmethod
797
1083
  def create(cls, path_or_str: str, seed: int = 1) -> "System":
1084
+ "Creates a `System` instance from an XML file or string."
798
1085
  path = Path(path_or_str).with_suffix(".xml")
799
1086
 
800
1087
  exists = False
@@ -807,14 +1094,15 @@ class System(_Base):
807
1094
  if exists:
808
1095
  return cls.from_xml(path, seed=seed)
809
1096
  else:
810
- return cls.from_str(path_or_str)
1097
+ return cls.from_str(path_or_str, seed=seed)
811
1098
 
812
1099
  def coordinate_vector_to_q(
813
1100
  self,
814
1101
  q: jax.Array,
815
1102
  custom_joints: dict[str, Callable] = {},
816
1103
  ) -> jax.Array:
817
- """Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
1104
+ """Converts a coordinate vector to minimal coordinates (`q`), applying
1105
+ constraints such as quaternion normalization."""
818
1106
  # Does, e.g.
819
1107
  # - normalize quaternions
820
1108
  # - hinge joints in [-pi, pi]
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
1026
1314
 
1027
1315
  @struct.dataclass
1028
1316
  class State(_Base):
1029
- """The static and dynamic state of a system in minimal and maximal coordinates.
1030
- Use `.create()` to create this object.
1031
-
1032
- Args:
1033
- q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
1034
- qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
1035
- x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
1036
1317
  """
1318
+ Represents the state of a dynamic system in minimal and maximal coordinates.
1319
+
1320
+ The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
1321
+ of the system in minimal coordinates, as well as the corresponding transforms (`x`)
1322
+ in maximal coordinates.
1323
+
1324
+ Attributes:
1325
+ q (jax.Array):
1326
+ The joint positions (generalized coordinates) of the system. The size
1327
+ of `q` matches `sys.q_size()`.
1328
+ qd (jax.Array):
1329
+ The joint velocities (generalized velocities) of the system. The size
1330
+ of `qd` matches `sys.qd_size()`.
1331
+ x (Transform):
1332
+ The maximal coordinate representation of all system links, expressed as
1333
+ a `Transform` object.
1334
+
1335
+ Methods:
1336
+ create(sys: System, q: Optional[jax.Array] = None,
1337
+ qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
1338
+ key: Optional[jax.Array] = None,
1339
+ custom_joints: dict[str, Callable] = {}) -> State:
1340
+ Creates a `State` instance for a given system with optional initial conditions.
1341
+
1342
+ Usage:
1343
+ >>> sys = System.create("model.xml")
1344
+ >>> state = State.create(sys)
1345
+ >>> print(state.q.shape) # Should match sys.q_size()
1346
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1347
+ """ # noqa: E501
1037
1348
 
1038
1349
  q: jax.Array
1039
1350
  qd: jax.Array
@@ -1048,19 +1359,37 @@ class State(_Base):
1048
1359
  x: Optional[Transform] = None,
1049
1360
  key: Optional[jax.Array] = None,
1050
1361
  custom_joints: dict[str, Callable] = {},
1051
- ):
1052
- """Create state of system.
1362
+ ) -> "State":
1363
+ """
1364
+ Creates a `State` instance for the given system with optional initial conditions.
1365
+
1366
+ If no initial values are provided, joint positions (`q`) and velocities (`qd`)
1367
+ are initialized to zero, except for free and spherical joints, which have unit quaternions.
1053
1368
 
1054
1369
  Args:
1055
- sys (System): The system for which to create a state.
1056
- q (jax.Array, optional): The joint values of the system. Defaults to None.
1057
- Which then defaults to zeros.
1058
- qd (jax.Array, optional): The joint velocities of the system.
1059
- Defaults to None. Which then defaults to zeros.
1370
+ sys (System):
1371
+ The system for which to create a state.
1372
+ q (Optional[jax.Array], default=None):
1373
+ Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
1374
+ for free and spherical joints.
1375
+ qd (Optional[jax.Array], default=None):
1376
+ Initial joint velocities. If `None`, defaults to zeros.
1377
+ x (Optional[Transform], default=None):
1378
+ Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
1379
+ key (Optional[jax.Array], default=None):
1380
+ Random key for initializing `q` if no values are provided.
1381
+ custom_joints (dict[str, Callable], default={}):
1382
+ Custom joint functions for mapping coordinate vectors to minimal coordinates.
1060
1383
 
1061
1384
  Returns:
1062
- (State): Create State object.
1063
- """
1385
+ State: A new instance of the `State` class representing the initialized system state.
1386
+
1387
+ Example:
1388
+ >>> sys = System.create("model.xml")
1389
+ >>> state = State.create(sys)
1390
+ >>> print(state.q.shape) # Should match sys.q_size()
1391
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1392
+ """ # noqa: E501
1064
1393
  if key is not None:
1065
1394
  assert q is None
1066
1395
  q = jax.random.normal(key, shape=(sys.q_size(),))
ring/io/xml/from_xml.py CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
252
252
 
253
253
  # numpy -> jax
254
254
  # we load using numpy in order to have float64 precision
255
- sys = jax.tree_map(jax.numpy.asarray, sys)
255
+ sys = jax.tree.map(jax.numpy.asarray, sys)
256
256
 
257
257
  sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
258
258
 
ring/ml/base.py CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
13
13
  def _to_3d(tree):
14
14
  if tree is None:
15
15
  return None
16
- return jax.tree_map(lambda arr: arr[None], tree)
16
+ return jax.tree.map(lambda arr: arr[None], tree)
17
17
 
18
18
 
19
19
  def _to_2d(tree, i: int = 0):
20
20
  if tree is None:
21
21
  return None
22
- return jax.tree_map(lambda arr: arr[i], tree)
22
+ return jax.tree.map(lambda arr: arr[i], tree)
23
23
 
24
24
 
25
25
  class AbstractFilter(ABC):
@@ -297,7 +297,8 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
297
297
 
298
298
  if self._quat_normalize:
299
299
  assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
300
- yhat = ring.maths.safe_normalize(yhat)
300
+ # yhat = ring.maths.safe_normalize(yhat)
301
+ yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
301
302
 
302
303
  return yhat, state
303
304
 
ring/ml/ml_utils.py CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
161
161
  metrices: NestedDict, filter_nan_inf: bool = True
162
162
  ):
163
163
  metrices = _flatten_dict(metrices)
164
- metrices = jax.tree_map(_to_float_if_not_string, metrices)
164
+ metrices = jax.tree.map(_to_float_if_not_string, metrices)
165
165
 
166
166
  if not filter_nan_inf:
167
167
  return metrices
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
216
216
  from jax.experimental import jax2tf
217
217
  import tensorflow as tf
218
218
 
219
- signature = jax.tree_map(
219
+ signature = jax.tree.map(
220
220
  lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
221
221
  )
222
222
 
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
241
241
  if validate:
242
242
  output_jax = jax_func(*input)
243
243
  output_tf = tf.saved_model.load(path)(*input)
244
- jax.tree_map(
244
+ jax.tree.map(
245
245
  lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
246
246
  output_jax,
247
247
  output_tf,
ring/ml/ringnet.py CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
248
248
  params, state = self.forward_lam_factory(lam=lam).init(key, X)
249
249
 
250
250
  if bs is not None:
251
- state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
251
+ state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
252
252
 
253
253
  return params, state
254
254
 
ring/ml/train.py CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
50
50
  # this vmap maps along batch-axis, not time-axis
51
51
  # time-axis is handled by `metric_fn`
52
52
  pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
53
- error_tree = jax.tree_map(pipe, y, yhat)
53
+ error_tree = jax.tree.map(pipe, y, yhat)
54
54
  return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
55
55
 
56
56
  @partial(
@@ -274,7 +274,7 @@ def _build_eval_fn(
274
274
  ), f"The metric identitifier {metric_name} is not unique"
275
275
 
276
276
  pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
277
- values.update({metric_name: jax.tree_map(pipe, y, yhat)})
277
+ values.update({metric_name: jax.tree.map(pipe, y, yhat)})
278
278
 
279
279
  return values
280
280