imt-ring 1.3.7__tar.gz → 1.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.7 → imt_ring-1.3.9}/PKG-INFO +1 -1
- {imt_ring-1.3.7 → imt_ring-1.3.9}/pyproject.toml +1 -1
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/imt_ring.egg-info/SOURCES.txt +1 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/dynamics.py +11 -5
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/base.py +11 -13
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/batch.py +6 -2
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/motion_artifacts.py +12 -6
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/pd_control.py +2 -1
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/base.py +1 -3
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/__init__.py +15 -3
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/callbacks.py +3 -2
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/ringnet.py +10 -2
- imt_ring-1.3.9/src/ring/ml/rnno_v1.py +41 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/train.py +1 -2
- {imt_ring-1.3.7 → imt_ring-1.3.9}/readme.md +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/setup.cfg +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algebra.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/randomize.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/transforms.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/jcalc.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/examples.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/maths.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/base.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/ml_utils.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/rendering/mujoco_render.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/spatial.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/path.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/src/ring/utils/utils.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_algebra.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_base.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_dynamics.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_generator.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_jcalc.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_jit.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_kinematics.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_maths.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_ml_utils.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_pd_control.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_random.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_randomize.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_rcmg.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_render.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_sensors.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_sim2real.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_train.py +0 -0
- {imt_ring-1.3.7 → imt_ring-1.3.9}/tests/test_utils.py +0 -0
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import Optional, Tuple
|
2
|
+
import warnings
|
2
3
|
|
3
4
|
import jax
|
4
5
|
import jax.numpy as jnp
|
6
|
+
|
5
7
|
from ring import algebra
|
6
8
|
from ring import base
|
7
9
|
from ring import maths
|
@@ -213,7 +215,7 @@ def forward_dynamics(
|
|
213
215
|
q: jax.Array,
|
214
216
|
qd: jax.Array,
|
215
217
|
tau: jax.Array,
|
216
|
-
mass_mat_inv: jax.Array,
|
218
|
+
# mass_mat_inv: jax.Array,
|
217
219
|
) -> Tuple[jax.Array, jax.Array]:
|
218
220
|
C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
|
219
221
|
mass_matrix = compute_mass_matrix(sys)
|
@@ -235,6 +237,11 @@ def forward_dynamics(
|
|
235
237
|
|
236
238
|
mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
|
237
239
|
else:
|
240
|
+
warnings.warn(
|
241
|
+
f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
|
242
|
+
"This feature is currently not fully supported. See the local TODO."
|
243
|
+
)
|
244
|
+
mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
|
238
245
|
mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
|
239
246
|
|
240
247
|
return mass_mat_inv @ qf_smooth, mass_mat_inv
|
@@ -254,9 +261,8 @@ def _strapdown_integration(
|
|
254
261
|
def _semi_implicit_euler_integration(
|
255
262
|
sys: base.System, state: base.State, taus: jax.Array
|
256
263
|
) -> base.State:
|
257
|
-
qdd, mass_mat_inv = forward_dynamics(
|
258
|
-
|
259
|
-
)
|
264
|
+
qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
|
265
|
+
del mass_mat_inv
|
260
266
|
qd_next = state.qd + sys.dt * qdd
|
261
267
|
|
262
268
|
q_next = []
|
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
|
|
277
283
|
sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
|
278
284
|
q_next = jnp.concatenate(q_next)
|
279
285
|
|
280
|
-
state = state.replace(q=q_next, qd=qd_next
|
286
|
+
state = state.replace(q=q_next, qd=qd_next)
|
281
287
|
return state
|
282
288
|
|
283
289
|
|
@@ -4,6 +4,7 @@ import warnings
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
import tqdm
|
7
8
|
import tree_utils
|
8
9
|
|
9
10
|
from ring import base
|
@@ -83,10 +84,14 @@ class RCMG:
|
|
83
84
|
), "If `randomize_anchors`, then only one system is expected"
|
84
85
|
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
85
86
|
|
86
|
-
zip_sys_config = False
|
87
87
|
if randomize_hz:
|
88
|
-
zip_sys_config = True
|
89
88
|
sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
|
89
|
+
else:
|
90
|
+
# create zip
|
91
|
+
N_sys = len(sys)
|
92
|
+
sys = sum([len(config) * [s] for s in sys], start=[])
|
93
|
+
config = N_sys * config
|
94
|
+
assert len(sys) == len(config)
|
90
95
|
|
91
96
|
if sys_ml is None:
|
92
97
|
# TODO
|
@@ -97,17 +102,10 @@ class RCMG:
|
|
97
102
|
sys_ml = sys[0]
|
98
103
|
|
99
104
|
self.gens = []
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
)
|
105
|
-
else:
|
106
|
-
for _sys in sys:
|
107
|
-
for _config in config:
|
108
|
-
self.gens.append(
|
109
|
-
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
110
|
-
)
|
105
|
+
for _sys, _config in tqdm.tqdm(
|
106
|
+
zip(sys, config), desc="building generators", total=len(sys)
|
107
|
+
):
|
108
|
+
self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
|
111
109
|
|
112
110
|
def _to_data(self, sizes, seed):
|
113
111
|
return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
|
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
|
|
86
86
|
|
87
87
|
key = jax.random.PRNGKey(seed)
|
88
88
|
data = []
|
89
|
-
for gen, size in tqdm(
|
89
|
+
for gen, size in tqdm(
|
90
|
+
zip(generators, sizes),
|
91
|
+
desc="executing generators",
|
92
|
+
total=len(sizes),
|
93
|
+
):
|
90
94
|
|
91
95
|
n_calls = _number_of_executions_required(size)
|
92
96
|
# decrease size by n_calls times
|
@@ -147,7 +151,7 @@ def _data_fn_from_paths(
|
|
147
151
|
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
148
152
|
|
149
153
|
extensions = list(set([Path(p).suffix for p in paths]))
|
150
|
-
assert len(extensions) == 1
|
154
|
+
assert len(extensions) == 1, f"{extensions}"
|
151
155
|
|
152
156
|
if extensions[0] == ".h5":
|
153
157
|
N = sum([utils.hdf5_load_length(p) for p in paths])
|
@@ -49,6 +49,7 @@ def inject_subsystems(
|
|
49
49
|
rotational_damp: float = 0.1,
|
50
50
|
translational_stif: float = 50.0,
|
51
51
|
translational_damp: float = 0.1,
|
52
|
+
disable_warning: bool = False,
|
52
53
|
**kwargs,
|
53
54
|
) -> base.System:
|
54
55
|
imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
|
@@ -92,10 +93,11 @@ def inject_subsystems(
|
|
92
93
|
# TODO set all joint_params to zeros; they can not be preserved anyways and
|
93
94
|
# otherwise many warnings will be rose
|
94
95
|
# instead warn explicitly once now and move on
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
96
|
+
if not disable_warning:
|
97
|
+
warnings.warn(
|
98
|
+
"`sys.links.joint_params` has been set to zero, this might lead to "
|
99
|
+
"unexpected behaviour unless you use `randomize_joint_params`"
|
100
|
+
)
|
99
101
|
joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
|
100
102
|
sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
|
101
103
|
|
@@ -180,9 +182,13 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
180
182
|
link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
|
181
183
|
link_damping = link_damping.at[slice].set(damp)
|
182
184
|
|
183
|
-
assert len(imus_surely_rigid) == len(
|
185
|
+
assert len(imus_surely_rigid) == len(
|
186
|
+
triggered_surely_rigid
|
187
|
+
), f"{imus_surely_rigid}, {triggered_surely_rigid}"
|
184
188
|
for imu_surely_rigid in imus_surely_rigid:
|
185
|
-
assert
|
189
|
+
assert (
|
190
|
+
imu_surely_rigid in triggered_surely_rigid
|
191
|
+
), f"{imus_surely_rigid} not in {triggered_surely_rigid}"
|
186
192
|
|
187
193
|
return sys.replace(
|
188
194
|
link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
from flax import struct
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
|
7
8
|
from ring import base
|
8
9
|
from ring.algorithms import dynamics
|
9
10
|
from ring.algorithms import jcalc
|
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
49
50
|
assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
|
50
51
|
assert sys.qd_size() == P.size
|
51
52
|
if D is not None:
|
52
|
-
sys.qd_size() == D.size
|
53
|
+
assert sys.qd_size() == D.size
|
53
54
|
|
54
55
|
q_ref_as_dict = {}
|
55
56
|
qd_ref_as_dict = {}
|
@@ -997,13 +997,11 @@ class State(_Base):
|
|
997
997
|
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
998
998
|
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
999
999
|
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
1000
|
-
mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
|
1001
1000
|
"""
|
1002
1001
|
|
1003
1002
|
q: jax.Array
|
1004
1003
|
qd: jax.Array
|
1005
1004
|
x: Transform
|
1006
|
-
mass_mat_inv: jax.Array
|
1007
1005
|
|
1008
1006
|
@classmethod
|
1009
1007
|
def create(
|
@@ -1057,4 +1055,4 @@ class State(_Base):
|
|
1057
1055
|
if x is None:
|
1058
1056
|
x = Transform.zero((sys.num_links(),))
|
1059
1057
|
|
1060
|
-
return cls(q, qd, x
|
1058
|
+
return cls(q, qd, x)
|
@@ -3,6 +3,7 @@ from . import callbacks
|
|
3
3
|
from . import ml_utils
|
4
4
|
from . import optimizer
|
5
5
|
from . import ringnet
|
6
|
+
from . import rnno_v1
|
6
7
|
from . import train
|
7
8
|
from . import training_loop
|
8
9
|
from .base import AbstractFilter
|
@@ -42,17 +43,28 @@ def RNNO(
|
|
42
43
|
params=None,
|
43
44
|
eval: bool = True,
|
44
45
|
samp_freq: float | None = None,
|
46
|
+
v1: bool = False,
|
45
47
|
**kwargs,
|
46
48
|
):
|
47
49
|
assert "message_dim" not in kwargs
|
48
50
|
assert "link_output_normalize" not in kwargs
|
49
51
|
assert "link_output_dim" not in kwargs
|
50
52
|
|
53
|
+
if v1:
|
54
|
+
kwargs.update(
|
55
|
+
dict(forward_factory=rnno_v1.rnno_v1_forward_factory, output_dim=output_dim)
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
kwargs.update(
|
59
|
+
dict(
|
60
|
+
message_dim=0,
|
61
|
+
link_output_normalize=False,
|
62
|
+
link_output_dim=output_dim,
|
63
|
+
)
|
64
|
+
)
|
65
|
+
|
51
66
|
ringnet = RING( # noqa: F811
|
52
67
|
params=params,
|
53
|
-
message_dim=0,
|
54
|
-
link_output_normalize=False,
|
55
|
-
link_output_dim=output_dim,
|
56
68
|
**kwargs,
|
57
69
|
)
|
58
70
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
@@ -245,7 +245,8 @@ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
|
|
245
245
|
else:
|
246
246
|
value = "{:.2f}".format(ele.value).replace(".", ",")
|
247
247
|
filename = parse_path(
|
248
|
-
self.path_to_file
|
248
|
+
str(Path(self.path_to_file).with_suffix(""))
|
249
|
+
+ f"_episode={ele.episode}_value={value}",
|
249
250
|
extension="pickle",
|
250
251
|
)
|
251
252
|
|
@@ -404,7 +405,7 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
|
|
404
405
|
# only checkpoint if run has been killed
|
405
406
|
if training_loop.recv_kill_run_signal():
|
406
407
|
path = parse_path(
|
407
|
-
"~/.
|
408
|
+
"~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
|
408
409
|
)
|
409
410
|
data = {"params": self.params, "opt_state": self.opt_state}
|
410
411
|
pickle_save(
|
@@ -191,8 +191,16 @@ class LSTM(hk.RNNCore):
|
|
191
191
|
|
192
192
|
|
193
193
|
class RING(ml_base.AbstractFilter):
|
194
|
-
def __init__(
|
195
|
-
self
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
params=None,
|
197
|
+
lam=None,
|
198
|
+
jit: bool = True,
|
199
|
+
name=None,
|
200
|
+
forward_factory=make_ring,
|
201
|
+
**kwargs,
|
202
|
+
):
|
203
|
+
self.forward_lam_factory = partial(forward_factory, **kwargs)
|
196
204
|
self.params = self._load_params(params)
|
197
205
|
self.lam = lam
|
198
206
|
self._name = name
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Sequence
|
2
|
+
|
3
|
+
import haiku as hk
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
|
8
|
+
def rnno_v1_forward_factory(
|
9
|
+
output_dim: int,
|
10
|
+
rnn_layers: Sequence[int] = (400, 300),
|
11
|
+
linear_layers: Sequence[int] = (200, 100, 50, 50, 25, 25),
|
12
|
+
layernorm: bool = True,
|
13
|
+
act_fn_linear=jax.nn.relu,
|
14
|
+
act_fn_rnn=jax.nn.elu,
|
15
|
+
lam: Optional[tuple[int]] = None,
|
16
|
+
):
|
17
|
+
# unused
|
18
|
+
del lam
|
19
|
+
|
20
|
+
@hk.without_apply_rng
|
21
|
+
@hk.transform_with_state
|
22
|
+
def forward_fn(X):
|
23
|
+
assert X.shape[-2] == 1
|
24
|
+
|
25
|
+
for i, n_units in enumerate(rnn_layers):
|
26
|
+
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
|
27
|
+
X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
|
28
|
+
hk.set_state(f"rnn_{i}", state)
|
29
|
+
|
30
|
+
if layernorm:
|
31
|
+
X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
|
32
|
+
X = act_fn_rnn(X)
|
33
|
+
|
34
|
+
for n_units in linear_layers:
|
35
|
+
X = hk.Linear(n_units)(X)
|
36
|
+
X = act_fn_linear(X)
|
37
|
+
|
38
|
+
y = hk.Linear(output_dim)(X)
|
39
|
+
return y[..., None, :]
|
40
|
+
|
41
|
+
return forward_fn
|
@@ -15,7 +15,6 @@ from ring.ml import ml_utils
|
|
15
15
|
from ring.ml import training_loop
|
16
16
|
from ring.utils import distribute_batchsize
|
17
17
|
from ring.utils import expand_batchsize
|
18
|
-
from ring.utils import parse_path
|
19
18
|
from ring.utils import pickle_load
|
20
19
|
import wandb
|
21
20
|
|
@@ -217,7 +216,7 @@ def train_fn(
|
|
217
216
|
|
218
217
|
callbacks_all.append(
|
219
218
|
ml_callbacks.SaveParamsTrainingLoopCallback(
|
220
|
-
path_to_file=
|
219
|
+
path_to_file=callback_save_params,
|
221
220
|
last_n_params=3,
|
222
221
|
track_metrices=callback_save_params_track_metrices,
|
223
222
|
cleanup=False,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|