imt-ring 1.6.36__tar.gz → 1.6.38__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. {imt_ring-1.6.36 → imt_ring-1.6.38}/PKG-INFO +2 -2
  2. {imt_ring-1.6.36 → imt_ring-1.6.38}/pyproject.toml +1 -1
  3. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/PKG-INFO +2 -2
  4. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/suntay.py +1 -1
  5. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/batch.py +2 -2
  6. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/finalize_fns.py +1 -1
  7. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/pd_control.py +1 -1
  8. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/kinematics.py +2 -1
  9. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/sensors.py +12 -10
  10. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/base.py +1 -1
  11. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/from_xml.py +1 -1
  12. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/base.py +2 -2
  13. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/ml_utils.py +3 -3
  14. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/ringnet.py +1 -1
  15. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/train.py +2 -2
  16. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/mujoco_render.py +11 -7
  17. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/vispy_render.py +5 -4
  18. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/inject_sys.py +3 -2
  19. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/batchsize.py +3 -3
  20. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/dataloader.py +4 -3
  21. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/dataloader_torch.py +14 -5
  22. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/hdf5.py +1 -1
  23. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/normalizer.py +6 -5
  24. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/utils.py +18 -2
  25. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_ml_utils.py +1 -1
  26. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sim2real.py +3 -2
  27. {imt_ring-1.6.36 → imt_ring-1.6.38}/readme.md +0 -0
  28. {imt_ring-1.6.36 → imt_ring-1.6.38}/setup.cfg +0 -0
  29. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  30. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  31. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/requires.txt +0 -0
  32. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/imt_ring.egg-info/top_level.txt +0 -0
  33. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/__init__.py +0 -0
  34. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algebra.py +0 -0
  35. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/__init__.py +0 -0
  36. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/_random.py +0 -0
  37. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  38. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  39. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  40. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
  41. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/dynamics.py +0 -0
  42. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/__init__.py +0 -0
  43. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/base.py +0 -0
  44. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  45. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/setup_fns.py +0 -0
  46. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/generator/types.py +0 -0
  47. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/algorithms/jcalc.py +0 -0
  48. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/__init__.py +0 -0
  49. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/branched.xml +0 -0
  50. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  51. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  52. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  53. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/inv_pendulum.xml +0 -0
  54. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  55. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/spherical_stiff.xml +0 -0
  56. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/symmetric.xml +0 -0
  57. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_all_1.xml +0 -0
  58. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_all_2.xml +0 -0
  59. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  60. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_control.xml +0 -0
  61. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  62. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_free.xml +0 -0
  63. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_kinematics.xml +0 -0
  64. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  65. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  66. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_randomize_position.xml +0 -0
  67. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_sensors.xml +0 -0
  68. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  69. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/examples.py +0 -0
  70. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/test_examples.py +0 -0
  71. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/__init__.py +0 -0
  72. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/abstract.py +0 -0
  73. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/test_from_xml.py +0 -0
  74. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/test_to_xml.py +0 -0
  75. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/io/xml/to_xml.py +0 -0
  76. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/maths.py +0 -0
  77. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/__init__.py +0 -0
  78. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/callbacks.py +0 -0
  79. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/optimizer.py +0 -0
  80. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  81. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  82. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/rnno_v1.py +0 -0
  83. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/ml/training_loop.py +0 -0
  84. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/__init__.py +0 -0
  85. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/base_render.py +0 -0
  86. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/rendering/vispy_visuals.py +0 -0
  87. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sim2real/__init__.py +0 -0
  88. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sim2real/sim2real.py +0 -0
  89. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/spatial.py +0 -0
  90. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/__init__.py +0 -0
  91. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/delete_sys.py +0 -0
  92. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/sys_composer/morph_sys.py +0 -0
  93. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/__init__.py +0 -0
  94. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/backend.py +0 -0
  95. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/colab.py +0 -0
  96. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/path.py +0 -0
  97. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/randomize_sys.py +0 -0
  98. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  99. {imt_ring-1.6.36 → imt_ring-1.6.38}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  100. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_algebra.py +0 -0
  101. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_base.py +0 -0
  102. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_custom_joints.py +0 -0
  103. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_dynamics.py +0 -0
  104. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_generator.py +0 -0
  105. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_jcalc.py +0 -0
  106. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_jit.py +0 -0
  107. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_kinematics.py +0 -0
  108. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_maths.py +0 -0
  109. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_motion_artifacts.py +0 -0
  110. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_pd_control.py +0 -0
  111. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_quickstart_example.py +0 -0
  112. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_random.py +0 -0
  113. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_randomize.py +0 -0
  114. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_rcmg.py +0 -0
  115. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_render.py +0 -0
  116. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sensors.py +0 -0
  117. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_sys_composer.py +0 -0
  118. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_train.py +0 -0
  119. {imt_ring-1.6.36 → imt_ring-1.6.38}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: imt-ring
