imt-ring 1.6.11__tar.gz → 1.6.12__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.6.11 → imt_ring-1.6.12}/PKG-INFO +1 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/pyproject.toml +1 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/SOURCES.txt +1 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/base.py +3 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/finalize_fns.py +17 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/jcalc.py +10 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/sensors.py +1 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/base.py +3 -1
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/ml_utils.py +14 -21
- imt_ring-1.6.12/src/ring/utils/dataloader.py +159 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/utils.py +13 -3
- {imt_ring-1.6.11 → imt_ring-1.6.12}/readme.md +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/setup.cfg +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algebra.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/batch.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/setup_fns.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/maths.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/base.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/train.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/mujoco_render.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/spatial.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/path.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/randomize_sys.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/register_gym_envs/__init__.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/register_gym_envs/saddle.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_algebra.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_base.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_dynamics.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_generator.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_jcalc.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_jit.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_kinematics.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_maths.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_ml_utils.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_pd_control.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_random.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_randomize.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_rcmg.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_render.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sensors.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sim2real.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_train.py +0 -0
- {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_utils.py +0 -0
@@ -321,6 +321,9 @@ def _build_mconfig_batched_generator(
|
|
321
321
|
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
322
322
|
)
|
323
323
|
|
324
|
+
if dynamic_simulation:
|
325
|
+
finalize_fns.DynamicalSimulation.assert_test_system(sys)
|
326
|
+
|
324
327
|
def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
|
325
328
|
pipe = []
|
326
329
|
if imu_motion_artifacts and randomize_motion_artifacts:
|
@@ -180,6 +180,23 @@ class DynamicalSimulation:
|
|
180
180
|
self.overwrite_q_ref = overwrite_q_ref
|
181
181
|
self.unroll_kwargs = unroll_kwargs
|
182
182
|
|
183
|
+
@staticmethod
|
184
|
+
def assert_test_system(sys: base.System) -> None:
|
185
|
+
"test that system has no zero mass bodies and no joints without damping"
|
186
|
+
|
187
|
+
def f(_, __, n, m, d):
|
188
|
+
assert d.size == 0 or m > 0, (
|
189
|
+
"Dynamic simulation is set to `True` which requires masses >= 0, "
|
190
|
+
f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
|
191
|
+
)
|
192
|
+
|
193
|
+
assert d.size == 0 or all(d > 0.0), (
|
194
|
+
"Dynamic simulation is set to `True` which requires dampings > 0, "
|
195
|
+
f"but found body `{n}` with damping={d}. This can lead to NaNs."
|
196
|
+
)
|
197
|
+
|
198
|
+
sys.scan(f, "lld", sys.link_names, sys.links.inertia.mass, sys.link_damping)
|
199
|
+
|
183
200
|
def __call__(
|
184
201
|
self, Xy: types.Xy, extras: types.OutputExtras
|
185
202
|
) -> tuple[types.Xy, types.OutputExtras]:
|
@@ -205,7 +205,7 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
|
|
205
205
|
return False
|
206
206
|
return True
|
207
207
|
|
208
|
-
|
208
|
+
cond1 = all(
|
209
209
|
[
|
210
210
|
dx_deltax_check(*args)
|
211
211
|
for args in zip(
|
@@ -217,6 +217,15 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
|
|
217
217
|
]
|
218
218
|
)
|
219
219
|
|
220
|
+
# this one tests that the initial value is inside the feasible value range
|
221
|
+
# so e.g. if you choose pos0_min=-10 then you can't choose pos_min=-1
|
222
|
+
def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
|
223
|
+
return (x0_min >= x_min) and (x0_max <= x_max)
|
224
|
+
|
225
|
+
cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
|
226
|
+
|
227
|
+
return cond1 and cond2
|
228
|
+
|
220
229
|
|
221
230
|
def _find_interval(t: jax.Array, boundaries: jax.Array):
|
222
231
|
"""Find the interval of `boundaries` between which `t` lies.
|
@@ -131,7 +131,7 @@ def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
|
|
131
131
|
# - gyr: rad/s
|
132
132
|
# - mag: a.u.
|
133
133
|
NOISE_LEVELS = {"acc": 0.048, "gyr": jnp.deg2rad(0.7), "mag": 0.01}
|
134
|
-
BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3
|
134
|
+
BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3), "mag": 0.0}
|
135
135
|
|
136
136
|
|
137
137
|
def add_noise_bias(
|
@@ -690,7 +690,9 @@ class System(_Base):
|
|
690
690
|
transparent_segment_to_root: bool = True,
|
691
691
|
**kwargs,
|
692
692
|
):
|
693
|
-
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
|
693
|
+
"""`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
|
694
|
+
Note that the body in yhat that connects to -1, is parent-to-child!
|
695
|
+
"""
|
694
696
|
return ring.rendering.render_prediction(
|
695
697
|
self, xs, yhat, transparent_segment_to_root, **kwargs
|
696
698
|
)
|
@@ -12,7 +12,6 @@ import numpy as np
|
|
12
12
|
from tree_utils import PyTree
|
13
13
|
|
14
14
|
import ring
|
15
|
-
from ring.utils import import_lib
|
16
15
|
import wandb
|
17
16
|
|
18
17
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -190,36 +189,30 @@ def unique_id() -> str:
|
|
190
189
|
|
191
190
|
def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
192
191
|
from jax.experimental import jax2tf
|
192
|
+
import tensorflow as tf
|
193
193
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
signature = jax.tree_map(
|
198
|
-
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
199
|
-
)
|
194
|
+
signature = jax.tree_map(
|
195
|
+
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
196
|
+
)
|
200
197
|
|
201
|
-
|
202
|
-
def __init__(self, jax_func):
|
203
|
-
super().__init__()
|
204
|
-
self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
|
198
|
+
tf_func = jax2tf.convert(jax_func, with_gradient=False)
|
205
199
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
)
|
212
|
-
def __call__(self, *args):
|
213
|
-
return self.tf_func(*args)
|
200
|
+
class RingTFModule(tf.Module):
|
201
|
+
@partial(
|
202
|
+
tf.function, autograph=False, jit_compile=True, input_signature=signature
|
203
|
+
)
|
204
|
+
def __call__(self, *args):
|
205
|
+
return tf_func(*args)
|
214
206
|
|
215
|
-
|
207
|
+
model = RingTFModule()
|
216
208
|
|
217
|
-
model = _create_module(jax_func, input)
|
218
209
|
tf.saved_model.save(
|
219
210
|
model,
|
220
211
|
path,
|
221
212
|
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
213
|
+
signatures={"default": model.__call__},
|
222
214
|
)
|
215
|
+
|
223
216
|
if validate:
|
224
217
|
output_jax = jax_func(*input)
|
225
218
|
output_tf = tf.saved_model.load(path)(*input)
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import os
|
2
|
+
import random
|
3
|
+
from typing import Callable, Optional
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import numpy as np
|
7
|
+
from ring.utils import parse_path
|
8
|
+
from ring.utils import pickle_load
|
9
|
+
import torch
|
10
|
+
from torch.utils.data import DataLoader
|
11
|
+
from torch.utils.data import Dataset
|
12
|
+
import tqdm
|
13
|
+
from tree_utils import PyTree
|
14
|
+
|
15
|
+
|
16
|
+
def make_generator(
|
17
|
+
*paths,
|
18
|
+
batch_size,
|
19
|
+
transform,
|
20
|
+
shuffle=True,
|
21
|
+
seed: int = 1,
|
22
|
+
backend: str = "eager",
|
23
|
+
**kwargs,
|
24
|
+
):
|
25
|
+
if backend == "grain":
|
26
|
+
_make_gen = pygrain_generator
|
27
|
+
elif backend == "torch":
|
28
|
+
_make_gen = pytorch_generator
|
29
|
+
elif backend == "eager":
|
30
|
+
_make_gen = eager_generator
|
31
|
+
else:
|
32
|
+
raise NotImplementedError
|
33
|
+
|
34
|
+
return _make_gen(
|
35
|
+
*paths,
|
36
|
+
batch_size=batch_size,
|
37
|
+
transform=transform,
|
38
|
+
shuffle=shuffle,
|
39
|
+
seed=seed,
|
40
|
+
**kwargs,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
T = PyTree[np.ndarray]
|
45
|
+
|
46
|
+
|
47
|
+
class _Dataset(Dataset):
|
48
|
+
def __init__(self, *paths, transform):
|
49
|
+
|
50
|
+
self.files = [self.listdir(path) for path in paths]
|
51
|
+
Ns = set([len(f) for f in self.files])
|
52
|
+
assert len(Ns) == 1, f"{Ns}"
|
53
|
+
|
54
|
+
self.P = len(self.files)
|
55
|
+
self.N = list(Ns)[0]
|
56
|
+
self.transform = transform
|
57
|
+
|
58
|
+
def __len__(self):
|
59
|
+
return self.N
|
60
|
+
|
61
|
+
def __getitem__(self, idx: int):
|
62
|
+
element = [pickle_load(self.files[p][idx]) for p in range(self.P)]
|
63
|
+
if self.transform is not None:
|
64
|
+
element = self.transform(element)
|
65
|
+
return element
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def listdir(path: str) -> list:
|
69
|
+
return [parse_path(path, file) for file in os.listdir(path)]
|
70
|
+
|
71
|
+
def __call__(self, idx: int):
|
72
|
+
return self[idx]
|
73
|
+
|
74
|
+
|
75
|
+
class TransformTransform:
|
76
|
+
def __init__(self, transform):
|
77
|
+
self.transform = transform
|
78
|
+
|
79
|
+
def __call__(self, element):
|
80
|
+
if self.transform is None:
|
81
|
+
return element
|
82
|
+
return self.transform(element, np.random.default_rng())
|
83
|
+
|
84
|
+
|
85
|
+
def pytorch_generator(
|
86
|
+
*paths,
|
87
|
+
batch_size: int,
|
88
|
+
transform: Optional[Callable[[T], T]] = None,
|
89
|
+
shuffle=True,
|
90
|
+
seed: int = 1,
|
91
|
+
**kwargs,
|
92
|
+
):
|
93
|
+
torch.manual_seed(seed)
|
94
|
+
|
95
|
+
ds = _Dataset(*paths, transform=TransformTransform(transform))
|
96
|
+
dl = DataLoader(
|
97
|
+
ds,
|
98
|
+
batch_size=batch_size,
|
99
|
+
shuffle=shuffle,
|
100
|
+
multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
|
101
|
+
**kwargs,
|
102
|
+
)
|
103
|
+
dl_iter = iter(dl)
|
104
|
+
|
105
|
+
def to_numpy(tree: PyTree[torch.Tensor]):
|
106
|
+
return jax.tree_map(lambda tensor: tensor.numpy(), tree)
|
107
|
+
|
108
|
+
def generator(_):
|
109
|
+
nonlocal dl, dl_iter
|
110
|
+
try:
|
111
|
+
return to_numpy(next(dl_iter))
|
112
|
+
except StopIteration:
|
113
|
+
dl_iter = iter(dl)
|
114
|
+
return to_numpy(next(dl_iter))
|
115
|
+
|
116
|
+
return generator
|
117
|
+
|
118
|
+
|
119
|
+
def eager_generator(
|
120
|
+
*paths,
|
121
|
+
batch_size: int,
|
122
|
+
transform: Optional[Callable[[T], T]] = None,
|
123
|
+
shuffle=True,
|
124
|
+
seed=1,
|
125
|
+
):
|
126
|
+
from ring import RCMG
|
127
|
+
|
128
|
+
random.seed(seed)
|
129
|
+
|
130
|
+
ds = _Dataset(*paths, transform=TransformTransform(transform))
|
131
|
+
data = [ds[i] for i in tqdm.tqdm(range(len(ds)), total=len(ds))]
|
132
|
+
return RCMG.eager_gen_from_list(data, batch_size, shuffle=shuffle)
|
133
|
+
|
134
|
+
|
135
|
+
def pygrain_generator(
|
136
|
+
*paths, batch_size: int, transform=None, shuffle=True, seed=1, **kwargs
|
137
|
+
):
|
138
|
+
|
139
|
+
import grain.python as pygrain # type: ignore
|
140
|
+
|
141
|
+
class _Transform(pygrain.RandomMapTransform):
|
142
|
+
def random_map(self, element, rng: np.random.Generator):
|
143
|
+
return transform(element, rng)
|
144
|
+
|
145
|
+
ds = _Dataset(*paths, transform=None)
|
146
|
+
dl = pygrain.load(
|
147
|
+
ds,
|
148
|
+
batch_size=batch_size,
|
149
|
+
shuffle=shuffle,
|
150
|
+
seed=seed,
|
151
|
+
transformations=[_Transform()],
|
152
|
+
**kwargs,
|
153
|
+
)
|
154
|
+
iter_dl = iter(dl)
|
155
|
+
|
156
|
+
def generator(_):
|
157
|
+
return next(iter_dl)
|
158
|
+
|
159
|
+
return generator
|
@@ -3,6 +3,7 @@ import io
|
|
3
3
|
import pickle
|
4
4
|
import random
|
5
5
|
from typing import Optional
|
6
|
+
import warnings
|
6
7
|
|
7
8
|
import jax
|
8
9
|
import jax.numpy as jnp
|
@@ -195,7 +196,7 @@ def replace_elements_w_nans(
|
|
195
196
|
assert min(include_elements) >= 0
|
196
197
|
assert max(include_elements) < len(list_of_data)
|
197
198
|
|
198
|
-
def _is_nan(ele: tree_utils.PyTree, i: int):
|
199
|
+
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool):
|
199
200
|
isnan = np.any(
|
200
201
|
[np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
|
201
202
|
)
|
@@ -205,13 +206,22 @@ def replace_elements_w_nans(
|
|
205
206
|
return True
|
206
207
|
return False
|
207
208
|
|
209
|
+
list_of_isnan = [int(_is_nan(e, 0, False)) for e in list_of_data]
|
210
|
+
perc_of_isnan = sum(list_of_isnan) / len(list_of_data)
|
211
|
+
|
212
|
+
if perc_of_isnan >= 0.02:
|
213
|
+
warnings.warn(
|
214
|
+
f"{perc_of_isnan * 100}% of {len(list_of_data)} datapoints are NaN"
|
215
|
+
)
|
216
|
+
assert perc_of_isnan != 1
|
217
|
+
|
208
218
|
list_of_data_nonan = []
|
209
219
|
for i, ele in enumerate(list_of_data):
|
210
|
-
if _is_nan(ele, i):
|
220
|
+
if _is_nan(ele, i, verbose):
|
211
221
|
while True:
|
212
222
|
j = random.choice(include_elements)
|
213
223
|
ele_j = list_of_data[j]
|
214
|
-
if not _is_nan(ele_j, j):
|
224
|
+
if not _is_nan(ele_j, j, verbose):
|
215
225
|
ele = pytree_deepcopy(ele_j)
|
216
226
|
break
|
217
227
|
list_of_data_nonan.append(ele)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml
RENAMED
File without changes
|
{imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|