imt-ring 1.6.37__py3-none-any.whl → 1.6.39__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: imt-ring
3
- Version: 1.6.37
3
+ Version: 1.6.39
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,25 +1,25 @@
1
1
  ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=_TgFrggsZfam0VPxvD4J5xp977vgiLnKTlDIJVzik5M,35362
3
+ ring/base.py,sha256=zromjIuMpNBoyiwHa9OCyZvAz7jHjXHZIdRt8fN8PoA,50481
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
8
- ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
9
- ring/algorithms/jcalc.py,sha256=QafnCKa1mjEl7nL_KuadPJB5ebW31NKnkdcKn2YtSsM,36171
10
- ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
8
+ ring/algorithms/dynamics.py,sha256=NFOZawjrFoS5RgiWOpG1pQCC8l7RBOEZIi9ok6gvf9U,12268
9
+ ring/algorithms/jcalc.py,sha256=l6BXOmXwrZ_AKKRm4gEHq_k2LSUQ4wd--1qL1qNTcKk,46794
10
+ ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
11
+ ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
- ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
+ ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
19
- ring/algorithms/generator/batch.py,sha256=P51UnAZl9TUF_eVq58VL1CsmPPStPHhRDdKjUyvu4EA,2652
20
- ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
18
+ ring/algorithms/generator/base.py,sha256=klWYt6TlMluLu0ihGzmmPXBm47DOTpjXJylZVNXHVEk,22419
19
+ ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
20
+ ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
- ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
+ ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
23
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
24
24
  ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
25
25
  ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
@@ -47,46 +47,46 @@ ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQ
47
47
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
48
48
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
49
49
  ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
50
- ring/io/xml/from_xml.py,sha256=E7JQl_scL5U4LK6mqLMr5qaiZCc6J1fInxD7uwgNCJY,9356
50
+ ring/io/xml/from_xml.py,sha256=CR3OaBxoDuHK8k5N79XziHwS90lCaaw49UGzQirWiIw,9356
51
51
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
53
53
  ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
- ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
+ ring/ml/base.py,sha256=phBUfTpP1Mqt8lvtilSavT7ypkaeaF1oh7nCMjV0dqg,10046
56
56
  ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
57
- ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
57
+ ring/ml/ml_utils.py,sha256=hu189AnHcmkhkpEPZZ19O0gWz3T-YKpWQW9buqDTMow,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
- ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
59
+ ring/ml/ringnet.py,sha256=uybMtLdQV0oTFldlbffrqku_l79y3coJYHwxX8P9QZQ,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
61
+ ring/ml/train.py,sha256=_CtQM3w9L01V5yn23lz0aaIPJN5sOWlL9e7G_9__11c,10989
62
62
  ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
65
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
66
66
  ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
67
- ring/rendering/mujoco_render.py,sha256=bSj1_7YL8wZV6cp9oD2CvbkZRSuxVhmPBA2JECxrnUE,8426
68
- ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
67
+ ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obmj4,8757
68
+ ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
69
69
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
70
70
  ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
71
71
  ring/sim2real/sim2real.py,sha256=B4nqBBnjGXhM-7PfTyxEq44ZidGNghqaq--qdFILX5A,9675
72
72
  ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E,193
73
73
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
74
- ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
74
+ ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
75
75
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
76
76
  ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
77
77
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
- ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
+ ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
- ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=bravdBqbkxxcQDieg6OnmArGwGcWpMI3MmNFwTCt0qg,3808
82
- ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
- ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
80
+ ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
81
+ ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
82
+ ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
83
+ ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
85
85
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
86
86
  ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.37.dist-info/METADATA,sha256=NN6c8jI6u0PjCtw3ZXs9Ktk6ODPI9igC-qyS1aUIUsI,4251
90
- imt_ring-1.6.37.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
91
- imt_ring-1.6.37.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.37.dist-info/RECORD,,
89
+ imt_ring-1.6.39.dist-info/METADATA,sha256=v0rBTnCQP-SWJU153byfX31HUCcrWHhCg_EmecLbLf4,4251
90
+ imt_ring-1.6.39.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
91
+ imt_ring-1.6.39.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.39.dist-info/RECORD,,
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
184
184
 