3
- Version: 1.6.36
3
+ Version: 1.6.38
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.36"
7
+ version = "1.6.38"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: imt-ring
3
- Version: 1.6.36
3
+ Version: 1.6.38
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
184
184
 
185
185
  suntay_link_name = _utils_find_suntay_joint(sys)
186
186
 
187
- params = jax.tree_map(
187
+ params = jax.tree.map(
188
188
  lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
189
189
  sys.links.joint_params[name],
190
190
  )
@@ -80,11 +80,11 @@ def generators_eager(
80
80
  # converts also to numpy; but with np.array.flags.writeable = False
81
81
  sample = jax.device_get(sample)
82
82
  # this then sets this flag to True
83
- sample = jax.tree_map(np.array, sample)
83
+ sample = jax.tree.map(np.array, sample)
84
84
 
85
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
86
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
87
- callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
87
+ callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
88
88
 
89
89
  # cleanup
90
90
  del sample, sample_flat
@@ -311,7 +311,7 @@ def _expand_then_flatten(Xy):
311
311
 
312
312
  X, y = _flatten(X), _flatten(y)
313
313
  if not batched:
314
- X, y = jax.tree_map(lambda arr: arr[0], (X, y))
314
+ X, y = jax.tree.map(lambda arr: arr[0], (X, y))
315
315
  return X, y
316
316
 
317
317
 
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
86
86
  controller_state: PDControllerState, sys: base.System, state: base.State
87
87
  ) -> jax.Array:
88
88
  taus = jnp.zeros((sys.qd_size()))
89
- q_ref, qd_ref = jax.tree_map(
89
+ q_ref, qd_ref = jax.tree.map(
90
90
  lambda arr: jax.lax.dynamic_index_in_dim(
91
91
  arr, controller_state.i, keepdims=False
92
92
  ),
@@ -4,6 +4,7 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  import jaxopt
6
6
  from jaxopt._src.base import Solver
7
+
7
8
  from ring import algebra
8
9
  from ring import base
9
10
  from ring import maths
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
171
172
 
172
173
  # find result of best q0 initial value
173
174
  best_q_index = jnp.argmin(values)
174
- best_q, best_q_value = jax.tree_map(
175
+ best_q, best_q_value = jax.tree.map(
175
176
  lambda arr: jax.lax.dynamic_index_in_dim(
176
177
  arr, best_q_index, keepdims=False
177
178
  ),
@@ -244,7 +244,7 @@ def imu(
244
244
  measurements["mag"] = magnetometer(xs.rot, magvec)
245
245
 
246
246
  if smoothen_degree is not None:
247
- measurements = jax.tree_map(
247
+ measurements = jax.tree.map(
248
248
  lambda arr: _moving_average(arr, smoothen_degree),
249
249
  measurements,
250
250
  )
@@ -257,7 +257,7 @@ def imu(
257
257
  delay = half_window
258
258
 
259
259
  if delay is not None and delay > 0:
260
- measurements = jax.tree_map(
260
+ measurements = jax.tree.map(
261
261
  lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
262
262
  )
263
263
 
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
473
473
  X[name] = {"joint_axes": joint_axes}
474
474
 
475
475
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
476
- X = jax.tree_map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
476
+ X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
477
477
  return X
478
478
 
479
479
 
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
498
498
  <x_xy>
499
499
  <options gravity="0 0 0"/>
500
500
  <worldbody>
501
- <body name="IMU" joint="p3d" damping="0.1 0.1 0.1" spring_stiff="3 3 3">
502
- <geom type="box" mass="0.002" dim="0.01 0.01 0.01"/>
501
+ <body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
502
+ <geom type="box" mass="1" dim="0.01 0.01 0.01"/>
503
503
  </body>
504
504
  </worldbody>
505
505
  </x_xy>
506
- """
506
+ """ # noqa: E501
507
507
 
508
508
 
509
509
  def _quasi_physical_simulation_beautiful(
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
512
512
  sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
513
513
 
514
514
  def step_dynamics(state: base.State, x):
515
- state = algorithms.step(sys.replace(link_spring_zeropoint=x.pos), state)
515
+ state = algorithms.step(
516
+ sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
517
+ )
516
518
  return state, state.q
517
519
 
518
- state = base.State.create(sys, q=xs.pos[0])
519
- _, pos = jax.lax.scan(step_dynamics, state, xs)
520
- return xs.replace(pos=pos)
520
+ state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
521
+ _, qs = jax.lax.scan(step_dynamics, state, xs)
522
+ return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
521
523
 
522
524
 
523
525
  _constants = {
@@ -807,7 +807,7 @@ class System(_Base):
807
807
  if exists:
808
808
  return cls.from_xml(path, seed=seed)
809
809
  else:
810
- return cls.from_str(path_or_str)
810
+ return cls.from_str(path_or_str, seed=seed)
811
811
 
812
812
  def coordinate_vector_to_q(
813
813
  self,
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
252
252
 
253
253
  # numpy -> jax
254
254
  # we load using numpy in order to have float64 precision
255
- sys = jax.tree_map(jax.numpy.asarray, sys)
255
+ sys = jax.tree.map(jax.numpy.asarray, sys)
256
256
 
257
257
  sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
258
258
 
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
13
13
  def _to_3d(tree):
14
14
  if tree is None:
15
15
  return None
16
- return jax.tree_map(lambda arr: arr[None], tree)
16
+ return jax.tree.map(lambda arr: arr[None], tree)
17
17
 
18
18
 
19
19
  def _to_2d(tree, i: int = 0):
20
20
  if tree is None:
21
21
  return None
22
- return jax.tree_map(lambda arr: arr[i], tree)
22
+ return jax.tree.map(lambda arr: arr[i], tree)
23
23
 
24
24
 
25
25
  class AbstractFilter(ABC):
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
161
161
  metrices: NestedDict, filter_nan_inf: bool = True
162
162
  ):
163
163
  metrices = _flatten_dict(metrices)
164
- metrices = jax.tree_map(_to_float_if_not_string, metrices)
164
+ metrices = jax.tree.map(_to_float_if_not_string, metrices)
165
165
 
166
166
  if not filter_nan_inf:
167
167
  return metrices
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
216
216
  from jax.experimental import jax2tf
217
217
  import tensorflow as tf
218
218
 
219
- signature = jax.tree_map(
219
+ signature = jax.tree.map(
220
220
  lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
221
221
  )
222
222
 
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
241
241
  if validate:
242
242
  output_jax = jax_func(*input)
243
243
  output_tf = tf.saved_model.load(path)(*input)
244
- jax.tree_map(
244
+ jax.tree.map(
245
245
  lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
246
246
  output_jax,
247
247
  output_tf,
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
248
248
  params, state = self.forward_lam_factory(lam=lam).init(key, X)
249
249
 
250
250
  if bs is not None:
251
- state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
251
+ state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
252
252
 
253
253
  return params, state
254
254
 
@@ -50,7 +50,7 @@ def _build_step_fn(
50
50
  # this vmap maps along batch-axis, not time-axis
51
51
  # time-axis is handled by `metric_fn`
52
52
  pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
53
- error_tree = jax.tree_map(pipe, y, yhat)
53
+ error_tree = jax.tree.map(pipe, y, yhat)
54
54
  return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
55
55
 
56
56
  @partial(
@@ -274,7 +274,7 @@ def _build_eval_fn(
274
274
  ), f"The metric identitifier {metric_name} is not unique"
275
275
 
276
276
  pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
277
- values.update({metric_name: jax.tree_map(pipe, y, yhat)})
277
+ values.update({metric_name: jax.tree.map(pipe, y, yhat)})
278
278
 
279
279
  return values
280
280
 
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
10
10
  _skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
11
11
 
12
12
 
13
- def _floor(floor_z: float) -> str:
14
- return f"""<geom name="floor" pos="0 0 {floor_z}" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
13
+ def _floor(z: float, material: str) -> str:
14
+ return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
15
15
 
16
16
 
17
17
  def _build_model_of_geoms(
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
19
19
  cameras: dict[int, Sequence[str]],
20
20
  lights: dict[int, Sequence[str]],
21
21
  floor: bool,
22
- floor_z: float,
22
+ floor_kwargs: dict,
23
23
  stars: bool,
24
24
  debug: bool,
25
25
  ) -> mujoco.MjModel:
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
77
77
  xml_str = f""" # noqa: E501
78
78
  <mujoco>
79
79
  <asset>
80
- <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
80
+ <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
81
81
  <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
82
82
  <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
83
83
  <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
84
+ <material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
85
+ <material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
86
+ <material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
84
87
  {_skybox if stars else ''}
85
88
  <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
86
89
  <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
98
101
  <camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
99
102
  <camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
100
103
  <camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
101
- {_floor(floor_z) if floor else ''}
104
+ {_floor(**floor_kwargs) if floor else ''}
102
105
  {inside_worldbody_cameras}
103
106
  {inside_worldbody_lights}
104
107
  {inside_worldbody}
@@ -176,6 +179,7 @@ class MujocoScene:
176
179
  show_stars: bool = True,
177
180
  show_floor: bool = True,
178
181
  floor_z: float = -0.84,
182
+ floor_material: str = "matplane",
179
183
  debug: bool = False,
180
184
  ) -> None:
181
185
  self.debug = debug
@@ -190,7 +194,7 @@ class MujocoScene:
190
194
  self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
191
195
  self.show_stars = show_stars
192
196
  self.show_floor = show_floor
193
- self.floor_z = floor_z
197
+ self.floor_kwargs = dict(z=floor_z, material=floor_material)
194
198
 
195
199
  def init(self, geoms: list[base.Geometry]):
196
200
  self._parent_ids = list(set([geom.link_idx for geom in geoms]))
@@ -199,7 +203,7 @@ class MujocoScene:
199
203
  self.add_cameras,
200
204
  self.add_lights,
201
205
  floor=self.show_floor,
202
- floor_z=self.floor_z,
206
+ floor_kwargs=self.floor_kwargs,
203
207
  stars=self.show_stars,
204
208
  debug=self.debug,
205
209
  )
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
7
7
  import jax
8
8
  import jax.numpy as jnp
9
9
  import numpy as np
10
- from ring import algebra
11
- from ring import base
12
- from ring import maths
13
10
  from tree_utils import PyTree
14
11
  from tree_utils import tree_batch
15
12
  from vispy import scene
16
13
  from vispy.scene import MatrixTransform
17
14
 
15
+ from ring import algebra
16
+ from ring import base
17
+ from ring import maths
18
+
18
19
  from . import vispy_visuals
19
20
 
20
21
  Camera = TypeVar("Camera")
@@ -192,7 +193,7 @@ class Scene(ABC):
192
193
 
193
194
  # step 3: update visuals
194
195
  for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
195
- t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
196
+ t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
196
197
  if self._fresh_init:
197
198
  self._init_visual(visual, t, geom)
198
199
  else:
@@ -2,12 +2,13 @@ from typing import Optional
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- from ring import base
6
5
  from tree_utils import tree_batch
7
6
 
7
+ from ring import base
8
+
8
9
 
9
10
  def _tree_nan_like(tree, repeats: int):
10
- return jax.tree_map(
11
+ return jax.tree.map(
11
12
  lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
12
13
  )
13
14
 
@@ -39,19 +39,19 @@ def merge_batchsize(
39
39
  tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
40
40
  ) -> PyTree:
41
41
  if third_dim_also:
42
- return jax.tree_map(
42
+ return jax.tree.map(
43
43
  lambda arr: arr.reshape(
44
44
  (pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
45
45
  ),
46
46
  tree,
47
47
  )
48
- return jax.tree_map(
48
+ return jax.tree.map(
49
49
  lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
50
50
  )
51
51
 
52
52
 
53
53
  def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
54
- return jax.tree_map(
54
+ return jax.tree.map(
55
55
  lambda arr: arr.reshape(
56
56
  (
57
57
  pmap_size,
@@ -4,14 +4,15 @@ from typing import Callable, Optional
4
4
 
5
5
  import jax
6
6
  import numpy as np
7
- from ring.utils import parse_path
8
- from ring.utils import pickle_load
9
7
  import torch
10
8
  from torch.utils.data import DataLoader
11
9
  from torch.utils.data import Dataset
12
10
  import tqdm
13
11
  from tree_utils import PyTree
14
12
 
13
+ from ring.utils import parse_path
14
+ from ring.utils import pickle_load
15
+
15
16
 
16
17
  def make_generator(
17
18
  *paths,
@@ -103,7 +104,7 @@ def pytorch_generator(
103
104
  dl_iter = iter(dl)
104
105
 
105
106
  def to_numpy(tree: PyTree[torch.Tensor]):
106
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
107
+ return jax.tree.map(lambda tensor: tensor.numpy(), tree)
107
108
 
108
109
  def generator(_):
109
110
  nonlocal dl, dl_iter
@@ -1,16 +1,25 @@
1
1
  import os
2
+ import pickle
2
3
  from typing import Any, Optional
3
4
  import warnings
4
5
 
5
- import jax
6
6
  import numpy as np
7
7
  import torch
8
8
  from torch.utils.data import DataLoader
9
9
  from torch.utils.data import Dataset
10
+ import tree
10
11
  from tree_utils import PyTree
11
12
 
12
- from ring.utils import parse_path
13
- from ring.utils import pickle_load
13
+ from ring.utils.path import parse_path
14
+
15
+
16
+ def pickle_load(
17
+ path,
18
+ ):
19
+ path = parse_path(path, extension="pickle", require_is_file=True)
20
+ with open(path, "rb") as file:
21
+ obj = pickle.load(file)
22
+ return obj
14
23
 
15
24
 
16
25
  class FolderOfFilesDataset(Dataset):
@@ -60,8 +69,8 @@ def dataset_to_generator(
60
69
  )
61
70
  dl_iter = iter(dl)
62
71
 
63
- def to_numpy(tree: PyTree[torch.Tensor]):
64
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
72
+ def to_numpy(data: PyTree[torch.Tensor]):
73
+ return tree.map_structure(lambda tensor: tensor.numpy(), data)
65
74
 
66
75
  def generator(_):
67
76
  nonlocal dl, dl_iter
@@ -121,7 +121,7 @@ def _parse_path(
121
121
 
122
122
  def _tree_concat(trees: list):
123
123
  # otherwise scalar-arrays will lead to indexing error
124
- trees = jax.tree_map(lambda arr: np.atleast_1d(arr), trees)
124
+ trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
125
125
 
126
126
  if len(trees) == 0:
127
127
  return trees
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- from ring.algorithms.generator import types
7
6
  import tree_utils
8
7
 
8
+ from ring.algorithms.generator import types
9
+
9
10
  KEY = jax.random.PRNGKey(777)
10
11
  KEY_PERMUTATION = jax.random.PRNGKey(888)
11
12
 
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
37
38
  # permute 0-th axis, since batchsize of generator might be larger than
38
39
  # `approx_with_large_batchsize`, then we would not get a representative
39
40
  # subsample otherwise
40
- Xs = jax.tree_map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
+ Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
42
  Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
42
43
 
43
44
  # obtain statistics
44
- mean = jax.tree_map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
45
- std = jax.tree_map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
45
+ mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
46
+ std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
46
47
 
47
48
  if verbose:
48
49
  print("Mean: ", mean)
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
51
52
  eps = 1e-8
52
53
 
53
54
  def normalizer(X):
54
- return jax.tree_map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
+ return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
56
 
56
57
  return normalizer
@@ -1,5 +1,6 @@
1
1
  from importlib import import_module as _import_module
2
2
  import io
3
+ from pathlib import Path
3
4
  import pickle
4
5
  import random
5
6
  from typing import Optional
@@ -152,13 +153,28 @@ def import_lib(
152
153
 
153
154
  def pickle_save(obj, path, overwrite: bool = False):
154
155
  path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
155
- with open(path, "wb") as file:
156
- pickle.dump(obj, file, protocol=5)
156
+ try:
157
+ with open(path, "wb") as file:
158
+ pickle.dump(obj, file, protocol=5)
159
+ except OSError as e:
160
+ print(
161
+ f"saving with `pickle` throws exception {e}. "
162
+ + "Attempting to save using `joblib`"
163
+ )
164
+ path = parse_path(path, extension="joblib", file_exists_ok=overwrite)
165
+ import joblib
166
+
167
+ joblib.dump(obj, path)
157
168
 
158
169
 
159
170
  def pickle_load(
160
171
  path,
161
172
  ):
173
+ if Path(path).suffix == ".joblib":
174
+ import joblib
175
+
176
+ return joblib.load(path)
177
+
162
178
  path = parse_path(path, extension="pickle", require_is_file=True)
163
179
  with open(path, "rb") as file:
164
180
  obj = pickle.load(file)
@@ -41,7 +41,7 @@ def test_save_load_generators():
41
41
  data = rcmg.to_list()[0]
42
42
  rcmg.to_pickle(path)
43
43
 
44
- data_list = [jax.tree_map(lambda a: a[0], utils.pickle_load(path))]
44
+ data_list = [jax.tree.map(lambda a: a[0], utils.pickle_load(path))]
45
45
  gen_reloaded = ring.RCMG.eager_gen_from_list(data_list, 1)
46
46
  data_reloaded = unbatch_gen(gen_reloaded)(jax.random.PRNGKey(1))
47
47
 
@@ -2,6 +2,7 @@ from _compat import unbatch_gen
2
2
  import jax
3
3
  import jax.numpy as jnp
4
4
  import numpy as np
5
+
5
6
  import ring
6
7
  from ring import maths
7
8
  from ring import sim2real
@@ -49,7 +50,7 @@ def test_forward_kinematics_omc():
49
50
  # t1_omc should be used when p == -1, else t1_sys
50
51
  @jax.vmap
51
52
  def merge_transform1(t1_omc):
52
- return jax.tree_map(
53
+ return jax.tree.map(
53
54
  lambda a, b: jnp.where(
54
55
  jnp.repeat(
55
56
  jnp.array(sys.link_parents)[:, None] == -1,
@@ -138,7 +139,7 @@ def test_zip_unzip_scale():
138
139
  t1, t2 = sim2real.unzip_xs(sys, xs)
139
140
  xs_re = sim2real.zip_xs(sys, t1, t2)
140
141
 
141
- jax.tree_map(
142
+ jax.tree.map(
142
143
  lambda a, b: np.testing.assert_allclose(a, b, rtol=1e-3, atol=1e-5),
143
144
  xs,
144
145
  xs_re,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes