imt-ring 1.6.25__py3-none-any.whl → 1.6.27__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.25
3
+ Version: 1.6.27
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -47,24 +47,24 @@ ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQ
47
47
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
48
48
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
49
49
  ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
50
- ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
+ ring/io/xml/from_xml.py,sha256=E7JQl_scL5U4LK6mqLMr5qaiZCc6J1fInxD7uwgNCJY,9356
51
51
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
53
53
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
56
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
57
- ring/ml/ml_utils.py,sha256=uAQ6qXFT2UxILwbKzFNPxaBeY4X56l9pixdv91MKQis,9072
57
+ ring/ml/ml_utils.py,sha256=Zm4spN0Xn-2avYu9xt3NikCLVjYM1Gh59a6XU9jGxoU,10792
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=XuUUB0NhvByGtZDtS_weyp-TKPG9ErnKixS4NqB8q6M,10822
61
+ ring/ml/train.py,sha256=-6SzQKjIgktgRjaXKVg_1dqcBmAJggZSVwDnau1FnxI,10832
62
62
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
65
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
66
66
  ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
67
- ring/rendering/mujoco_render.py,sha256=R8qxqItakBlptbQpCzsZoVfdWYhSMwZYQzaCKbUigYU,7987
67
+ ring/rendering/mujoco_render.py,sha256=_aesWMf_KfxvG8JaXTj4SNmRvzsJrluSMz0iHTbXbLg,8256
68
68
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
69
69
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
70
70
  ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.25.dist-info/METADATA,sha256=8-77JWmLIy6E3nJVd2VqfxwoKHt9b26ruipjMKR2K8I,4089
90
- imt_ring-1.6.25.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.25.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.25.dist-info/RECORD,,
89
+ imt_ring-1.6.27.dist-info/METADATA,sha256=ekZVHth31C6ZXF_k2J_XnfDeSWCao-pF_fT9BDdfOAs,4089
90
+ imt_ring-1.6.27.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.27.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.27.dist-info/RECORD,,
ring/io/xml/from_xml.py CHANGED
@@ -2,6 +2,7 @@ from xml.etree import ElementTree
2
2
 
3
3
  import jax
4
4
  import numpy as np
5
+
5
6
  from ring import base
6
7
  from ring.algorithms import jcalc
7
8
  from ring.utils import parse_path
@@ -181,7 +182,10 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
181
182
 
182
183
  link_parents[current_link_idx] = parent
183
184
  link_types[current_link_idx] = current_link_typ
184
- link_names[current_link_idx] = body.attrib["name"]
185
+ current_name = body.attrib["name"]
186
+ link_names[current_link_idx] = (
187
+ current_name if isinstance(current_name, str) else str(int(current_name))
188
+ )
185
189
 
186
190
  transform = abstract.AbsTrans.from_xml(body.attrib)
187
191
  pos_min, pos_max = abstract.AbsPosMinMax.from_xml(body.attrib, transform.pos)
ring/ml/ml_utils.py CHANGED
@@ -251,6 +251,48 @@ def to_onnx(
251
251
  out_args_names: Optional[list[str]] = None,
252
252
  validate: bool = False,
253
253
  ):
254
+ """
255
+ Converts a JAX function to ONNX format, with optional input/output renaming and validation.
256
+
257
+ Args:
258
+ fn (callable): The JAX function to be converted.
259
+ output_path (str): Path where the ONNX model will be saved.
260
+ *args (tuple[np.ndarray]): Input arguments for the JAX function.
261
+ in_args_names (Optional[list[str]]): Names for the ONNX model's input tensors. Defaults to None.
262
+ out_args_names (Optional[list[str]]): Names for the ONNX model's output tensors. Defaults to None.
263
+ validate (bool): Whether to validate the ONNX model against the JAX function's outputs. Defaults to False.
264
+
265
+ Raises:
266
+ AssertionError: If the number of provided names does not match the number of inputs/outputs.
267
+ AssertionError: If the ONNX model's outputs do not match the JAX function's outputs within tolerance.
268
+ ValueError: If any error occurs during ONNX conversion, saving, or validation.
269
+
270
+ Notes:
271
+ - The function uses `jax2tf` to convert the JAX function to TensorFlow format,
272
+ and `tf2onnx` for ONNX conversion.
273
+ - Input and output tensor names in the ONNX model can be renamed using `sor4onnx.rename`.
274
+ - Validation compares outputs of the JAX function and the ONNX model using ONNX Runtime.
275
+
276
+ Example:
277
+ ```
278
+ import jax.numpy as jnp
279
+
280
+ def my_fn(x, y):
281
+ return x + y, x * y
282
+
283
+ x = jnp.array([1, 2, 3])
284
+ y = jnp.array([4, 5, 6])
285
+
286
+ to_onnx(
287
+ my_fn,
288
+ "model.onnx",
289
+ x, y,
290
+ in_args_names=["input1", "input2"],
291
+ out_args_names=["sum", "product"],
292
+ validate=True,
293
+ )
294
+ ```
295
+ """ # noqa: E501
254
296
  import jax.experimental.jax2tf as jax2tf
255
297
  import tensorflow as tf
256
298
  import tf2onnx
ring/ml/train.py CHANGED
@@ -45,7 +45,7 @@ def _build_step_fn(
45
45
 
46
46
  @partial(jax.value_and_grad, has_aux=True)
47
47
  def loss_fn(params, state, X, y):
48
- yhat, state = filter.apply(params=params, state=state, X=X)
48
+ yhat, state = filter.apply(params=params, state=state, X=X, y=y)
49
49
  # this vmap maps along batch-axis, not time-axis
50
50
  # time-axis is handled by `metric_fn`
51
51
  pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
@@ -261,7 +261,7 @@ def _build_eval_fn(
261
261
  """Build function that evaluates the filter performance."""
262
262
 
263
263
  def eval_fn(params, state, X, y):
264
- yhat, _ = filter.apply(params=params, state=state, X=X)
264
+ yhat, _ = filter.apply(params=params, state=state, X=X, y=y)
265
265
 
266
266
  y = _arr_to_dict(y, link_names)
267
267
  yhat = _arr_to_dict(yhat, link_names)
@@ -8,7 +8,7 @@ from ring import maths
8
8
 
9
9
  _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
10
10
  _skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
11
- _floor = """<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
11
+ _floor = """<geom name="floor" pos="0 0 -0.84" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
12
12
 
13
13
 
14
14
  def _build_model_of_geoms(
@@ -75,6 +75,8 @@ def _build_model_of_geoms(
75
75
  <asset>
76
76
  <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
77
77
  <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
78
+ <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
79
+ <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
78
80
  {_skybox if stars else ''}
79
81
  <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
80
82
  <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>