185
185
  suntay_link_name = _utils_find_suntay_joint(sys)
186
186
 
187
- params = jax.tree_map(
187
+ params = jax.tree.map(
188
188
  lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
189
189
  sys.links.joint_params[name],
190
190
  )
@@ -303,7 +303,33 @@ def step(
303
303
  taus: Optional[jax.Array] = None,
304
304
  n_substeps: int = 1,
305
305
  ) -> base.State:
306
- "Steps the dynamics. Returns the state of next timestep."
306
+ """
307
+ Advances the system dynamics by a single timestep using semi-implicit Euler integration.
308
+
309
+ This function updates the system's state by integrating the equations of motion
310
+ over a timestep, potentially with multiple substeps for improved numerical stability.
311
+ The method ensures that the system's kinematics are updated before each integration step.
312
+
313
+ Args:
314
+ sys (base.System):
315
+ The system to simulate, containing link information, joint dynamics, and integration parameters.
316
+ state (base.State):
317
+ The current state of the system, including joint positions (`q`), velocities (`qd`), and transforms (`x`).
318
+ taus (Optional[jax.Array], optional):
319
+ The control torques applied to the system joints. If `None`, zero torques are applied.
320
+ Defaults to `None`.
321
+ n_substeps (int, optional):
322
+ The number of integration substeps per timestep to improve numerical accuracy.
323
+ Defaults to `1`.
324
+
325
+ Returns:
326
+ base.State:
327
+ The updated state of the system after integration.
328
+
329
+ Raises:
330
+ AssertionError: If the system's degrees of freedom (`q` and `qd`) do not match expectations.
331
+ AssertionError: If an unsupported integration method is specified in `sys.integration_method`.
332
+ """ # noqa: E501
307
333
  assert sys.q_size() == state.q.size
308
334
  if taus is None:
309
335
  taus = jnp.zeros_like(state.qd)
@@ -53,7 +53,85 @@ class RCMG:
53
53
  cor: bool = False,
54
54
  disable_tqdm: bool = False,
55
55
  ) -> None:
