imt-ring 1.6.15__tar.gz → 1.6.17__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (118) hide show
  1. {imt_ring-1.6.15 → imt_ring-1.6.17}/PKG-INFO +3 -3
  2. {imt_ring-1.6.15 → imt_ring-1.6.17}/pyproject.toml +1 -1
  3. {imt_ring-1.6.15 → imt_ring-1.6.17}/readme.md +2 -2
  4. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/imt_ring.egg-info/PKG-INFO +3 -3
  5. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/imt_ring.egg-info/SOURCES.txt +1 -0
  6. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/__init__.py +4 -0
  7. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/custom_joints/__init__.py +1 -0
  8. imt_ring-1.6.17/src/ring/algorithms/custom_joints/rsaddle_joint.py +40 -0
  9. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/finalize_fns.py +9 -5
  10. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/jcalc.py +32 -1
  11. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/ringnet.py +11 -2
  12. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_rcmg.py +4 -1
  13. {imt_ring-1.6.15 → imt_ring-1.6.17}/setup.cfg +0 -0
  14. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  15. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/imt_ring.egg-info/requires.txt +0 -0
  16. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/imt_ring.egg-info/top_level.txt +0 -0
  17. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algebra.py +0 -0
  18. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/__init__.py +0 -0
  19. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/_random.py +0 -0
  20. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  21. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  22. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  23. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/dynamics.py +0 -0
  24. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/__init__.py +0 -0
  25. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/base.py +0 -0
  26. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/batch.py +0 -0
  27. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  28. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/pd_control.py +0 -0
  29. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/setup_fns.py +0 -0
  30. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/generator/types.py +0 -0
  31. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/kinematics.py +0 -0
  32. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/algorithms/sensors.py +0 -0
  33. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/base.py +0 -0
  34. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/__init__.py +0 -0
  35. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/branched.xml +0 -0
  36. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  37. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  38. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  39. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/inv_pendulum.xml +0 -0
  40. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  41. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/spherical_stiff.xml +0 -0
  42. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/symmetric.xml +0 -0
  43. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_all_1.xml +0 -0
  44. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_all_2.xml +0 -0
  45. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  46. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_control.xml +0 -0
  47. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  48. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_free.xml +0 -0
  49. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_kinematics.xml +0 -0
  50. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  51. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  52. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_randomize_position.xml +0 -0
  53. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_sensors.xml +0 -0
  54. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  55. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/examples.py +0 -0
  56. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/test_examples.py +0 -0
  57. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/__init__.py +0 -0
  58. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/abstract.py +0 -0
  59. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/from_xml.py +0 -0
  60. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/test_from_xml.py +0 -0
  61. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/test_to_xml.py +0 -0
  62. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/io/xml/to_xml.py +0 -0
  63. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/maths.py +0 -0
  64. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/__init__.py +0 -0
  65. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/base.py +0 -0
  66. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/callbacks.py +0 -0
  67. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/ml_utils.py +0 -0
  68. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/optimizer.py +0 -0
  69. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  70. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  71. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/rnno_v1.py +0 -0
  72. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/train.py +0 -0
  73. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/ml/training_loop.py +0 -0
  74. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/rendering/__init__.py +0 -0
  75. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/rendering/base_render.py +0 -0
  76. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/rendering/mujoco_render.py +0 -0
  77. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/rendering/vispy_render.py +0 -0
  78. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/rendering/vispy_visuals.py +0 -0
  79. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sim2real/__init__.py +0 -0
  80. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sim2real/sim2real.py +0 -0
  81. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/spatial.py +0 -0
  82. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sys_composer/__init__.py +0 -0
  83. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sys_composer/delete_sys.py +0 -0
  84. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sys_composer/inject_sys.py +0 -0
  85. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/sys_composer/morph_sys.py +0 -0
  86. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/__init__.py +0 -0
  87. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/backend.py +0 -0
  88. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/batchsize.py +0 -0
  89. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/colab.py +0 -0
  90. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/dataloader.py +0 -0
  91. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/hdf5.py +0 -0
  92. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/normalizer.py +0 -0
  93. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/path.py +0 -0
  94. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/randomize_sys.py +0 -0
  95. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  96. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  97. {imt_ring-1.6.15 → imt_ring-1.6.17}/src/ring/utils/utils.py +0 -0
  98. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_base.py +0 -0
  100. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_generator.py +0 -0
  103. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_jit.py +0 -0
  105. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_maths.py +0 -0
  107. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_ml_utils.py +0 -0
  108. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_motion_artifacts.py +0 -0
  109. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_pd_control.py +0 -0
  110. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_quickstart_example.py +0 -0
  111. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_random.py +0 -0
  112. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_randomize.py +0 -0
  113. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_render.py +0 -0
  114. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_sensors.py +0 -0
  115. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_sim2real.py +0 -0
  116. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_sys_composer.py +0 -0
  117. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_train.py +0 -0
  118. {imt_ring-1.6.15 → imt_ring-1.6.17}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.15
