imt-ring 1.6.24__py3-none-any.whl → 1.6.25__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.24
3
+ Version: 1.6.25
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -54,7 +54,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
55
  ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
56
56
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
57
- ring/ml/ml_utils.py,sha256=xqy9BnLy8IKVqkFS9mlZsGJXSbThI9zZxZ5rhl8LSI8,7144
57
+ ring/ml/ml_utils.py,sha256=uAQ6qXFT2UxILwbKzFNPxaBeY4X56l9pixdv91MKQis,9072
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
59
  ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
@@ -86,7 +86,7 @@ ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3
86
86
  ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.24.dist-info/METADATA,sha256=vaXarRf1r5xZeGK-av_regQ2LgaCTnb0Th43bDLXgN8,4089
90
- imt_ring-1.6.24.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.24.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.24.dist-info/RECORD,,
89
+ imt_ring-1.6.25.dist-info/METADATA,sha256=8-77JWmLIy6E3nJVd2VqfxwoKHt9b26ruipjMKR2K8I,4089
90
+ imt_ring-1.6.25.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
+ imt_ring-1.6.25.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.25.dist-info/RECORD,,
ring/ml/ml_utils.py CHANGED
@@ -243,5 +243,60 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
243
243
  )
244
244
 
245
245
 
246
+ def to_onnx(
247
+ fn,
248
+ output_path,
249
+ *args: tuple[np.ndarray],
250
+ in_args_names: Optional[list[str]] = None,
251
+ out_args_names: Optional[list[str]] = None,
252
+ validate: bool = False,
253
+ ):
254
+ import jax.experimental.jax2tf as jax2tf
255
+ import tensorflow as tf
256
+ import tf2onnx
257
+
258
+ tf_fn = tf.function(jax2tf.convert(fn, enable_xla=False))
259
+ tf_args = [tf.TensorSpec(np.shape(x), np.result_type(x)) for x in args]
260
+ tf2onnx.convert.from_function(
261
+ tf_fn, input_signature=tf_args, output_path=output_path
262
+ )
263
+
264
+ if in_args_names is not None or out_args_names is not None:
265
+ import onnx
266
+ from sor4onnx import rename
267
+
268
+ model = onnx.load(output_path)
269
+
270
+ if in_args_names is not None:
271
+ old_names = [inp.name for inp in model.graph.input]
272
+ assert len(old_names) == len(in_args_names)
273
+ for old_name, new_name in zip(old_names, in_args_names):
274
+ model = rename([old_name, new_name], None, model, None, mode="inputs")
275
+
276
+ if out_args_names is not None:
277
+ old_names = [out.name for out in model.graph.output]
278
+ assert len(old_names) == len(out_args_names)
279
+ for old_name, new_name in zip(old_names, out_args_names):
280
+ model = rename([old_name, new_name], None, model, None, mode="outputs")
281
+
282
+ onnx.save(model, output_path)
283
+
284
+ if validate:
285
+ import onnxruntime as ort
286
+
287
+ output_jax = fn(*args)
288
+ session = ort.InferenceSession(output_path)
289
+ input_names = [inp.name for inp in session.get_inputs()]
290
+ output_onnx = session.run(
291
+ None, {name: np.array(arg) for name, arg in zip(input_names, args)}
292
+ )
293
+
294
+ for o1, o2 in zip(output_jax, output_onnx):
295
+ assert np.allclose(o1, o2, atol=1e-5, rtol=1e-5)
296
+
297
+ if out_args_names is not None:
298
+ assert [out.name for out in session.get_outputs()] == out_args_names
299
+
300
+
246
301
  def _unknown_link_names(N: int):
247
302
  return [f"link{i}" for i in range(N)]