56
- "Random Chain Motion Generator"
56
+ """
57
+ Initializes the Random Chain Motion Generator (RCMG).
58
+
59
+ The RCMG generates synthetic joint motion sequences for kinematic and dynamic
60
+ systems based on predefined motion configurations. It allows for system
61
+ randomization, augmentation with IMU and joint axis data, and optional
62
+ dynamic simulation.
63
+
64
+ Args:
65
+ sys (base.System | list[base.System]):
66
+ The system(s) for which motion should be generated.
67
+ config (jcalc.MotionConfig | list[jcalc.MotionConfig], optional):
68
+ Motion configuration(s) defining velocity limits, interpolation methods,
69
+ and range constraints. Defaults to `jcalc.MotionConfig()`.
70
+ setup_fn (Optional[types.SETUP_FN], optional):
71
+ A function to modify the system before motion generation. Defaults to `None`.
72
+ finalize_fn (Optional[types.FINALIZE_FN], optional):
73
+ A function to modify outputs after motion generation. Defaults to `None`.
74
+ add_X_imus (bool, optional):
75
+ Whether to add IMU sensor data to the output. Defaults to `False`.
76
+ add_X_imus_kwargs (dict, optional):
77
+ Additional keyword arguments for IMU data processing. Defaults to `{}`.
78
+ add_X_jointaxes (bool, optional):
79
+ Whether to add joint axis data to the output. Defaults to `False`.
80
+ add_X_jointaxes_kwargs (dict, optional):
81
+ Additional keyword arguments for joint axis data processing. Defaults to `{}`.
82
+ add_y_relpose (bool, optional):
83
+ Whether to add relative pose targets to the output. Defaults to `False`.
84
+ add_y_rootincl (bool, optional):
85
+ Whether to add root inclination targets to the output. Defaults to `False`.
86
+ add_y_rootincl_kwargs (dict, optional):
87
+ Additional keyword arguments for root inclination processing. Defaults to `{}`.
88
+ add_y_rootfull (bool, optional):
89
+ Whether to add full root state targets to the output. Defaults to `False`.
90
+ add_y_rootfull_kwargs (dict, optional):
91
+ Additional keyword arguments for full root state processing. Defaults to `{}`.
92
+ sys_ml (Optional[base.System], optional):
93
+ System that defines the graph and naming structure of the `X` and `y` outputs. Defaults to `None` which then uses the first provided system.
94
+ randomize_positions (bool, optional):
95
+ Whether to randomised positions based on `pos_min` and `pos_max`. Defaults to `False`.
96
+ randomize_motion_artifacts (bool, optional):
97
+ Whether to randomize the IMU motion artifact simulation. This randomises the spring stiffness and spring damping parameters of the passive free joint that is added between nonrigid and rigid IMU. Defaults to `False`.
98
+ randomize_joint_params (bool, optional):
99
+ Whether to randomize joint parameters by calling `JointModel.init_joint_params` before every sequence generation. Defaults to `False`.
100
+ randomize_hz (bool, optional):
101
+ Whether to randomize the sampling frequency of the generated data. Defaults to `False`.
102
+ randomize_hz_kwargs (dict, optional):
103
+ Additional keyword arguments for sampling frequency randomization. Defaults to `{}`.
104
+ imu_motion_artifacts (bool, optional):
105
+ Whether to simulate nonrigid IMU motion artifacts. Defaults to `False`.
106
+ imu_motion_artifacts_kwargs (dict, optional):
107
+ Additional keyword arguments for IMU motion artifact simulation. Defaults to `{}`.
108
+ dynamic_simulation (bool, optional):
109
+ Whether to use a physics-based simulation to generate motion instead of purely
110
+ kinematic methods. Defaults to `False`.
111
+ dynamic_simulation_kwargs (dict, optional):
112
+ Additional keyword arguments for dynamic simulation. Defaults to `{}`.
113
+ output_transform (Optional[Callable], optional):
114
+ A function to transform the generated output data. Defaults to `None`.
115
+ keep_output_extras (bool, optional):
116
+ Whether to keep additional output metadata. Defaults to `False`.
117
+ use_link_number_in_Xy (bool, optional):
118
+ Whether to replace joint names with numerical indices in the output. Defaults to `False`.
119
+ cor (bool, optional):
120
+ Whether to replace free joints with center-of-rotation (COR) 9D free joint. Defaults to `False`.
121
+ disable_tqdm (bool, optional):
122
+ Whether to disable progress bars during generation. Defaults to `False`.
123
+
124
+ Raises:
125
+ AssertionError: If any of the provided `MotionConfig` instances are infeasible.
126
+
127
+ Notes:
128
+ - This class supports batch generation, lazy and eager data loading, and
129
+ motion augmentation.
130
+ - If `randomize_hz=True`, the time step (`dt`) varies according to the specified
131
+ sampling rates.
132
+ - When `cor=True`, free joints are replaced with center-of-rotation models,
133
+ affecting joint motion behavior.
134
+ """ # noqa: E501
57
135
 
58
136
  # add some default values
59
137
  randomize_hz_kwargs_defaults = dict(add_dt=True)
@@ -139,6 +217,7 @@ class RCMG:
139
217
  def to_lazy_gen(
140
218
  self, sizes: int | list[int] = 1, jit: bool = True
141
219
  ) -> types.BatchedGenerator:
220
+ "Returns a generator `X, y = gen(key)` that lazily generates batched sequences."
142
221
  return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
143
222
 
144
223
  @staticmethod
@@ -201,7 +280,7 @@ class RCMG:
201
280
  ),
202
281
  verbose: bool = True,
203
282
  ):
204
-
283
+ "Stores unbatched sequences as numpy arrays into folder."
205
284
  i = 0
206
285
 
207
286
  def callback(data: list[PyTree[np.ndarray]]) -> None:
@@ -237,6 +316,7 @@ class RCMG:
237
316
  shuffle: bool = True,
238
317
  transform=None,
239
318
  ) -> types.BatchedGenerator:
319
+ "Returns a generator `X, y = gen(key)` that returns precomputed batched sequences." # noqa: E501
240
320
  data = self.to_list(sizes, seed)
241
321
  assert len(data) >= batchsize
242
322
  return self.eager_gen_from_list(data, batchsize, shuffle, transform)
@@ -80,11 +80,11 @@ def generators_eager(
80
80
  # converts also to numpy; but with np.array.flags.writeable = False
81
81
  sample = jax.device_get(sample)
82
82
  # this then sets this flag to True
83
- sample = jax.tree_map(np.array, sample)
83
+ sample = jax.tree.map(np.array, sample)
84
84
 
85
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
86
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
87
- callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
87
+ callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
88
88
 
89
89
  # cleanup
90
90
  del sample, sample_flat
@@ -311,7 +311,7 @@ def _expand_then_flatten(Xy):
311
311
 
312
312
  X, y = _flatten(X), _flatten(y)
313
313
  if not batched:
314
- X, y = jax.tree_map(lambda arr: arr[0], (X, y))
314
+ X, y = jax.tree.map(lambda arr: arr[0], (X, y))
315
315
  return X, y
316
316
 
317
317
 
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
86
86
  controller_state: PDControllerState, sys: base.System, state: base.State
87
87
  ) -> jax.Array:
88
88
  taus = jnp.zeros((sys.qd_size()))
89
- q_ref, qd_ref = jax.tree_map(
89
+ q_ref, qd_ref = jax.tree.map(
90
90
  lambda arr: jax.lax.dynamic_index_in_dim(
91
91
  arr, controller_state.i, keepdims=False
92
92
  ),
ring/algorithms/jcalc.py CHANGED
@@ -19,6 +19,87 @@ from ring.algorithms._random import TimeDependentFloat
19
19
 
20
20
  @dataclass
21
21
  class MotionConfig:
22
+ """
23
+ Configuration for joint motion generation in kinematic and dynamic simulations.
24
+
25
+ This class defines the constraints and parameters for generating random joint motions,
26
+ including angular and positional velocity limits, interpolation methods, and range
27
+ restrictions for various joint types.
28
+
29
+ Attributes:
30
+ T (float): Total duration of the motion sequence (in seconds).
31
+ t_min (float): Minimum time interval between two generated joint states.
32
+ t_max (float | TimeDependentFloat): Maximum time interval between two generated joint states.
33
+
34
+ dang_min (float | TimeDependentFloat): Minimum angular velocity (rad/s).
35
+ dang_max (float | TimeDependentFloat): Maximum angular velocity (rad/s).
36
+ dang_min_free_spherical (float | TimeDependentFloat): Minimum angular velocity for free and spherical joints.
37
+ dang_max_free_spherical (float | TimeDependentFloat): Maximum angular velocity for free and spherical joints.
38
+
39
+ delta_ang_min (float | TimeDependentFloat): Minimum allowed change in joint angle (radians).
40
+ delta_ang_max (float | TimeDependentFloat): Maximum allowed change in joint angle (radians).
41
+ delta_ang_min_free_spherical (float | TimeDependentFloat): Minimum allowed change in angle for free/spherical joints.
42
+ delta_ang_max_free_spherical (float | TimeDependentFloat): Maximum allowed change in angle for free/spherical joints.
43
+
44
+ dpos_min (float | TimeDependentFloat): Minimum translational velocity.
45
+ dpos_max (float | TimeDependentFloat): Maximum translational velocity.
46
+ pos_min (float | TimeDependentFloat): Minimum position constraint.
47
+ pos_max (float | TimeDependentFloat): Maximum position constraint.
48
+
49
+ pos_min_p3d_x (float | TimeDependentFloat): Minimum position in x-direction for P3D joints.
50
+ pos_max_p3d_x (float | TimeDependentFloat): Maximum position in x-direction for P3D joints.
51
+ pos_min_p3d_y (float | TimeDependentFloat): Minimum position in y-direction for P3D joints.
52
+ pos_max_p3d_y (float | TimeDependentFloat): Maximum position in y-direction for P3D joints.
53
+ pos_min_p3d_z (float | TimeDependentFloat): Minimum position in z-direction for P3D joints.
54
+ pos_max_p3d_z (float | TimeDependentFloat): Maximum position in z-direction for P3D joints.
55
+
56
+ cdf_bins_min (int): Minimum number of bins for cumulative distribution function (CDF)-based random sampling.
57
+ cdf_bins_max (Optional[int]): Maximum number of bins for CDF-based sampling.
58
+
59
+ randomized_interpolation_angle (bool): Whether to use randomized interpolation for angular motion.
60
+ randomized_interpolation_position (bool): Whether to use randomized interpolation for positional motion.
61
+ interpolation_method (str): Interpolation method to be used (default: "cosine").
62
+
63
+ range_of_motion_hinge (bool): Whether to enforce range-of-motion constraints on hinge joints.
64
+ range_of_motion_hinge_method (str): Method used for range-of-motion enforcement (e.g., "uniform", "sigmoid").
65
+
66
+ rom_halfsize (float | TimeDependentFloat): Half-size of the range of motion restriction.
67
+
68
+ ang0_min (float): Minimum initial joint angle.
69
+ ang0_max (float): Maximum initial joint angle.
70
+ pos0_min (float): Minimum initial joint position.
71
+ pos0_max (float): Maximum initial joint position.
72
+
73
+ cor_t_min (float): Minimum time step for center-of-rotation (COR) joints.
74
+ cor_t_max (float | TimeDependentFloat): Maximum time step for COR joints.
75
+ cor_dpos_min (float | TimeDependentFloat): Minimum velocity for COR translation.
76
+ cor_dpos_max (float | TimeDependentFloat): Maximum velocity for COR translation.
77
+ cor_pos_min (float | TimeDependentFloat): Minimum position for COR translation.
78
+ cor_pos_max (float | TimeDependentFloat): Maximum position for COR translation.
79
+ cor_pos0_min (float): Initial minimum position for COR translation.
80
+ cor_pos0_max (float): Initial maximum position for COR translation.
81
+
82
+ joint_type_specific_overwrites (dict[str, dict[str, Any]]):
83
+ A dictionary mapping joint types to specific motion configuration overrides.
84
+
85
+ Methods:
86
+ is_feasible:
87
+ Checks if the motion configuration satisfies all constraints.
88
+
89
+ to_nomotion_config:
90
+ Returns a new `MotionConfig` where all velocities and angle changes are set to zero.
91
+
92
+ overwrite_for_joint_type:
93
+ Applies specific configuration changes for a given joint type.
94
+ Note: These changes affect all instances of `MotionConfig` for this joint type.
95
+
96
+ overwrite_for_subsystem:
97
+ Modifies the motion configuration for all joints in a subsystem rooted at `link_name`.
98
+
99
+ from_register:
100
+ Retrieves a predefined `MotionConfig` from the global registry.
101
+ """ # noqa: E501
102
+
22
103
  T: float = 60.0 # length of random motion
23
104
  t_min: float = 0.05 # min time between two generated angles
24
105
  t_max: float | TimeDependentFloat = 0.30 # max time ..
@@ -412,6 +493,30 @@ def _find_interval(t: jax.Array, boundaries: jax.Array):
412
493
  def join_motionconfigs(
413
494
  configs: list[MotionConfig], boundaries: list[float]
414
495
  ) -> MotionConfig:
496
+ """
497
+ Joins multiple `MotionConfig` objects in time, transitioning between them at specified boundaries.
498
+
499
+ This function takes a list of `MotionConfig` instances and a corresponding list of boundary times,
500
+ and constructs a new `MotionConfig` that varies in time according to the provided segments.
501
+
502
+ Args:
503
+ configs (list[MotionConfig]): A list of `MotionConfig` objects to be joined.
504
+ boundaries (list[float]): A list of time values where transitions between `configs` occur.
505
+ Must have one element less than `configs`, as each boundary defines the transition point
506
+ between two consecutive configurations.
507
+
508
+ Returns:
509
+ MotionConfig: A new `MotionConfig` object where time-dependent fields transition based on the
510
+ specified boundaries.
511
+
512
+ Raises:
513
+ AssertionError: If the number of boundaries does not match `len(configs) - 1`.
514
+ AssertionError: If time-independent fields have differing values across `configs`.
515
+
516
+ Notes:
517
+ - Only fields that are time-dependent (`float | TimeDependentFloat`) will change over time.
518
+ - Time-independent fields must be the same in all `configs`, or an error is raised.
519
+ """ # noqa: E501
415
520
  # to avoid a circular import due to `ring.utils.randomize_sys` importing `jcalc`
416
521
  from ring.utils import tree_equal
417
522
 
@@ -517,6 +622,55 @@ INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
517
622
 
518
623
  @dataclass
519
624
  class JointModel:
625
+ """
626
+ Represents the kinematic and dynamic properties of a joint type.
627
+
628
+ A `JointModel` defines the mathematical functions required to compute joint
629
+ transformations, motion, control terms, and inverse kinematics. It is used to
630
+ describe the behavior of various joint types, including revolute, prismatic,
631
+ spherical, and free joints.
632
+
633
+ Attributes:
634
+ transform (Callable[[jax.Array, jax.Array], base.Transform]):
635
+ Computes the transformation (position and orientation) of the joint
636
+ given the joint state `q` and joint parameters.
637
+
638
+ motion (list[base.Motion | Callable[[jax.Array], base.Motion]]):
639
+ Defines the joint motion model. It can be a list of `Motion` objects
640
+ or callables that return `Motion` based on joint parameters.
641
+
642
+ rcmg_draw_fn (Optional[DRAW_FN]):
643
+ Function used to generate a reference motion trajectory for the joint
644
+ using Randomized Control Motion Generation (RCMG).
645
+
646
+ p_control_term (Optional[P_CONTROL_TERM]):
647
+ Function that computes the proportional control term for the joint.
648
+
649
+ qd_from_q (Optional[QD_FROM_Q]):
650
+ Function to compute joint velocity (`qd`) from joint positions (`q`).
651
+
652
+ coordinate_vector_to_q (Optional[COORDINATE_VECTOR_TO_Q]):
653
+ Function that maps a coordinate vector to a valid joint state `q`,
654
+ ensuring constraints (e.g., wrapping angles or normalizing quaternions).
655
+
656
+ inv_kin (Optional[INV_KIN]):
657
+ Function that computes the inverse kinematics for the joint, mapping
658
+ a desired transform to joint coordinates `q`.
659
+
660
+ init_joint_params (Optional[INIT_JOINT_PARAMS]):
661
+ Function that initializes joint-specific parameters.
662
+
663
+ utilities (Optional[dict[str, Any]]):
664
+ Additional utility functions or metadata related to the joint model.
665
+
666
+ Notes:
667
+ - The `transform` function is essential for computing the joint's spatial
668
+ transformation based on its generalized coordinates.
669
+ - The `motion` attribute describes how forces and torques affect the joint.
670
+ - The `rcmg_draw_fn` is used for RCMG motion generation.
671
+ - The `coordinate_vector_to_q` is critical for maintaining valid joint states.
672
+ """ # noqa: E501
673
+
520
674
  # (q, params) -> Transform
521
675
  transform: Callable[[jax.Array, jax.Array], base.Transform]
522
676
  # len(motion) == len(qd)
@@ -1079,6 +1233,50 @@ def register_new_joint_type(
1079
1233
  qd_width: Optional[int] = None,
1080
1234
  overwrite: bool = False,
1081
1235
  ):
1236
+ """
1237
+ Registers a new joint type with its corresponding `JointModel` and kinematic properties.
1238
+
1239
+ This function allows the addition of custom joint types to the system by associating
1240
+ them with a `JointModel`, specifying their state and velocity dimensions, and optionally
1241
+ overwriting existing joint definitions.
1242
+
1243
+ Args:
1244
+ joint_type (str):
1245
+ Name of the new joint type to register.
1246
+ joint_model (JointModel):
1247
+ The `JointModel` instance defining the kinematic and dynamic properties of the joint.
1248
+ q_width (int):
1249
+ Number of generalized coordinates (degrees of freedom) required to represent the joint.
1250
+ qd_width (Optional[int], default=None):
1251
+ Number of velocity coordinates associated with the joint. Defaults to `q_width`.
1252
+ overwrite (bool, default=False):
1253
+ If `True`, allows overwriting an existing joint type. Otherwise, raises an error if
1254
+ the joint type already exists.
1255
+
1256
+ Raises:
1257
+ AssertionError:
1258
+ - If `joint_type` is `"default"` (reserved name).
1259
+ - If `joint_type` already exists and `overwrite=False`.
1260
+ - If `qd_width` is not provided and does not default to `q_width`.
1261
+ - If `joint_model.motion` length does not match `qd_width`.
1262
+
1263
+ Notes:
1264
+ - The function updates global dictionaries that store joint properties, including:
1265
+ - `_joint_types`: Maps joint type names to `JointModel` instances.
1266
+ - `base.Q_WIDTHS`: Stores the number of state coordinates for each joint type.
1267
+ - `base.QD_WIDTHS`: Stores the number of velocity coordinates for each joint type.
1268
+ - If `overwrite=True`, existing entries are removed before adding the new joint type.
1269
+ - Ensures consistency between motion definitions and velocity coordinate dimensions.
1270
+
1271
+ Example:
1272
+ ```python
1273
+ new_joint = JointModel(
1274
+ transform=my_transform_fn,
1275
+ motion=[base.Motion.create(ang=jnp.array([1, 0, 0]))],
1276
+ )
1277
+ register_new_joint_type("custom_hinge", new_joint, q_width=1)
1278
+ ```
1279
+ """ # noqa: E501
1082
1280
  # this name is used
1083
1281
  assert joint_type != "default", "Please use another name."
1084
1282
 
@@ -4,6 +4,7 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  import jaxopt
6
6
  from jaxopt._src.base import Solver
7
+
7
8
  from ring import algebra
8
9
  from ring import base
9
10
  from ring import maths
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
171
172
 
172
173
  # find result of best q0 initial value
173
174
  best_q_index = jnp.argmin(values)
174
- best_q, best_q_value = jax.tree_map(
175
+ best_q, best_q_value = jax.tree.map(
175
176
  lambda arr: jax.lax.dynamic_index_in_dim(
176
177
  arr, best_q_index, keepdims=False
177
178
  ),
@@ -244,7 +244,7 @@ def imu(
244
244
  measurements["mag"] = magnetometer(xs.rot, magvec)
245
245
 
246
246
  if smoothen_degree is not None:
247
- measurements = jax.tree_map(
247
+ measurements = jax.tree.map(
248
248
  lambda arr: _moving_average(arr, smoothen_degree),
249
249
  measurements,
250
250
  )
@@ -257,7 +257,7 @@ def imu(
257
257
  delay = half_window
258
258
 
259
259
  if delay is not None and delay > 0:
260
- measurements = jax.tree_map(
260
+ measurements = jax.tree.map(
261
261
  lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
262
262
  )
263
263
 
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
473
473
  X[name] = {"joint_axes": joint_axes}
474
474
 
475
475
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
476
- X = jax.tree_map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
476
+ X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
477
477
  return X
478
478
 
479
479
 
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
498
498
  <x_xy>
499
499
  <options gravity="0 0 0"/>
500
500
  <worldbody>
501
- <body name="IMU" joint="p3d" damping="0.1 0.1 0.1" spring_stiff="3 3 3">
502
- <geom type="box" mass="0.002" dim="0.01 0.01 0.01"/>
501
+ <body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
502
+ <geom type="box" mass="1" dim="0.01 0.01 0.01"/>
503
503
  </body>
504
504
  </worldbody>
505
505
  </x_xy>
506
- """
506
+ """ # noqa: E501
507
507
 
508
508
 
509
509
  def _quasi_physical_simulation_beautiful(
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
512
512
  sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
513
513
 
514
514
  def step_dynamics(state: base.State, x):
515
- state = algorithms.step(sys.replace(link_spring_zeropoint=x.pos), state)
515
+ state = algorithms.step(
516
+ sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
517
+ )
516
518
  return state, state.q
517
519
 
518
- state = base.State.create(sys, q=xs.pos[0])
519
- _, pos = jax.lax.scan(step_dynamics, state, xs)
520
- return xs.replace(pos=pos)
520
+ state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
521
+ _, qs = jax.lax.scan(step_dynamics, state, xs)
522
+ return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
521
523
 
522
524
 
523
525
  _constants = {