3
+ Version: 1.6.17
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -32,11 +32,11 @@ Requires-Dist: pytest-xdist; extra == "dev"
32
32
  Requires-Dist: nbmake; extra == "dev"
33
33
 
34
34
  <p align="center">
35
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/icon.svg" height="200" />
35
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/icon.svg" height="200" />
36
36
  </p>
37
37
 
38
38
  # Recurrent Inertial Graph-based Estimator (RING)
39
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/coverage_badge.svg" height="20" />
39
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/coverage_badge.svg" height="20" />
40
40
 
41
41
  ## Installation
42
42
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.15"
7
+ version = "1.6.17"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,9 +1,9 @@
1
1
  <p align="center">
2
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/icon.svg" height="200" />
2
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/icon.svg" height="200" />
3
3
  </p>
4
4
 
5
5
  # Recurrent Inertial Graph-based Estimator (RING)
6
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/coverage_badge.svg" height="20" />
6
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/coverage_badge.svg" height="20" />
7
7
 
8
8
  ## Installation
9
9
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.15
3
+ Version: 1.6.17
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -32,11 +32,11 @@ Requires-Dist: pytest-xdist; extra == "dev"
32
32
  Requires-Dist: nbmake; extra == "dev"
33
33
 
34
34
  <p align="center">
35
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/icon.svg" height="200" />
35
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/icon.svg" height="200" />
36
36
  </p>
37
37
 
38
38
  # Recurrent Inertial Graph-based Estimator (RING)
39
- <img src="https://raw.githubusercontent.com/SimiPixel/ring/main/docs/img/coverage_badge.svg" height="20" />
39
+ <img src="https://raw.githubusercontent.com/simon-bachhuber/ring/main/docs/img/coverage_badge.svg" height="20" />
40
40
 
41
41
  ## Installation
42
42
 
@@ -20,6 +20,7 @@ src/ring/algorithms/sensors.py
20
20
  src/ring/algorithms/custom_joints/__init__.py
21
21
  src/ring/algorithms/custom_joints/rr_imp_joint.py
22
22
  src/ring/algorithms/custom_joints/rr_joint.py
23
+ src/ring/algorithms/custom_joints/rsaddle_joint.py
23
24
  src/ring/algorithms/custom_joints/suntay.py
24
25
  src/ring/algorithms/generator/__init__.py
25
26
  src/ring/algorithms/generator/base.py
@@ -121,6 +121,7 @@ _UNIQUE_ID = None
121
121
  def setup(
122
122
  rr_joint_kwargs: None | dict = dict(),
123
123
  rr_imp_joint_kwargs: None | dict = dict(),
124
+ rsaddle_joint_kwargs: None | dict = dict(),
124
125
  suntay_joint_kwargs: None | dict = None,
125
126
  train_timing_start: None | float = None,
126
127
  unique_id: None | str = None,
@@ -138,6 +139,9 @@ def setup(
138
139
  if rr_imp_joint_kwargs is not None:
139
140
  custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
140
141
 
142
+ if rsaddle_joint_kwargs is not None:
143
+ custom_joints.register_rsaddle_joint(**rsaddle_joint_kwargs)
144
+
141
145
  if suntay_joint_kwargs is not None:
142
146
  custom_joints.register_suntay(**suntay_joint_kwargs)
143
147
 
@@ -1,5 +1,6 @@
1
1
  from .rr_imp_joint import register_rr_imp_joint
2
2
  from .rr_joint import register_rr_joint
3
+ from .rsaddle_joint import register_rsaddle_joint
3
4
  from .suntay import ConstantValue_DrawnFnPair
4
5
  from .suntay import GP_DrawFnPair
5
6
  from .suntay import MLP_DrawnFnPair
@@ -0,0 +1,40 @@
1
+ import jax.numpy as jnp
2
+
3
+ import ring
4
+ from ring import maths
5
+ from ring.algorithms.jcalc import _draw_saddle
6
+ from ring.algorithms.jcalc import _p_control_term_rxyz
7
+ from ring.algorithms.jcalc import _qd_from_q_cartesian
8
+
9
+
10
+ def register_rsaddle_joint():
11
+ def _transform(q, params):
12
+ axes = params["joint_axes"]
13
+ rot1 = maths.quat_rot_axis(axes[0], q[0])
14
+ rot2 = maths.quat_rot_axis(axes[1], q[1])
15
+ rot = maths.quat_mul(rot2, rot1)
16
+ return ring.Transform.create(rot=rot)
17
+
18
+ def _motion_fn_gen(i: int):
19
+ def _motion_fn(params):
20
+ axis = params["joint_axes"][i]
21
+ return ring.base.Motion.create(ang=axis)
22
+
23
+ return _motion_fn
24
+
25
+ joint_model = ring.JointModel(
26
+ _transform,
27
+ motion=[_motion_fn_gen(i) for i in range(2)],
28
+ rcmg_draw_fn=_draw_saddle,
29
+ p_control_term=_p_control_term_rxyz,
30
+ qd_from_q=_qd_from_q_cartesian,
31
+ init_joint_params=_draw_random_joint_axes,
32
+ )
33
+
34
+ ring.register_new_joint_type("rsaddle", joint_model, 2, overwrite=True)
35
+
36
+
37
+ def _draw_random_joint_axes(key):
38
+ return dict(
39
+ joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key, (2,)))
40
+ )
@@ -160,6 +160,7 @@ _P_gains = {
160
160
  "spherical": jnp.array(3 * [P_rot]),
161
161
  "p3d": jnp.array(3 * [P_pos]),
162
162
  "saddle": jnp.array([P_rot, P_rot]),
163
+ "rsaddle": jnp.array([P_rot, P_rot]),
163
164
  "frozen": jnp.array([]),
164
165
  "suntay": jnp.array([P_rot]),
165
166
  }
@@ -182,13 +183,16 @@ class DynamicalSimulation:
182
183
 
183
184
  @staticmethod
184
185
  def assert_test_system(sys: base.System) -> None:
185
- "test that system has no zero mass bodies and no joints without damping"
186
+ "test that system has no zero mass leaf bodies and no joints without damping"
186
187
 
187
188
  def f(_, __, n, m, d):
188
- assert d.size == 0 or m > 0, (
189
- "Dynamic simulation is set to `True` which requires masses >= 0, "
190
- f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
191
- )
189
+ is_leaf_body = len(sys.children(n)) == 0
190
+ if is_leaf_body:
191
+ assert d.size == 0 or m > 0, (
192
+ "Dynamic simulation is set to `True` which requires masses >= 0, "
193
+ f"but found body `{n}` with mass={float(m[0])}. This can lead to "
194
+ "NaNs."
195
+ )
192
196
 
