imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ring/base.py CHANGED
@@ -112,17 +112,71 @@ class _Base:
112
112
 
113
113
  @struct.dataclass
114
114
  class Transform(_Base):
115
- """Represents the Transformation from Plücker A to Plücker B,
116
- where B is located relative to A at `pos` in frame A and `rot` is the
117
- relative quaternion from A to B.
118
- Create using `Transform.create(pos=..., rot=...)
119
115
  """
116
+ Represents a spatial transformation between two coordinate frames using Plücker coordinates.
117
+
118
+ The `Transform` class defines the relative position and orientation of one frame (`B`)
119
+ with respect to another frame (`A`). The position (`pos`) is given in the coordinate frame
120
+ of `A`, and the rotation (`rot`) is expressed as a unit quaternion representing the relative
121
+ rotation from frame `A` to frame `B`.
122
+
123
+ Attributes:
124
+ pos (Vector):
125
+ The translation vector (position of `B` relative to `A`) in the coordinate frame of `A`.
126
+ Shape: `(..., 3)`, where `...` represents optional batch dimensions.
127
+ rot (Quaternion):
128
+ The unit quaternion representing the orientation of `B` relative to `A`.
129
+ Shape: `(..., 4)`, where `...` represents optional batch dimensions.
130
+
131
+ Methods:
132
+ create(pos: Optional[Vector] = None, rot: Optional[Quaternion] = None) -> Transform:
133
+ Creates a `Transform` instance with optional position and rotation.
134
+
135
+ zero(shape: Sequence[int] = ()) -> Transform:
136
+ Returns a zero transform with a given batch shape.
137
+
138
+ as_matrix() -> jax.Array:
139
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
140
+
141
+ Usage:
142
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
143
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
144
+ >>> T = Transform.create(pos, rot)
145
+ >>> print(T.pos) # Output: [1. 2. 3.]
146
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
147
+ >>> print(T.as_matrix()) # 4x4 transformation matrix
148
+ """ # noqa: E501
120
149
 
121
150
  pos: Vector
122
151
  rot: Quaternion
123
152
 
124
153
  @classmethod
125
154
  def create(cls, pos=None, rot=None):
155
+ """
156
+ Creates a `Transform` instance with the specified position and rotation.
157
+
158
+ At least one of `pos` or `rot` must be provided. If only `pos` is given, the rotation
159
+ defaults to the identity quaternion `[1, 0, 0, 0]`. If only `rot` is given, the position
160
+ defaults to `[0, 0, 0]`.
161
+
162
+ Args:
163
+ pos (Optional[Vector], default=None):
164
+ The position of frame `B` relative to frame `A`, expressed in frame `A` coordinates.
165
+ If `None`, defaults to a zero vector of shape `(3,)`.
166
+ rot (Optional[Quaternion], default=None):
167
+ The unit quaternion representing the orientation of `B` relative to `A`.
168
+ If `None`, defaults to the identity quaternion `(1, 0, 0, 0)`.
169
+
170
+ Returns:
171
+ Transform: A new `Transform` instance with the specified position and rotation.
172
+
173
+ Example:
174
+ >>> pos = jnp.array([1.0, 2.0, 3.0])
175
+ >>> rot = jnp.array([1.0, 0.0, 0.0, 0.0]) # Identity quaternion
176
+ >>> T = Transform.create(pos, rot)
177
+ >>> print(T.pos) # Output: [1. 2. 3.]
178
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
179
+ """ # noqa: E501
126
180
  assert not (pos is None and rot is None), "One must be given."
127
181
  shape_rot = rot.shape[:-1] if rot is not None else ()
128
182
  shape_pos = pos.shape[:-1] if pos is not None else ()
@@ -139,12 +193,49 @@ class Transform(_Base):
139
193
 
140
194
  @classmethod
141
195
  def zero(cls, shape=()) -> "Transform":
142
- """Returns a zero transform with a batch shape."""
196
+ """
197
+ Returns a zero transform with a given batch shape.
198
+
199
+ This creates a transform with position `(0, 0, 0)` and an identity quaternion `(1, 0, 0, 0)`,
200
+ which represents no translation or rotation.
201
+
202
+ Args:
203
+ shape (Sequence[int], default=()):
204
+ The batch shape for the transform. Defaults to a scalar transform.
205
+
206
+ Returns:
207
+ Transform: A zero transform with the specified batch shape.
208
+
209
+ Example:
210
+ >>> T = Transform.zero()
211
+ >>> print(T.pos) # Output: [0. 0. 0.]
212
+ >>> print(T.rot) # Output: [1. 0. 0. 0.]
213
+ """ # noqa: E501
143
214
  pos = jnp.zeros(shape + (3,))
144
215
  rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
145
216
  return Transform(pos, rot)
146
217
 
147
218
  def as_matrix(self) -> jax.Array:
219
+ """
220
+ Returns the 4x4 homogeneous transformation matrix representation of this transform.
221
+
222
+ The homogeneous transformation matrix is defined as:
223
+
224
+ ```
225
+ [ R t ]
226
+ [ 0 1 ]
227
+ ```
228
+
229
+ where `R` is the 3x3 rotation matrix converted from the quaternion and `t` is the
230
+ 3x1 position vector.
231
+
232
+ Returns:
233
+ jax.Array: A `(4, 4)` homogeneous transformation matrix.
234
+
235
+ Example:
236
+ >>> T = Transform.create(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 0.0, 0.0, 0.0]))
237
+ >>> print(T.as_matrix()) # Output: 4x4 matrix
238
+ """ # noqa: E501
148
239
  E = maths.quat_to_3x3(self.rot)
149
240
  return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
150
241
 
@@ -402,7 +493,175 @@ QD_WIDTHS = {
402
493
 
403
494
  @struct.dataclass
404
495
  class System(_Base):
405
- "System object. Create using `System.create(path_xml)`"
496
+ """
497
+ Represents a robotic system consisting of interconnected links and joints. Create it using `System.create(...)`
498
+
499
+ The `System` class models the kinematic and dynamic properties of a multibody
500
+ system, providing methods for state representation, transformations, joint
501
+ configuration management, and rendering. It supports both minimal and maximal
502
+ coordinate representations and can be parsed from or saved to XML files.
503
+
504
+ Attributes:
505
+ link_parents (list[int]):
506
+ A list specifying the parent index for each link. The root link has a parent index of `-1`.
507
+ links (Link):
508
+ A data structure containing information about all links in the system.
509
+ link_types (list[str]):
510
+ A list specifying the joint type for each link (e.g., "free", "hinge", "prismatic").
511
+ link_damping (jax.Array):
512
+ Joint damping coefficients for each link.
513
+ link_armature (jax.Array):
514
+ Armature inertia values for each joint.
515
+ link_spring_stiffness (jax.Array):
516
+ Stiffness values for joint springs.
517
+ link_spring_zeropoint (jax.Array):
518
+ Rest position of joint springs.
519
+ dt (float):
520
+ Simulation time step size.
521
+ geoms (list[Geometry]):
522
+ List of geometries associated with the system.
523
+ gravity (jax.Array):
524
+ Gravity vector applied to the system (default: `[0, 0, -9.81]`).
525
+ integration_method (str):
526
+ Integration method for simulation (default: "semi_implicit_euler").
527
+ mass_mat_iters (int):
528
+ Number of iterations for mass matrix calculations.
529
+ link_names (list[str]):
530
+ Names of the links in the system.
531
+ model_name (Optional[str]):
532
+ Name of the system model (if available).
533
+ omc (list[MaxCoordOMC | None]):
534
+ List of optional Maximal Coordinate representations.
535
+
536
+ Methods:
537
+ num_links() -> int:
538
+ Returns the number of links in the system.
539
+
540
+ q_size() -> int:
541
+ Returns the total number of generalized coordinates (`q`) in the system.
542
+
543
+ qd_size() -> int:
544
+ Returns the total number of generalized velocities (`qd`) in the system.
545
+
546
+ name_to_idx(name: str) -> int:
547
+ Returns the index of a link given its name.
548
+
549
+ idx_to_name(idx: int, allow_world: bool = False) -> str:
550
+ Returns the name of a link given its index. If `allow_world` is `True`,
551
+ returns `"world"` for index `-1`.
552
+
553
+ idx_map(type: str) -> dict:
554
+ Returns a dictionary mapping link names to their indices for a specified type
555
+ (`"l"`, `"q"`, or `"d"`).
556
+
557
+ parent_name(name: str) -> str:
558
+ Returns the name of the parent link for a given link.
559
+
560
+ change_model_name(new_name: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
561
+ Changes the name of the system model.
562
+
563
+ change_link_name(old_name: str, new_name: str) -> "System":
564
+ Renames a specific link.
565
+
566
+ add_prefix_suffix(prefix: Optional[str] = None, suffix: Optional[str] = None) -> "System":
567
+ Adds both a prefix and suffix to all link names.
568
+
569
+ freeze(name: str | list[str]) -> "System":
570
+ Freezes the specified link(s), making them immovable.
571
+
572
+ unfreeze(name: str, new_joint_type: str) -> "System":
573
+ Unfreezes a frozen link and assigns it a new joint type.
574
+
575
+ change_joint_type(name: str, new_joint_type: str, **kwargs) -> "System":
576
+ Changes the joint type of a specified link.
577
+
578
+ joint_type_simplification(typ: str) -> str:
579
+ Returns a simplified representation of the given joint type.
580
+
581
+ joint_type_is_free_or_cor(typ: str) -> bool:
582
+ Checks if a joint type is either "free" or "cor".
583
+
584
+ joint_type_is_spherical(typ: str) -> bool:
585
+ Checks if a joint type is "spherical".
586
+
587
+ joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
588
+ Checks if a joint type is "free", "cor", or "spherical".
589
+
590
+ findall_imus(names: bool = True) -> list[str] | list[int]:
591
+ Finds all IMU sensors in the system.
592
+
593
+ findall_segments(names: bool = True) -> list[str] | list[int]:
594
+ Finds all non-IMU segments in the system.
595
+
596
+ findall_bodies_to_world(names: bool = False) -> list[int] | list[str]:
597
+ Returns all bodies directly connected to the world.
598
+
599
+ find_body_to_world(name: bool = False) -> int | str:
600
+ Returns the root body connected to the world.
601
+
602
+ findall_bodies_with_jointtype(typ: str, names: bool = False) -> list[int] | list[str]:
603
+ Returns all bodies with the specified joint type.
604
+
605
+ children(name: str, names: bool = False) -> list[int] | list[str]:
606
+ Returns the direct children of a given body.
607
+
608
+ findall_bodies_subsystem(name: str, names: bool = False) -> list[int] | list[str]:
609
+ Finds all bodies in the subsystem rooted at a given link.
610
+
611
+ scan(f: Callable, in_types: str, *args, reverse: bool = False):
612
+ Iterates over system elements while applying a function.
613
+
614
+ parse() -> "System":
615
+ Parses the system, performing consistency checks and computing spatial inertia tensors.
616
+
617
+ render(xs: Optional[Transform | list[Transform]] = None, **kwargs) -> list[np.ndarray]:
618
+ Renders frames of the system using maximal coordinates.
619
+
620
+ render_prediction(xs: Transform | list[Transform], yhat: dict | jax.Array | np.ndarray, **kwargs):
621
+ Renders a predicted state transformation.
622
+
623
+ delete_system(link_name: str | list[str], strict: bool = True):
624
+ Removes a subsystem from the system.
625
+
626
+ make_sys_noimu(imu_link_names: Optional[list[str]] = None):
627
+ Returns a version of the system without IMU sensors.
628
+
629
+ inject_system(other_system: "System", at_body: Optional[str] = None):
630
+ Merges another system into this one.
631
+
632
+ morph_system(new_parents: Optional[list[int | str]] = None, new_anchor: Optional[int | str] = None):
633
+ Reorders the system’s link hierarchy.
634
+
635
+ from_xml(path: str, seed: int = 1) -> "System":
636
+ Loads a system from an XML file.
637
+
638
+ from_str(xml: str, seed: int = 1) -> "System":
639
+ Loads a system from an XML string.
640
+
641
+ to_str(warn: bool = True) -> str:
642
+ Serializes the system to an XML string.
643
+
644
+ to_xml(path: str) -> None:
645
+ Saves the system as an XML file.
646
+
647
+ create(path_or_str: str, seed: int = 1) -> "System":
648
+ Creates a `System` instance from an XML file or string.
649
+
650
+ coordinate_vector_to_q(q: jax.Array, custom_joints: dict[str, Callable] = {}) -> jax.Array:
651
+ Converts a coordinate vector to minimal coordinates (`q`), applying
652
+ constraints such as quaternion normalization.
653
+
654
+ Raises:
655
+ AssertionError: If the system structure is invalid (e.g., duplicate link names, incorrect parent-child relationships).
656
+ InvalidSystemError: If an operation results in an inconsistent system state.
657
+
658
+ Notes:
659
+ - The system must be parsed before use to ensure consistency.
660
+ - The system supports batch operations using JAX for efficient computations.
661
+ - Joint types include revolute ("rx", "ry", "rz"), prismatic ("px", "py", "pz"), spherical, free, and more.
662
+ - Inertial properties of links are computed automatically from associated geometries.
663
+ """ # noqa: E501
664
+
406
665
  link_parents: list[int] = struct.field(False)
407
666
  links: Link
408
667
  link_types: list[str] = struct.field(False)
@@ -429,25 +688,32 @@ class System(_Base):
429
688
  omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
430
689
 
431
690
  def num_links(self) -> int:
691
+ "Returns the number of links in the system."
432
692
  return len(self.link_parents)
433
693
 
434
694
  def q_size(self) -> int:
695
+ "Returns the total number of generalized coordinates (`q`) in the system."
435
696
  return sum([Q_WIDTHS[typ] for typ in self.link_types])
436
697
 
437
698
  def qd_size(self) -> int:
699
+ "Returns the total number of generalized velocities (`qd`) in the system."
438
700
  return sum([QD_WIDTHS[typ] for typ in self.link_types])
439
701
 
440
702
  def name_to_idx(self, name: str) -> int:
703
+ "Returns the index of a link given its name."
441
704
  return self.link_names.index(name)
442
705
 
443
706
  def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
707
+ """Returns the name of a link given its index. If `allow_world` is `True`,
708
+ returns `"world"` for index `-1`."""
444
709
  if allow_world and idx == -1:
445
710
  return "world"
446
711
  assert idx >= 0, "Worldbody index has no name."
447
712
  return self.link_names[idx]
448
713
 
449
714
  def idx_map(self, type: str) -> dict:
450
- "type: is either `l` or `q` or `d`"
715
+ """Returns a dictionary mapping link names to their indices for a specified type
716
+ (`"l"`, `"q"`, or `"d"`)."""
451
717
  dict_int_slices = {}
452
718
 
453
719
  def f(_, idx_map, name: str, link_idx: int):
@@ -458,10 +724,11 @@ class System(_Base):
458
724
  return dict_int_slices
459
725
 
460
726
  def parent_name(self, name: str) -> str:
727
+ "Returns the name of the parent link for a given link."
461
728
  return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
462
729
 
463
730
  def add_prefix(self, prefix: str = "") -> "System":
464
- return self.replace(link_names=[prefix + name for name in self.link_names])
731
+ return self.add_prefix_suffix(prefix=prefix)
465
732
 
466
733
  def change_model_name(
467
734
  self,
@@ -469,6 +736,7 @@ class System(_Base):
469
736
  prefix: Optional[str] = None,
470
737
  suffix: Optional[str] = None,
471
738
  ) -> "System":
739
+ "Changes the name of the system model."
472
740
  if prefix is None:
473
741
  prefix = ""
474
742
  if suffix is None:
@@ -479,6 +747,7 @@ class System(_Base):
479
747
  return self.replace(model_name=name)
480
748
 
481
749
  def change_link_name(self, old_name: str, new_name: str) -> "System":
750
+ "Renames a specific link."
482
751
  old_idx = self.name_to_idx(old_name)
483
752
  new_link_names = self.link_names.copy()
484
753
  new_link_names[old_idx] = new_name
@@ -487,6 +756,7 @@ class System(_Base):
487
756
  def add_prefix_suffix(
488
757
  self, prefix: Optional[str] = None, suffix: Optional[str] = None
489
758
  ) -> "System":
759
+ "Adds either or, or both a prefix and suffix to all link names."
490
760
  if prefix is None:
491
761
  prefix = ""
492
762
  if suffix is None:
@@ -526,6 +796,7 @@ class System(_Base):
526
796
  return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
527
797
 
528
798
  def freeze(self, name: str | list[str]):
799
+ "Freezes the specified link(s), making them immovable (uses `frozen` joint)"
529
800
  if isinstance(name, list):
530
801
  sys = self
531
802
  for n in name:
@@ -544,6 +815,7 @@ class System(_Base):
544
815
  return _update_sys_if_replace_joint_type(self, logic_freeze)
545
816
 
546
817
  def unfreeze(self, name: str, new_joint_type: str):
818
+ "Unfreezes a frozen link and assigns it a new joint type."
547
819
  assert self.link_types[self.name_to_idx(name)] == "frozen"
548
820
  assert new_joint_type != "frozen"
549
821
 
@@ -560,7 +832,8 @@ class System(_Base):
560
832
  seed: int = 1,
561
833
  warn: bool = True,
562
834
  ):
563
- "By default damping, stiffness are set to zero."
835
+ """Changes the joint type of a specified link.
836
+ By default damping, stiffness are set to zero."""
564
837
  from ring.algorithms import get_joint_model
565
838
 
566
839
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
@@ -594,6 +867,7 @@ class System(_Base):
594
867
 
595
868
  @staticmethod
596
869
  def joint_type_simplification(typ: str) -> str:
870
+ "Returns a simplified name of the given joint type."
597
871
  if typ[:4] == "free":
598
872
  if typ == "free_2d":
599
873
  return "free_2d"
@@ -608,23 +882,28 @@ class System(_Base):
608
882
 
609
883
  @staticmethod
610
884
  def joint_type_is_free_or_cor(typ: str) -> bool:
885
+ 'Checks if a joint type is either "free" or "cor".'
611
886
  return System.joint_type_simplification(typ) in ["free", "cor"]
612
887
 
613
888
  @staticmethod
614
889
  def joint_type_is_spherical(typ: str) -> bool:
890
+ 'Checks if a joint type is "spherical".'
615
891
  return System.joint_type_simplification(typ) == "spherical"
616
892
 
617
893
  @staticmethod
618
894
  def joint_type_is_free_or_cor_or_spherical(typ: str) -> bool:
895
+ 'Checks if a joint type is "free", "cor", or "spherical".'
619
896
  return System.joint_type_is_free_or_cor(typ) or System.joint_type_is_spherical(
620
897
  typ
621
898
  )
622
899
 
623
900
  def findall_imus(self, names: bool = True) -> list[str] | list[int]:
901
+ "Finds all IMU sensors in the system."
624
902
  bodies = [name for name in self.link_names if name[:3] == "imu"]
625
903
  return bodies if names else [self.name_to_idx(n) for n in bodies]
626
904
 
627
905
  def findall_segments(self, names: bool = True) -> list[str] | list[int]:
906
+ "Finds all non-IMU segments in the system."
628
907
  imus = self.findall_imus(names=True)
629
908
  bodies = [name for name in self.link_names if name not in imus]
630
909
  return bodies if names else [self.name_to_idx(n) for n in bodies]
@@ -633,10 +912,12 @@ class System(_Base):
633
912
  return [self.idx_to_name(i) for i in bodies]
634
913
 
635
914
  def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
915
+ "Returns all bodies directly connected to the world."
636
916
  bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
637
917
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
638
918
 
639
919
  def find_body_to_world(self, name: bool = False) -> int | str:
920
+ "Returns the root body connected to the world."
640
921
  bodies = self.findall_bodies_to_world(names=name)
641
922
  assert len(bodies) == 1
642
923
  return bodies[0]
@@ -644,6 +925,7 @@ class System(_Base):
644
925
  def findall_bodies_with_jointtype(
645
926
  self, typ: str, names: bool = False
646
927
  ) -> list[int] | list[str]:
928
+ "Returns all bodies with the specified joint type."
647
929
  bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
648
930
  return self._bodies_indices_to_bodies_name(bodies) if names else bodies
649
931
 
@@ -781,20 +1063,25 @@ class System(_Base):
781
1063
 
782
1064
  @staticmethod
783
1065
  def from_xml(path: str, seed: int = 1):
1066
+ "Loads a system from an XML file."
784
1067
  return ring.io.load_sys_from_xml(path, seed)
785
1068
 
786
1069
  @staticmethod
787
1070
  def from_str(xml: str, seed: int = 1):
1071
+ "Loads a system from an XML string."
788
1072
  return ring.io.load_sys_from_str(xml, seed)
789
1073
 
790
1074
  def to_str(self, warn: bool = True) -> str:
1075
+ "Serializes the system to an XML string."
791
1076
  return ring.io.save_sys_to_str(self, warn=warn)
792
1077
 
793
1078
  def to_xml(self, path: str) -> None:
1079
+ "Saves the system to an XML file."
794
1080
  ring.io.save_sys_to_xml(self, path)
795
1081
 
796
1082
  @classmethod
797
1083
  def create(cls, path_or_str: str, seed: int = 1) -> "System":
1084
+ "Creates a `System` instance from an XML file or string."
798
1085
  path = Path(path_or_str).with_suffix(".xml")
799
1086
 
800
1087
  exists = False
@@ -807,14 +1094,15 @@ class System(_Base):
807
1094
  if exists:
808
1095
  return cls.from_xml(path, seed=seed)
809
1096
  else:
810
- return cls.from_str(path_or_str)
1097
+ return cls.from_str(path_or_str, seed=seed)
811
1098
 
812
1099
  def coordinate_vector_to_q(
813
1100
  self,
814
1101
  q: jax.Array,
815
1102
  custom_joints: dict[str, Callable] = {},
816
1103
  ) -> jax.Array:
817
- """Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
1104
+ """Converts a coordinate vector to minimal coordinates (`q`), applying
1105
+ constraints such as quaternion normalization."""
818
1106
  # Does, e.g.
819
1107
  # - normalize quaternions
820
1108
  # - hinge joints in [-pi, pi]
@@ -1026,14 +1314,37 @@ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = Fa
1026
1314
 
1027
1315
  @struct.dataclass
1028
1316
  class State(_Base):
1029
- """The static and dynamic state of a system in minimal and maximal coordinates.
1030
- Use `.create()` to create this object.
1031
-
1032
- Args:
1033
- q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
1034
- qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
1035
- x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
1036
1317
  """
1318
+ Represents the state of a dynamic system in minimal and maximal coordinates.
1319
+
1320
+ The `State` class encapsulates both the configuration (`q`) and velocity (`qd`)
1321
+ of the system in minimal coordinates, as well as the corresponding transforms (`x`)
1322
+ in maximal coordinates.
1323
+
1324
+ Attributes:
1325
+ q (jax.Array):
1326
+ The joint positions (generalized coordinates) of the system. The size
1327
+ of `q` matches `sys.q_size()`.
1328
+ qd (jax.Array):
1329
+ The joint velocities (generalized velocities) of the system. The size
1330
+ of `qd` matches `sys.qd_size()`.
1331
+ x (Transform):
1332
+ The maximal coordinate representation of all system links, expressed as
1333
+ a `Transform` object.
1334
+
1335
+ Methods:
1336
+ create(sys: System, q: Optional[jax.Array] = None,
1337
+ qd: Optional[jax.Array] = None, x: Optional[Transform] = None,
1338
+ key: Optional[jax.Array] = None,
1339
+ custom_joints: dict[str, Callable] = {}) -> State:
1340
+ Creates a `State` instance for a given system with optional initial conditions.
1341
+
1342
+ Usage:
1343
+ >>> sys = System.create("model.xml")
1344
+ >>> state = State.create(sys)
1345
+ >>> print(state.q.shape) # Should match sys.q_size()
1346
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1347
+ """ # noqa: E501
1037
1348
 
1038
1349
  q: jax.Array
1039
1350
  qd: jax.Array
@@ -1048,19 +1359,37 @@ class State(_Base):
1048
1359
  x: Optional[Transform] = None,
1049
1360
  key: Optional[jax.Array] = None,
1050
1361
  custom_joints: dict[str, Callable] = {},
1051
- ):
1052
- """Create state of system.
1362
+ ) -> "State":
1363
+ """
1364
+ Creates a `State` instance for the given system with optional initial conditions.
1365
+
1366
+ If no initial values are provided, joint positions (`q`) and velocities (`qd`)
1367
+ are initialized to zero, except for free and spherical joints, which have unit quaternions.
1053
1368
 
1054
1369
  Args:
1055
- sys (System): The system for which to create a state.
1056
- q (jax.Array, optional): The joint values of the system. Defaults to None.
1057
- Which then defaults to zeros.
1058
- qd (jax.Array, optional): The joint velocities of the system.
1059
- Defaults to None. Which then defaults to zeros.
1370
+ sys (System):
1371
+ The system for which to create a state.
1372
+ q (Optional[jax.Array], default=None):
1373
+ Initial joint positions. If `None`, defaults to zeros, with unit quaternion initialization
1374
+ for free and spherical joints.
1375
+ qd (Optional[jax.Array], default=None):
1376
+ Initial joint velocities. If `None`, defaults to zeros.
1377
+ x (Optional[Transform], default=None):
1378
+ Initial maximal coordinates of the system links. If `None`, defaults to zero transforms.
1379
+ key (Optional[jax.Array], default=None):
1380
+ Random key for initializing `q` if no values are provided.
1381
+ custom_joints (dict[str, Callable], default={}):
1382
+ Custom joint functions for mapping coordinate vectors to minimal coordinates.
1060
1383
 
1061
1384
  Returns:
1062
- (State): Create State object.
1063
- """
1385
+ State: A new instance of the `State` class representing the initialized system state.
1386
+
1387
+ Example:
1388
+ >>> sys = System.create("model.xml")
1389
+ >>> state = State.create(sys)
1390
+ >>> print(state.q.shape) # Should match sys.q_size()
1391
+ >>> print(state.qd.shape) # Should match sys.qd_size()
1392
+ """ # noqa: E501
1064
1393
  if key is not None:
1065
1394
  assert q is None
1066
1395
  q = jax.random.normal(key, shape=(sys.q_size(),))
ring/io/xml/from_xml.py CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
252
252
 
253
253
  # numpy -> jax
254
254
  # we load using numpy in order to have float64 precision
255
- sys = jax.tree_map(jax.numpy.asarray, sys)
255
+ sys = jax.tree.map(jax.numpy.asarray, sys)
256
256
 
257
257
  sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
258
258
 
ring/ml/base.py CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
13
13
  def _to_3d(tree):
14
14
  if tree is None:
15
15
  return None
16
- return jax.tree_map(lambda arr: arr[None], tree)
16
+ return jax.tree.map(lambda arr: arr[None], tree)
17
17
 
18
18
 
19
19
  def _to_2d(tree, i: int = 0):
20
20
  if tree is None:
21
21
  return None
22
- return jax.tree_map(lambda arr: arr[i], tree)
22
+ return jax.tree.map(lambda arr: arr[i], tree)
23
23
 
24
24
 
25
25
  class AbstractFilter(ABC):
@@ -297,7 +297,8 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
297
297
 
298
298
  if self._quat_normalize:
299
299
  assert yhat.shape[-1] == 4, f"yhat.shape={yhat.shape}"
300
- yhat = ring.maths.safe_normalize(yhat)
300
+ # yhat = ring.maths.safe_normalize(yhat)
301
+ yhat = yhat / jnp.linalg.norm(yhat, axis=-1, keepdims=True)
301
302
 
302
303
  return yhat, state
303
304
 
ring/ml/ml_utils.py CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
161
161
  metrices: NestedDict, filter_nan_inf: bool = True
162
162
  ):
163
163
  metrices = _flatten_dict(metrices)
164
- metrices = jax.tree_map(_to_float_if_not_string, metrices)
164
+ metrices = jax.tree.map(_to_float_if_not_string, metrices)
165
165
 
166
166
  if not filter_nan_inf:
167
167
  return metrices
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
216
216
  from jax.experimental import jax2tf
217
217
  import tensorflow as tf
218
218
 
219
- signature = jax.tree_map(
219
+ signature = jax.tree.map(
220
220
  lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
221
221
  )
222
222
 
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
241
241
  if validate:
242
242
  output_jax = jax_func(*input)
243
243
  output_tf = tf.saved_model.load(path)(*input)
244
- jax.tree_map(
244
+ jax.tree.map(
245
245
  lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
246
246
  output_jax,
247
247
  output_tf,
ring/ml/ringnet.py CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
248
248
  params, state = self.forward_lam_factory(lam=lam).init(key, X)
249
249
 
250
250
  if bs is not None:
251
- state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
251
+ state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
252
252
 
253
253
  return params, state
254
254
 
ring/ml/train.py CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
50
50
  # this vmap maps along batch-axis, not time-axis
51
51
  # time-axis is handled by `metric_fn`
52
52
  pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
53
- error_tree = jax.tree_map(pipe, y, yhat)
53
+ error_tree = jax.tree.map(pipe, y, yhat)
54
54
  return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
55
55
 
56
56
  @partial(
@@ -274,7 +274,7 @@ def _build_eval_fn(
274
274
  ), f"The metric identitifier {metric_name} is not unique"
275
275
 
276
276
  pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
277
- values.update({metric_name: jax.tree_map(pipe, y, yhat)})
277
+ values.update({metric_name: jax.tree.map(pipe, y, yhat)})
278
278
 
279
279
  return values
280
280