imt-ring 1.4.1__tar.gz → 1.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.4.1 → imt_ring-1.5.0}/PKG-INFO +1 -1
- {imt_ring-1.4.1 → imt_ring-1.5.0}/pyproject.toml +1 -1
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/SOURCES.txt +3 -2
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/__init__.py +21 -10
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/__init__.py +1 -11
- imt_ring-1.5.0/src/ring/algorithms/generator/__init__.py +11 -0
- imt_ring-1.5.0/src/ring/algorithms/generator/base.py +375 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/batch.py +26 -109
- imt_ring-1.5.0/src/ring/algorithms/generator/finalize_fns.py +306 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/motion_artifacts.py +17 -19
- imt_ring-1.5.0/src/ring/algorithms/generator/setup_fns.py +43 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/types.py +3 -18
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/jcalc.py +0 -9
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/mujoco_render.py +2 -1
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/__init__.py +3 -4
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/batchsize.py +12 -4
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/utils.py +6 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_custom_joints.py +15 -17
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_generator.py +5 -6
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_ml_utils.py +5 -6
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_pd_control.py +9 -7
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_randomize.py +3 -2
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_rcmg.py +62 -18
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_train.py +18 -3
- imt_ring-1.4.1/src/ring/algorithms/generator/__init__.py +0 -25
- imt_ring-1.4.1/src/ring/algorithms/generator/base.py +0 -409
- imt_ring-1.4.1/src/ring/algorithms/generator/transforms.py +0 -411
- {imt_ring-1.4.1 → imt_ring-1.5.0}/readme.md +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/setup.cfg +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algebra.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/base.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/maths.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/base.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/ml_utils.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/train.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/spatial.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/path.py +0 -0
- /imt_ring-1.4.1/src/ring/algorithms/generator/randomize.py → /imt_ring-1.5.0/src/ring/utils/randomize_sys.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_algebra.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_base.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_dynamics.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_jcalc.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_jit.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_kinematics.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_maths.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_random.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_render.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sensors.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sim2real.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_utils.py +0 -0
@@ -24,10 +24,10 @@ src/ring/algorithms/custom_joints/suntay.py
|
|
24
24
|
src/ring/algorithms/generator/__init__.py
|
25
25
|
src/ring/algorithms/generator/base.py
|
26
26
|
src/ring/algorithms/generator/batch.py
|
27
|
+
src/ring/algorithms/generator/finalize_fns.py
|
27
28
|
src/ring/algorithms/generator/motion_artifacts.py
|
28
29
|
src/ring/algorithms/generator/pd_control.py
|
29
|
-
src/ring/algorithms/generator/
|
30
|
-
src/ring/algorithms/generator/transforms.py
|
30
|
+
src/ring/algorithms/generator/setup_fns.py
|
31
31
|
src/ring/algorithms/generator/types.py
|
32
32
|
src/ring/io/__init__.py
|
33
33
|
src/ring/io/examples.py
|
@@ -87,6 +87,7 @@ src/ring/utils/colab.py
|
|
87
87
|
src/ring/utils/hdf5.py
|
88
88
|
src/ring/utils/normalizer.py
|
89
89
|
src/ring/utils/path.py
|
90
|
+
src/ring/utils/randomize_sys.py
|
90
91
|
src/ring/utils/utils.py
|
91
92
|
tests/test_algebra.py
|
92
93
|
tests/test_base.py
|
@@ -20,11 +20,11 @@ from .base import System
|
|
20
20
|
from .base import Transform
|
21
21
|
|
22
22
|
|
23
|
-
def RING(lam: list[int], Ts: float | None):
|
23
|
+
def RING(lam: list[int] | None, Ts: float | None, **kwargs):
|
24
24
|
"""Creates the RING network.
|
25
25
|
|
26
26
|
Params:
|
27
|
-
lam: parent array
|
27
|
+
lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
|
28
28
|
Ts : sampling interval of IMU data; time delta in seconds
|
29
29
|
|
30
30
|
Usage:
|
@@ -55,6 +55,7 @@ def RING(lam: list[int], Ts: float | None):
|
|
55
55
|
>>>
|
56
56
|
>>> yhat, _ = ringnet.apply(X)
|
57
57
|
>>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
|
58
|
+
>>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
|
58
59
|
>>>
|
59
60
|
>>> # use `jax.jit` to compile the forward pass
|
60
61
|
>>> jit_apply = jax.jit(ringnet.apply)
|
@@ -69,13 +70,20 @@ def RING(lam: list[int], Ts: float | None):
|
|
69
70
|
from pathlib import Path
|
70
71
|
import warnings
|
71
72
|
|
73
|
+
config = dict(
|
74
|
+
use_100Hz_RING=True,
|
75
|
+
use_lpf=True,
|
76
|
+
lpf_cutoff_freq=ml._LPF_CUTOFF_FREQ,
|
77
|
+
)
|
78
|
+
config.update(kwargs)
|
79
|
+
|
72
80
|
if Ts is not None and (Ts > (1 / 40) or Ts < (1 / 200)):
|
73
81
|
warnings.warn(
|
74
82
|
"RING was only trained on sampling rates between 40 to 200 Hz "
|
75
83
|
f"but found {1 / Ts}Hz"
|
76
84
|
)
|
77
85
|
|
78
|
-
if Ts is not None and Ts == 0.01:
|
86
|
+
if Ts is not None and Ts == 0.01 and config["use_100Hz_RING"]:
|
79
87
|
# this set of parameters was trained exclusively on 100Hz data; it also
|
80
88
|
# expects F=9 features per node and not F=10 where the last features is
|
81
89
|
# the sampling interval Ts
|
@@ -86,14 +94,17 @@ def RING(lam: list[int], Ts: float | None):
|
|
86
94
|
params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
|
87
95
|
add_Ts = True
|
88
96
|
|
89
|
-
ringnet = ml.RING(
|
90
|
-
|
91
|
-
ringnet = ml.base.LPF_FilterWrapper(
|
92
|
-
ringnet,
|
93
|
-
ml._LPF_CUTOFF_FREQ,
|
94
|
-
samp_freq=None if Ts is None else 1 / Ts,
|
95
|
-
quiet=True,
|
97
|
+
ringnet = ml.RING(
|
98
|
+
params=params, lam=None if lam is None else tuple(lam), jit=False, name="RING"
|
96
99
|
)
|
100
|
+
ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
|
101
|
+
if config["use_lpf"]:
|
102
|
+
ringnet = ml.base.LPF_FilterWrapper(
|
103
|
+
ringnet,
|
104
|
+
config["lpf_cutoff_freq"],
|
105
|
+
samp_freq=None if Ts is None else 1 / Ts,
|
106
|
+
quiet=True,
|
107
|
+
)
|
97
108
|
ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
|
98
109
|
if add_Ts:
|
99
110
|
ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
|
@@ -10,21 +10,11 @@ from .dynamics import compute_mass_matrix
|
|
10
10
|
from .dynamics import forward_dynamics
|
11
11
|
from .dynamics import inverse_dynamics
|
12
12
|
from .dynamics import step
|
13
|
-
from .generator import batch_generators_eager
|
14
|
-
from .generator import batch_generators_eager_to_list
|
15
|
-
from .generator import batch_generators_lazy
|
16
|
-
from .generator import batched_generator_from_list
|
17
|
-
from .generator import batched_generator_from_paths
|
18
13
|
from .generator import FINALIZE_FN
|
19
14
|
from .generator import Generator
|
20
|
-
from .generator import GeneratorPipe
|
21
|
-
from .generator import GeneratorTrafo
|
22
|
-
from .generator import GeneratorTrafoExpandFlatten
|
23
|
-
from .generator import GeneratorTrafoRandomizePositions
|
24
|
-
from .generator import GeneratorTrafoRemoveInputExtras
|
25
|
-
from .generator import GeneratorTrafoRemoveOutputExtras
|
26
15
|
from .generator import RCMG
|
27
16
|
from .generator import SETUP_FN
|
17
|
+
from .generator.finalize_fns import GeneratorTrafoExpandFlatten
|
28
18
|
from .jcalc import get_joint_model
|
29
19
|
from .jcalc import jcalc_motion
|
30
20
|
from .jcalc import jcalc_tau
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import batch
|
3
|
+
from . import finalize_fns
|
4
|
+
from . import motion_artifacts
|
5
|
+
from . import pd_control
|
6
|
+
from . import setup_fns
|
7
|
+
from . import types
|
8
|
+
from .base import RCMG
|
9
|
+
from .types import FINALIZE_FN
|
10
|
+
from .types import Generator
|
11
|
+
from .types import SETUP_FN
|
@@ -0,0 +1,375 @@
|
|
1
|
+
from typing import Callable, Optional
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import tree_utils
|
7
|
+
|
8
|
+
from ring import base
|
9
|
+
from ring import utils
|
10
|
+
from ring.algorithms import jcalc
|
11
|
+
from ring.algorithms import kinematics
|
12
|
+
from ring.algorithms.generator import batch
|
13
|
+
from ring.algorithms.generator import finalize_fns
|
14
|
+
from ring.algorithms.generator import motion_artifacts
|
15
|
+
from ring.algorithms.generator import setup_fns
|
16
|
+
from ring.algorithms.generator import types
|
17
|
+
|
18
|
+
|
19
|
+
class RCMG:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
sys: base.System | list[base.System],
|
23
|
+
config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),
|
24
|
+
setup_fn: Optional[types.SETUP_FN] = None,
|
25
|
+
finalize_fn: Optional[types.FINALIZE_FN] = None,
|
26
|
+
add_X_imus: bool = False,
|
27
|
+
add_X_imus_kwargs: dict = dict(),
|
28
|
+
add_X_jointaxes: bool = False,
|
29
|
+
add_X_jointaxes_kwargs: dict = dict(),
|
30
|
+
add_y_relpose: bool = False,
|
31
|
+
add_y_rootincl: bool = False,
|
32
|
+
sys_ml: Optional[base.System] = None,
|
33
|
+
randomize_positions: bool = False,
|
34
|
+
randomize_motion_artifacts: bool = False,
|
35
|
+
randomize_joint_params: bool = False,
|
36
|
+
imu_motion_artifacts: bool = False,
|
37
|
+
imu_motion_artifacts_kwargs: dict = dict(hide_injected_bodies=True),
|
38
|
+
dynamic_simulation: bool = False,
|
39
|
+
dynamic_simulation_kwargs: dict = dict(),
|
40
|
+
output_transform: Optional[Callable] = None,
|
41
|
+
keep_output_extras: bool = False,
|
42
|
+
use_link_number_in_Xy: bool = False,
|
43
|
+
cor: bool = False,
|
44
|
+
disable_tqdm: bool = False,
|
45
|
+
) -> None:
|
46
|
+
|
47
|
+
sys, config = utils.to_list(sys), utils.to_list(config)
|
48
|
+
sys_ml = sys[0] if sys_ml is None else sys_ml
|
49
|
+
|
50
|
+
for c in config:
|
51
|
+
assert c.is_feasible()
|
52
|
+
|
53
|
+
if cor:
|
54
|
+
sys = [s._replace_free_with_cor() for s in sys]
|
55
|
+
|
56
|
+
self.gens = []
|
57
|
+
for _sys in sys:
|
58
|
+
self.gens.append(
|
59
|
+
_build_mconfig_batched_generator(
|
60
|
+
sys=_sys,
|
61
|
+
config=config,
|
62
|
+
setup_fn=setup_fn,
|
63
|
+
finalize_fn=finalize_fn,
|
64
|
+
add_X_imus=add_X_imus,
|
65
|
+
add_X_imus_kwargs=add_X_imus_kwargs,
|
66
|
+
add_X_jointaxes=add_X_jointaxes,
|
67
|
+
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
68
|
+
add_y_relpose=add_y_relpose,
|
69
|
+
add_y_rootincl=add_y_rootincl,
|
70
|
+
sys_ml=sys_ml,
|
71
|
+
randomize_positions=randomize_positions,
|
72
|
+
randomize_motion_artifacts=randomize_motion_artifacts,
|
73
|
+
randomize_joint_params=randomize_joint_params,
|
74
|
+
imu_motion_artifacts=imu_motion_artifacts,
|
75
|
+
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
76
|
+
dynamic_simulation=dynamic_simulation,
|
77
|
+
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
78
|
+
output_transform=output_transform,
|
79
|
+
keep_output_extras=keep_output_extras,
|
80
|
+
use_link_number_in_Xy=use_link_number_in_Xy,
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
self._n_mconfigs = len(config)
|
85
|
+
self._size_of_generators = [self._n_mconfigs] * len(self.gens)
|
86
|
+
|
87
|
+
self._disable_tqdm = disable_tqdm
|
88
|
+
|
89
|
+
def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
|
90
|
+
"how many times the generators are repeated to create a batch of `sizes`"
|
91
|
+
|
92
|
+
S, L = sum(self._size_of_generators), len(self._size_of_generators)
|
93
|
+
|
94
|
+
def assert_size(size: int):
|
95
|
+
assert self._n_mconfigs in utils.primes(size), (
|
96
|
+
f"`size`={size} is not divisible by number of "
|
97
|
+
+ f"`mconfigs`={self._n_mconfigs}"
|
98
|
+
)
|
99
|
+
|
100
|
+
if isinstance(sizes, int):
|
101
|
+
assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
|
102
|
+
assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
|
103
|
+
repeats = L * [sizes // S]
|
104
|
+
else:
|
105
|
+
for size in sizes:
|
106
|
+
assert_size(size)
|
107
|
+
|
108
|
+
assert len(sizes) == len(
|
109
|
+
self.gens
|
110
|
+
), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
|
111
|
+
|
112
|
+
repeats = [
|
113
|
+
size // size_of_gen
|
114
|
+
for size, size_of_gen in zip(sizes, self._size_of_generators)
|
115
|
+
]
|
116
|
+
assert 0 not in repeats
|
117
|
+
|
118
|
+
return repeats
|
119
|
+
|
120
|
+
def to_lazy_gen(
|
121
|
+
self, sizes: int | list[int] = 1, jit: bool = True
|
122
|
+
) -> types.BatchedGenerator:
|
123
|
+
return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def _number_of_executions_required(size: int) -> int:
|
127
|
+
_, vmap = utils.distribute_batchsize(size)
|
128
|
+
|
129
|
+
eager_threshold = utils.batchsize_thresholds()[1]
|
130
|
+
primes = iter(utils.primes(vmap))
|
131
|
+
n_calls = 1
|
132
|
+
while vmap > eager_threshold:
|
133
|
+
prime = next(primes)
|
134
|
+
n_calls *= prime
|
135
|
+
vmap /= prime
|
136
|
+
|
137
|
+
return n_calls
|
138
|
+
|
139
|
+
def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
|
140
|
+
"Returns list of unbatched sequences as numpy arrays."
|
141
|
+
repeats = self._compute_repeats(sizes)
|
142
|
+
sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
|
143
|
+
|
144
|
+
reduced_repeats = []
|
145
|
+
n_calls = []
|
146
|
+
for size, repeat in zip(sizes, repeats):
|
147
|
+
n_call = self._number_of_executions_required(size)
|
148
|
+
gcd = utils.gcd(n_call, repeat)
|
149
|
+
n_calls.append(gcd)
|
150
|
+
reduced_repeats.append(repeat // gcd)
|
151
|
+
jits = [N > 1 for N in n_calls]
|
152
|
+
|
153
|
+
gens = []
|
154
|
+
for i in range(len(repeats)):
|
155
|
+
gens.append(
|
156
|
+
batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
|
157
|
+
)
|
158
|
+
|
159
|
+
return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
|
160
|
+
|
161
|
+
def to_pickle(
|
162
|
+
self,
|
163
|
+
path: str,
|
164
|
+
sizes: int | list[int] = 1,
|
165
|
+
seed: int = 1,
|
166
|
+
overwrite: bool = True,
|
167
|
+
) -> None:
|
168
|
+
data = tree_utils.tree_batch(self.to_list(sizes, seed))
|
169
|
+
utils.pickle_save(data, path, overwrite=overwrite)
|
170
|
+
|
171
|
+
def to_eager_gen(
|
172
|
+
self,
|
173
|
+
batchsize: int = 1,
|
174
|
+
sizes: int | list[int] = 1,
|
175
|
+
seed: int = 1,
|
176
|
+
shuffle: bool = True,
|
177
|
+
) -> types.BatchedGenerator:
|
178
|
+
data = self.to_list(sizes, seed)
|
179
|
+
assert len(data) >= batchsize
|
180
|
+
|
181
|
+
def data_fn(indices: list[int]):
|
182
|
+
return tree_utils.tree_batch([data[i] for i in indices])
|
183
|
+
|
184
|
+
return batch.generator_from_data_fn(
|
185
|
+
data_fn, list(range(len(data))), shuffle, batchsize
|
186
|
+
)
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def eager_gen_from_paths(
|
190
|
+
paths: str | list[str],
|
191
|
+
batchsize: int,
|
192
|
+
include_samples: Optional[list[int]] = None,
|
193
|
+
shuffle: bool = True,
|
194
|
+
load_all_into_memory: bool = False,
|
195
|
+
tree_transform=None,
|
196
|
+
) -> tuple[types.BatchedGenerator, int]:
|
197
|
+
paths = utils.to_list(paths)
|
198
|
+
return batch.generator_from_paths(
|
199
|
+
paths,
|
200
|
+
batchsize,
|
201
|
+
include_samples,
|
202
|
+
shuffle,
|
203
|
+
load_all_into_memory=load_all_into_memory,
|
204
|
+
tree_transform=tree_transform,
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
def _copy_dicts(f) -> dict:
|
209
|
+
def _f(*args, **kwargs):
|
210
|
+
_copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
|
211
|
+
args = tuple([_copy(ele) for ele in args])
|
212
|
+
kwargs = {k: _copy(v) for k, v in kwargs.items()}
|
213
|
+
return f(*args, **kwargs)
|
214
|
+
|
215
|
+
return _f
|
216
|
+
|
217
|
+
|
218
|
+
@_copy_dicts
|
219
|
+
def _build_mconfig_batched_generator(
|
220
|
+
sys: base.System,
|
221
|
+
config: list[jcalc.MotionConfig],
|
222
|
+
setup_fn: types.SETUP_FN | None,
|
223
|
+
finalize_fn: types.FINALIZE_FN | None,
|
224
|
+
add_X_imus: bool,
|
225
|
+
add_X_imus_kwargs: dict,
|
226
|
+
add_X_jointaxes: bool,
|
227
|
+
add_X_jointaxes_kwargs: dict,
|
228
|
+
add_y_relpose: bool,
|
229
|
+
add_y_rootincl: bool,
|
230
|
+
sys_ml: base.System,
|
231
|
+
randomize_positions: bool,
|
232
|
+
randomize_motion_artifacts: bool,
|
233
|
+
randomize_joint_params: bool,
|
234
|
+
imu_motion_artifacts: bool,
|
235
|
+
imu_motion_artifacts_kwargs: dict,
|
236
|
+
dynamic_simulation: bool,
|
237
|
+
dynamic_simulation_kwargs: dict,
|
238
|
+
output_transform: Callable | None,
|
239
|
+
keep_output_extras: bool,
|
240
|
+
use_link_number_in_Xy: bool,
|
241
|
+
) -> types.BatchedGenerator:
|
242
|
+
|
243
|
+
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
244
|
+
if len(sys_ml.findall_imus()) > 0:
|
245
|
+
# warnings.warn("Automatically removed the IMUs from `sys_ml`.")
|
246
|
+
sys_noimu, _ = sys_ml.make_sys_noimu()
|
247
|
+
else:
|
248
|
+
sys_noimu = sys_ml
|
249
|
+
|
250
|
+
unactuated_subsystems = []
|
251
|
+
if imu_motion_artifacts:
|
252
|
+
assert dynamic_simulation
|
253
|
+
unactuated_subsystems = motion_artifacts.unactuated_subsystem(sys)
|
254
|
+
sys = motion_artifacts.inject_subsystems(sys, **imu_motion_artifacts_kwargs)
|
255
|
+
assert "unactuated_subsystems" not in dynamic_simulation_kwargs
|
256
|
+
dynamic_simulation_kwargs["unactuated_subsystems"] = unactuated_subsystems
|
257
|
+
|
258
|
+
if not randomize_motion_artifacts:
|
259
|
+
warnings.warn(
|
260
|
+
"`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
|
261
|
+
)
|
262
|
+
|
263
|
+
if "prob_rigid" in imu_motion_artifacts_kwargs:
|
264
|
+
assert randomize_motion_artifacts, (
|
265
|
+
"`prob_rigid` works by overwriting damping and stiffness parameters "
|
266
|
+
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
267
|
+
)
|
268
|
+
|
269
|
+
def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
|
270
|
+
pipe = []
|
271
|
+
if imu_motion_artifacts and randomize_motion_artifacts:
|
272
|
+
pipe.append(
|
273
|
+
motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
|
274
|
+
**imu_motion_artifacts_kwargs
|
275
|
+
)
|
276
|
+
)
|
277
|
+
if randomize_positions:
|
278
|
+
pipe.append(setup_fns._setup_fn_randomize_positions)
|
279
|
+
if randomize_joint_params:
|
280
|
+
pipe.append(jcalc._init_joint_params)
|
281
|
+
if setup_fn is not None:
|
282
|
+
pipe.append(setup_fn)
|
283
|
+
|
284
|
+
for f in pipe:
|
285
|
+
key, consume = jax.random.split(key)
|
286
|
+
sys = f(consume, sys)
|
287
|
+
return sys
|
288
|
+
|
289
|
+
def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
|
290
|
+
pipe = []
|
291
|
+
if dynamic_simulation:
|
292
|
+
pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
|
293
|
+
if imu_motion_artifacts and imu_motion_artifacts_kwargs["hide_injected_bodies"]:
|
294
|
+
pipe.append(motion_artifacts.HideInjectedBodies())
|
295
|
+
if finalize_fn is not None:
|
296
|
+
pipe.append(finalize_fns.FinalizeFn(finalize_fn))
|
297
|
+
if add_X_imus:
|
298
|
+
pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
|
299
|
+
if add_X_jointaxes:
|
300
|
+
pipe.append(
|
301
|
+
finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
|
302
|
+
)
|
303
|
+
if add_y_relpose:
|
304
|
+
pipe.append(finalize_fns.RelPose(sys_noimu))
|
305
|
+
if add_y_rootincl:
|
306
|
+
pipe.append(finalize_fns.RootIncl(sys_noimu))
|
307
|
+
if use_link_number_in_Xy:
|
308
|
+
pipe.append(finalize_fns.Names2Indices(sys_noimu))
|
309
|
+
|
310
|
+
for f in pipe:
|
311
|
+
Xy, extras = f(Xy, extras)
|
312
|
+
return Xy, extras
|
313
|
+
|
314
|
+
def _gen(key: types.PRNGKey):
|
315
|
+
qs = []
|
316
|
+
for _config in config:
|
317
|
+
key, _q = draw_random_q(key, sys, _config)
|
318
|
+
qs.append(_q)
|
319
|
+
qs = jnp.stack(qs)
|
320
|
+
|
321
|
+
key, *consume = jax.random.split(key, len(config) + 1)
|
322
|
+
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
323
|
+
|
324
|
+
@jax.vmap
|
325
|
+
def _vmapped_context(key, q, sys):
|
326
|
+
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
327
|
+
Xy, extras = ({}, {}), (key, q, x, sys)
|
328
|
+
return _finalize_fn(Xy, extras)
|
329
|
+
|
330
|
+
keys = jax.random.split(key, len(config))
|
331
|
+
Xy, extras = _vmapped_context(keys, qs, syss)
|
332
|
+
output = (Xy, extras) if keep_output_extras else Xy
|
333
|
+
output = output if output_transform is None else output_transform(output)
|
334
|
+
return output
|
335
|
+
|
336
|
+
return _gen
|
337
|
+
|
338
|
+
|
339
|
+
def draw_random_q(
|
340
|
+
key: types.PRNGKey,
|
341
|
+
sys: base.System,
|
342
|
+
config: jcalc.MotionConfig,
|
343
|
+
) -> tuple[types.Xy, types.OutputExtras]:
|
344
|
+
|
345
|
+
key_start = key
|
346
|
+
# build generalized coordintes vector `q`
|
347
|
+
q_list = []
|
348
|
+
|
349
|
+
def draw_q(key, __, link_type, link):
|
350
|
+
joint_params = link.joint_params
|
351
|
+
# limit scope
|
352
|
+
joint_params = (
|
353
|
+
joint_params[link_type]
|
354
|
+
if link_type in joint_params
|
355
|
+
else joint_params["default"]
|
356
|
+
)
|
357
|
+
if key is None:
|
358
|
+
key = key_start
|
359
|
+
key, key_t, key_value = jax.random.split(key, 3)
|
360
|
+
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
361
|
+
if draw_fn is None:
|
362
|
+
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
363
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
364
|
+
# even revolute and prismatic joints must be 2d arrays
|
365
|
+
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
366
|
+
q_list.append(q_link)
|
367
|
+
return key
|
368
|
+
|
369
|
+
keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
|
370
|
+
# stack of keys; only the last key is unused
|
371
|
+
key = keys[-1]
|
372
|
+
|
373
|
+
q = jnp.concatenate(q_list, axis=1)
|
374
|
+
|
375
|
+
return key, q
|