193
197
  assert d.size == 0 or all(d > 0.0), (
194
198
  "Dynamic simulation is set to `True` which requires dampings > 0, "
@@ -1,3 +1,4 @@
1
+ from collections import defaultdict
1
2
  from dataclasses import asdict
2
3
  from dataclasses import dataclass
3
4
  from dataclasses import field
@@ -87,11 +88,39 @@ class MotionConfig:
87
88
  assert nomotion_config.is_feasible()
88
89
  return nomotion_config
89
90
 
91
+ @staticmethod
92
+ def overwrite_for_joint_type(joint_type: str, **changes) -> None:
93
+ """Changes values of the `MotionConfig` used by the draw_fn for only a specific
94
+ joint.
95
+ """
96
+ previous_changes = _overwrite_for_joint_type_changes[joint_type]
97
+ for change in changes:
98
+ assert change not in previous_changes, f"For jointtype={joint_type} you "
99
+ f"previously changed the value={change}. You can't change it again, this "
100
+ "is not supported."
101
+ previous_changes.update(changes)
102
+
103
+ jm = get_joint_model(joint_type)
104
+
105
+ def draw_fn(config, *args):
106
+ return jm.rcmg_draw_fn(replace(config, **changes), *args)
107
+
108
+ register_new_joint_type(
109
+ joint_type,
110
+ replace(jm, rcmg_draw_fn=draw_fn),
111
+ base.Q_WIDTHS[joint_type],
112
+ base.QD_WIDTHS[joint_type],
113
+ overwrite=True,
114
+ )
115
+
90
116
  @staticmethod
91
117
  def from_register(name: str) -> "MotionConfig":
92
118
  return _registered_motion_configs[name]
93
119
 
94
120
 
121
+ _overwrite_for_joint_type_changes: dict[str, dict] = defaultdict(lambda: dict())
122
+
123
+
95
124
  _registered_motion_configs = {
96
125
  "hinUndHer": MotionConfig(
97
126
  t_min=0.3,
@@ -222,7 +251,9 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
222
251
  def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
223
252
  return (x0_min >= x_min) and (x0_max <= x_max)
224
253
 
225
- cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
254
+ cond2 = inside_box_checks(
255
+ _to_float(c.pos_min, 0.0), _to_float(c.pos_max, 0.0), c.pos0_min, c.pos0_max
256
+ )
226
257
 
227
258
  return cond1 and cond2
228
259
 
@@ -87,6 +87,7 @@ def make_ring(
87
87
  link_output_normalize: bool = True,
88
88
  link_output_transform: Optional[Callable] = None,
89
89
  layernorm: bool = True,
90
+ layernorm_trainable: bool = True,
90
91
  ) -> SimpleNamespace:
91
92
 
92
93
  if link_output_normalize:
@@ -104,7 +105,11 @@ def make_ring(
104
105
  )
105
106
 
106
107
  inner_cell = StackedRNNCell(
107
- celltype, hidden_state_dim, stack_rnn_cells, layernorm=layernorm
108
+ celltype,
109
+ hidden_state_dim,
110
+ stack_rnn_cells,
111
+ layernorm=layernorm,
112
+ layernorm_trainable=layernorm_trainable,
108
113
  )
109
114
  send_output = hk.nets.MLP([hidden_state_dim, link_output_dim])
110
115
  state = hk.get_state(
@@ -143,6 +148,7 @@ class StackedRNNCell(hk.Module):
143
148
  hidden_state_dim,
144
149
  stacks: int,
145
150
  layernorm: bool = False,
151
+ layernorm_trainable: bool = True,
146
152
  name: str | None = None,
147
153
  ):
148
154
  super().__init__(name)
@@ -150,6 +156,7 @@ class StackedRNNCell(hk.Module):
150
156
 
151
157
  self.cells = [cell(hidden_state_dim) for _ in range(stacks)]
152
158
  self.layernorm = layernorm
159
+ self.layernorm_trainable = layernorm_trainable
153
160
 
154
161
  def __call__(self, x, state):
155
162
  output = x
@@ -159,7 +166,9 @@ class StackedRNNCell(hk.Module):
159
166
  next_state.append(next_state_i)
160
167
 
161
168
  if self.layernorm:
162
- output = hk.LayerNorm(-1, True, True)(output)
169
+ output = hk.LayerNorm(
170
+ -1, self.layernorm_trainable, self.layernorm_trainable
171
+ )(output)
163
172
 
164
173
  return output, jnp.stack(next_state)
165
174
 
@@ -37,6 +37,7 @@ def test_batch_generator(N: int, seed: int):
37
37
  def test_initial_ang_pos_values():
38
38
  T = 1.0
39
39
  bs = 8
40
+ pos_min, pos_max = -5.0, 5.0
40
41
  # system consists only of prismatic, and then revolute joint
41
42
  sys = ring.io.load_example("test_ang0_pos0")
42
43
 
@@ -48,6 +49,8 @@ def test_initial_ang_pos_values():
48
49
  ang0_max=ang0_max,
49
50
  pos0_min=pos0_min,
50
51
  pos0_max=pos0_max,
52
+ pos_min=pos_min,
53
+ pos_max=pos_max,
51
54
  T=T,
52
55
  ),
53
56
  finalize_fn=lambda key, q, x, sys: (q, x),
@@ -55,7 +58,7 @@ def test_initial_ang_pos_values():
55
58
  q, _ = gen(jax.random.PRNGKey(1))
56
59
  return q
57
60
 
58
- for init_val in np.linspace(-5.0, 5.0, num=10):
61
+ for init_val in np.linspace(pos_min, pos_max, num=10):
59
62
  q = rcmg(init_val, init_val, init_val, init_val)
60
63
  np.testing.assert_allclose(q[:, 0, 0], init_val * jnp.ones((bs,)))
61
64
  np.testing.assert_allclose(q[:, 0, 1], wrap_to_pi(init_val * jnp.ones((bs,))))
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes