imt-ring 1.4.1__tar.gz → 1.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {imt_ring-1.4.1 → imt_ring-1.5.1}/PKG-INFO +1 -1
  2. {imt_ring-1.4.1 → imt_ring-1.5.1}/pyproject.toml +1 -1
  3. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/imt_ring.egg-info/SOURCES.txt +3 -2
  5. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/__init__.py +21 -10
  6. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/__init__.py +1 -11
  7. imt_ring-1.5.1/src/ring/algorithms/generator/__init__.py +11 -0
  8. imt_ring-1.5.1/src/ring/algorithms/generator/base.py +378 -0
  9. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/generator/batch.py +26 -109
  10. imt_ring-1.5.1/src/ring/algorithms/generator/finalize_fns.py +306 -0
  11. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/generator/motion_artifacts.py +31 -19
  12. imt_ring-1.5.1/src/ring/algorithms/generator/setup_fns.py +43 -0
  13. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/generator/types.py +3 -18
  14. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/jcalc.py +0 -9
  15. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/rendering/mujoco_render.py +2 -1
  16. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/__init__.py +3 -4
  17. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/batchsize.py +12 -4
  18. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/utils.py +6 -0
  19. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_custom_joints.py +18 -20
  20. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_generator.py +5 -6
  21. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_ml_utils.py +5 -6
  22. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_pd_control.py +10 -8
  23. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_randomize.py +3 -2
  24. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_rcmg.py +62 -18
  25. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_train.py +18 -3
  26. imt_ring-1.4.1/src/ring/algorithms/generator/__init__.py +0 -25
  27. imt_ring-1.4.1/src/ring/algorithms/generator/base.py +0 -409
  28. imt_ring-1.4.1/src/ring/algorithms/generator/transforms.py +0 -411
  29. {imt_ring-1.4.1 → imt_ring-1.5.1}/readme.md +0 -0
  30. {imt_ring-1.4.1 → imt_ring-1.5.1}/setup.cfg +0 -0
  31. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  32. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/imt_ring.egg-info/requires.txt +0 -0
  33. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/imt_ring.egg-info/top_level.txt +0 -0
  34. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algebra.py +0 -0
  35. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/_random.py +0 -0
  36. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  37. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  38. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  39. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  40. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/dynamics.py +0 -0
  41. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/generator/pd_control.py +0 -0
  42. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/kinematics.py +0 -0
  43. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/algorithms/sensors.py +0 -0
  44. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/base.py +0 -0
  45. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/__init__.py +0 -0
  46. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/branched.xml +0 -0
  47. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  48. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  49. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  50. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/inv_pendulum.xml +0 -0
  51. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  52. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/spherical_stiff.xml +0 -0
  53. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/symmetric.xml +0 -0
  54. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_all_1.xml +0 -0
  55. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_all_2.xml +0 -0
  56. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  57. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_control.xml +0 -0
  58. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  59. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_free.xml +0 -0
  60. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_kinematics.xml +0 -0
  61. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  62. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  63. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_randomize_position.xml +0 -0
  64. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_sensors.xml +0 -0
  65. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  66. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/examples.py +0 -0
  67. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/test_examples.py +0 -0
  68. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/__init__.py +0 -0
  69. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/abstract.py +0 -0
  70. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/from_xml.py +0 -0
  71. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/test_from_xml.py +0 -0
  72. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/test_to_xml.py +0 -0
  73. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/io/xml/to_xml.py +0 -0
  74. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/maths.py +0 -0
  75. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/__init__.py +0 -0
  76. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/base.py +0 -0
  77. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/callbacks.py +0 -0
  78. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/ml_utils.py +0 -0
  79. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/optimizer.py +0 -0
  80. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  81. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  82. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/ringnet.py +0 -0
  83. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/rnno_v1.py +0 -0
  84. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/train.py +0 -0
  85. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/ml/training_loop.py +0 -0
  86. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/rendering/__init__.py +0 -0
  87. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/rendering/base_render.py +0 -0
  88. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/rendering/vispy_render.py +0 -0
  89. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/rendering/vispy_visuals.py +0 -0
  90. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sim2real/__init__.py +0 -0
  91. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sim2real/sim2real.py +0 -0
  92. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/spatial.py +0 -0
  93. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sys_composer/__init__.py +0 -0
  94. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sys_composer/delete_sys.py +0 -0
  95. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sys_composer/inject_sys.py +0 -0
  96. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/sys_composer/morph_sys.py +0 -0
  97. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/backend.py +0 -0
  98. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/colab.py +0 -0
  99. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/hdf5.py +0 -0
  100. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/normalizer.py +0 -0
  101. {imt_ring-1.4.1 → imt_ring-1.5.1}/src/ring/utils/path.py +0 -0
  102. /imt_ring-1.4.1/src/ring/algorithms/generator/randomize.py → /imt_ring-1.5.1/src/ring/utils/randomize_sys.py +0 -0
  103. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_algebra.py +0 -0
  104. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_base.py +0 -0
  105. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_dynamics.py +0 -0
  106. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_jcalc.py +0 -0
  107. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_jit.py +0 -0
  108. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_kinematics.py +0 -0
  109. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_maths.py +0 -0
  110. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_motion_artifacts.py +0 -0
  111. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_quickstart_example.py +0 -0
  112. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_random.py +0 -0
  113. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_render.py +0 -0
  114. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_sensors.py +0 -0
  115. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_sim2real.py +0 -0
  116. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_sys_composer.py +0 -0
  117. {imt_ring-1.4.1 → imt_ring-1.5.1}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.4.1
3
+ Version: 1.5.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.4.1"
7
+ version = "1.5.1"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.4.1
3
+ Version: 1.5.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -24,10 +24,10 @@ src/ring/algorithms/custom_joints/suntay.py
24
24
  src/ring/algorithms/generator/__init__.py
25
25
  src/ring/algorithms/generator/base.py
26
26
  src/ring/algorithms/generator/batch.py
27
+ src/ring/algorithms/generator/finalize_fns.py
27
28
  src/ring/algorithms/generator/motion_artifacts.py
28
29
  src/ring/algorithms/generator/pd_control.py
29
- src/ring/algorithms/generator/randomize.py
30
- src/ring/algorithms/generator/transforms.py
30
+ src/ring/algorithms/generator/setup_fns.py
31
31
  src/ring/algorithms/generator/types.py
32
32
  src/ring/io/__init__.py
33
33
  src/ring/io/examples.py
@@ -87,6 +87,7 @@ src/ring/utils/colab.py
87
87
  src/ring/utils/hdf5.py
88
88
  src/ring/utils/normalizer.py
89
89
  src/ring/utils/path.py
90
+ src/ring/utils/randomize_sys.py
90
91
  src/ring/utils/utils.py
91
92
  tests/test_algebra.py
92
93
  tests/test_base.py
@@ -20,11 +20,11 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int], Ts: float | None):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs):
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
- lam: parent array
27
+ lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
30
  Usage:
@@ -55,6 +55,7 @@ def RING(lam: list[int], Ts: float | None):
55
55
  >>>
56
56
  >>> yhat, _ = ringnet.apply(X)
57
57
  >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
58
59
  >>>
59
60
  >>> # use `jax.jit` to compile the forward pass
60
61
  >>> jit_apply = jax.jit(ringnet.apply)
@@ -69,13 +70,20 @@ def RING(lam: list[int], Ts: float | None):
69
70
  from pathlib import Path
70
71
  import warnings
71
72
 
73
+ config = dict(
74
+ use_100Hz_RING=True,
75
+ use_lpf=True,
76
+ lpf_cutoff_freq=ml._LPF_CUTOFF_FREQ,
77
+ )
78
+ config.update(kwargs)
79
+
72
80
  if Ts is not None and (Ts > (1 / 40) or Ts < (1 / 200)):
73
81
  warnings.warn(
74
82
  "RING was only trained on sampling rates between 40 to 200 Hz "
75
83
  f"but found {1 / Ts}Hz"
76
84
  )
77
85
 
78
- if Ts is not None and Ts == 0.01:
86
+ if Ts is not None and Ts == 0.01 and config["use_100Hz_RING"]:
79
87
  # this set of parameters was trained exclusively on 100Hz data; it also
80
88
  # expects F=9 features per node and not F=10 where the last features is
81
89
  # the sampling interval Ts
@@ -86,14 +94,17 @@ def RING(lam: list[int], Ts: float | None):
86
94
  params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
87
95
  add_Ts = True
88
96
 
89
- ringnet = ml.RING(params=params, lam=tuple(lam), jit=False, name="RING")
90
- ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
91
- ringnet = ml.base.LPF_FilterWrapper(
92
- ringnet,
93
- ml._LPF_CUTOFF_FREQ,
94
- samp_freq=None if Ts is None else 1 / Ts,
95
- quiet=True,
97
+ ringnet = ml.RING(
98
+ params=params, lam=None if lam is None else tuple(lam), jit=False, name="RING"
96
99
  )
100
+ ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
101
+ if config["use_lpf"]:
102
+ ringnet = ml.base.LPF_FilterWrapper(
103
+ ringnet,
104
+ config["lpf_cutoff_freq"],
105
+ samp_freq=None if Ts is None else 1 / Ts,
106
+ quiet=True,
107
+ )
97
108
  ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
98
109
  if add_Ts:
99
110
  ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
@@ -10,21 +10,11 @@ from .dynamics import compute_mass_matrix
10
10
  from .dynamics import forward_dynamics
11
11
  from .dynamics import inverse_dynamics
12
12
  from .dynamics import step
13
- from .generator import batch_generators_eager
14
- from .generator import batch_generators_eager_to_list
15
- from .generator import batch_generators_lazy
16
- from .generator import batched_generator_from_list
17
- from .generator import batched_generator_from_paths
18
13
  from .generator import FINALIZE_FN
19
14
  from .generator import Generator
20
- from .generator import GeneratorPipe
21
- from .generator import GeneratorTrafo
22
- from .generator import GeneratorTrafoExpandFlatten
23
- from .generator import GeneratorTrafoRandomizePositions
24
- from .generator import GeneratorTrafoRemoveInputExtras
25
- from .generator import GeneratorTrafoRemoveOutputExtras
26
15
  from .generator import RCMG
27
16
  from .generator import SETUP_FN
17
+ from .generator.finalize_fns import GeneratorTrafoExpandFlatten
28
18
  from .jcalc import get_joint_model
29
19
  from .jcalc import jcalc_motion
30
20
  from .jcalc import jcalc_tau
@@ -0,0 +1,11 @@
1
+ from . import base
2
+ from . import batch
3
+ from . import finalize_fns
4
+ from . import motion_artifacts
5
+ from . import pd_control
6
+ from . import setup_fns
7
+ from . import types
8
+ from .base import RCMG
9
+ from .types import FINALIZE_FN
10
+ from .types import Generator
11
+ from .types import SETUP_FN
@@ -0,0 +1,378 @@
1
+ from typing import Callable, Optional
2
+ import warnings
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import tree_utils
7
+
8
+ from ring import base
9
+ from ring import utils
10
+ from ring.algorithms import jcalc
11
+ from ring.algorithms import kinematics
12
+ from ring.algorithms.generator import batch
13
+ from ring.algorithms.generator import finalize_fns
14
+ from ring.algorithms.generator import motion_artifacts
15
+ from ring.algorithms.generator import setup_fns
16
+ from ring.algorithms.generator import types
17
+
18
+
19
+ class RCMG:
20
+ def __init__(
21
+ self,
22
+ sys: base.System | list[base.System],
23
+ config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),
24
+ setup_fn: Optional[types.SETUP_FN] = None,
25
+ finalize_fn: Optional[types.FINALIZE_FN] = None,
26
+ add_X_imus: bool = False,
27
+ add_X_imus_kwargs: dict = dict(),
28
+ add_X_jointaxes: bool = False,
29
+ add_X_jointaxes_kwargs: dict = dict(),
30
+ add_y_relpose: bool = False,
31
+ add_y_rootincl: bool = False,
32
+ sys_ml: Optional[base.System] = None,
33
+ randomize_positions: bool = False,
34
+ randomize_motion_artifacts: bool = False,
35
+ randomize_joint_params: bool = False,
36
+ imu_motion_artifacts: bool = False,
37
+ imu_motion_artifacts_kwargs: dict = dict(),
38
+ dynamic_simulation: bool = False,
39
+ dynamic_simulation_kwargs: dict = dict(),
40
+ output_transform: Optional[Callable] = None,
41
+ keep_output_extras: bool = False,
42
+ use_link_number_in_Xy: bool = False,
43
+ cor: bool = False,
44
+ disable_tqdm: bool = False,
45
+ ) -> None:
46
+
47
+ sys, config = utils.to_list(sys), utils.to_list(config)
48
+ sys_ml = sys[0] if sys_ml is None else sys_ml
49
+
50
+ for c in config:
51
+ assert c.is_feasible()
52
+
53
+ self.gens = []
54
+ for _sys in sys:
55
+ self.gens.append(
56
+ _build_mconfig_batched_generator(
57
+ sys=_sys,
58
+ config=config,
59
+ setup_fn=setup_fn,
60
+ finalize_fn=finalize_fn,
61
+ add_X_imus=add_X_imus,
62
+ add_X_imus_kwargs=add_X_imus_kwargs,
63
+ add_X_jointaxes=add_X_jointaxes,
64
+ add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
65
+ add_y_relpose=add_y_relpose,
66
+ add_y_rootincl=add_y_rootincl,
67
+ sys_ml=sys_ml,
68
+ randomize_positions=randomize_positions,
69
+ randomize_motion_artifacts=randomize_motion_artifacts,
70
+ randomize_joint_params=randomize_joint_params,
71
+ imu_motion_artifacts=imu_motion_artifacts,
72
+ imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
73
+ dynamic_simulation=dynamic_simulation,
74
+ dynamic_simulation_kwargs=dynamic_simulation_kwargs,
75
+ output_transform=output_transform,
76
+ keep_output_extras=keep_output_extras,
77
+ use_link_number_in_Xy=use_link_number_in_Xy,
78
+ cor=cor,
79
+ )
80
+ )
81
+
82
+ self._n_mconfigs = len(config)
83
+ self._size_of_generators = [self._n_mconfigs] * len(self.gens)
84
+
85
+ self._disable_tqdm = disable_tqdm
86
+
87
+ def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
88
+ "how many times the generators are repeated to create a batch of `sizes`"
89
+
90
+ S, L = sum(self._size_of_generators), len(self._size_of_generators)
91
+
92
+ def assert_size(size: int):
93
+ assert self._n_mconfigs in utils.primes(size), (
94
+ f"`size`={size} is not divisible by number of "
95
+ + f"`mconfigs`={self._n_mconfigs}"
96
+ )
97
+
98
+ if isinstance(sizes, int):
99
+ assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
100
+ assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
101
+ repeats = L * [sizes // S]
102
+ else:
103
+ for size in sizes:
104
+ assert_size(size)
105
+
106
+ assert len(sizes) == len(
107
+ self.gens
108
+ ), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
109
+
110
+ repeats = [
111
+ size // size_of_gen
112
+ for size, size_of_gen in zip(sizes, self._size_of_generators)
113
+ ]
114
+ assert 0 not in repeats
115
+
116
+ return repeats
117
+
118
+ def to_lazy_gen(
119
+ self, sizes: int | list[int] = 1, jit: bool = True
120
+ ) -> types.BatchedGenerator:
121
+ return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
122
+
123
+ @staticmethod
124
+ def _number_of_executions_required(size: int) -> int:
125
+ _, vmap = utils.distribute_batchsize(size)
126
+
127
+ eager_threshold = utils.batchsize_thresholds()[1]
128
+ primes = iter(utils.primes(vmap))
129
+ n_calls = 1
130
+ while vmap > eager_threshold:
131
+ prime = next(primes)
132
+ n_calls *= prime
133
+ vmap /= prime
134
+
135
+ return n_calls
136
+
137
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
138
+ "Returns list of unbatched sequences as numpy arrays."
139
+ repeats = self._compute_repeats(sizes)
140
+ sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
141
+
142
+ reduced_repeats = []
143
+ n_calls = []
144
+ for size, repeat in zip(sizes, repeats):
145
+ n_call = self._number_of_executions_required(size)
146
+ gcd = utils.gcd(n_call, repeat)
147
+ n_calls.append(gcd)
148
+ reduced_repeats.append(repeat // gcd)
149
+ jits = [N > 1 for N in n_calls]
150
+
151
+ gens = []
152
+ for i in range(len(repeats)):
153
+ gens.append(
154
+ batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
155
+ )
156
+
157
+ return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
158
+
159
+ def to_pickle(
160
+ self,
161
+ path: str,
162
+ sizes: int | list[int] = 1,
163
+ seed: int = 1,
164
+ overwrite: bool = True,
165
+ ) -> None:
166
+ data = tree_utils.tree_batch(self.to_list(sizes, seed))
167
+ utils.pickle_save(data, path, overwrite=overwrite)
168
+
169
+ def to_eager_gen(
170
+ self,
171
+ batchsize: int = 1,
172
+ sizes: int | list[int] = 1,
173
+ seed: int = 1,
174
+ shuffle: bool = True,
175
+ ) -> types.BatchedGenerator:
176
+ data = self.to_list(sizes, seed)
177
+ assert len(data) >= batchsize
178
+
179
+ def data_fn(indices: list[int]):
180
+ return tree_utils.tree_batch([data[i] for i in indices])
181
+
182
+ return batch.generator_from_data_fn(
183
+ data_fn, list(range(len(data))), shuffle, batchsize
184
+ )
185
+
186
+ @staticmethod
187
+ def eager_gen_from_paths(
188
+ paths: str | list[str],
189
+ batchsize: int,
190
+ include_samples: Optional[list[int]] = None,
191
+ shuffle: bool = True,
192
+ load_all_into_memory: bool = False,
193
+ tree_transform=None,
194
+ ) -> tuple[types.BatchedGenerator, int]:
195
+ paths = utils.to_list(paths)
196
+ return batch.generator_from_paths(
197
+ paths,
198
+ batchsize,
199
+ include_samples,
200
+ shuffle,
201
+ load_all_into_memory=load_all_into_memory,
202
+ tree_transform=tree_transform,
203
+ )
204
+
205
+
206
+ def _copy_dicts(f) -> dict:
207
+ def _f(*args, **kwargs):
208
+ _copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
209
+ args = tuple([_copy(ele) for ele in args])
210
+ kwargs = {k: _copy(v) for k, v in kwargs.items()}
211
+ return f(*args, **kwargs)
212
+
213
+ return _f
214
+
215
+
216
+ @_copy_dicts
217
+ def _build_mconfig_batched_generator(
218
+ sys: base.System,
219
+ config: list[jcalc.MotionConfig],
220
+ setup_fn: types.SETUP_FN | None,
221
+ finalize_fn: types.FINALIZE_FN | None,
222
+ add_X_imus: bool,
223
+ add_X_imus_kwargs: dict,
224
+ add_X_jointaxes: bool,
225
+ add_X_jointaxes_kwargs: dict,
226
+ add_y_relpose: bool,
227
+ add_y_rootincl: bool,
228
+ sys_ml: base.System,
229
+ randomize_positions: bool,
230
+ randomize_motion_artifacts: bool,
231
+ randomize_joint_params: bool,
232
+ imu_motion_artifacts: bool,
233
+ imu_motion_artifacts_kwargs: dict,
234
+ dynamic_simulation: bool,
235
+ dynamic_simulation_kwargs: dict,
236
+ output_transform: Callable | None,
237
+ keep_output_extras: bool,
238
+ use_link_number_in_Xy: bool,
239
+ cor: bool,
240
+ ) -> types.BatchedGenerator:
241
+
242
+ if add_X_jointaxes or add_y_relpose or add_y_rootincl:
243
+ if len(sys_ml.findall_imus()) > 0:
244
+ # warnings.warn("Automatically removed the IMUs from `sys_ml`.")
245
+ sys_noimu, _ = sys_ml.make_sys_noimu()
246
+ else:
247
+ sys_noimu = sys_ml
248
+
249
+ unactuated_subsystems = []
250
+ if imu_motion_artifacts:
251
+ assert dynamic_simulation
252
+ unactuated_subsystems = motion_artifacts.unactuated_subsystem(sys)
253
+ sys = motion_artifacts.inject_subsystems(sys, **imu_motion_artifacts_kwargs)
254
+ assert "unactuated_subsystems" not in dynamic_simulation_kwargs
255
+ dynamic_simulation_kwargs["unactuated_subsystems"] = unactuated_subsystems
256
+
257
+ if not randomize_motion_artifacts:
258
+ warnings.warn(
259
+ "`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
260
+ )
261
+
262
+ if "prob_rigid" in imu_motion_artifacts_kwargs:
263
+ assert randomize_motion_artifacts, (
264
+ "`prob_rigid` works by overwriting damping and stiffness parameters "
265
+ "using the `randomize_motion_artifacts` flag, so it must be enabled."
266
+ )
267
+
268
+ def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
269
+ pipe = []
270
+ if imu_motion_artifacts and randomize_motion_artifacts:
271
+ pipe.append(
272
+ motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
273
+ **imu_motion_artifacts_kwargs
274
+ )
275
+ )
276
+ if randomize_positions:
277
+ pipe.append(setup_fns._setup_fn_randomize_positions)
278
+ if randomize_joint_params:
279
+ pipe.append(jcalc._init_joint_params)
280
+ if setup_fn is not None:
281
+ pipe.append(setup_fn)
282
+
283
+ for f in pipe:
284
+ key, consume = jax.random.split(key)
285
+ sys = f(consume, sys)
286
+ if cor:
287
+ sys = sys._replace_free_with_cor()
288
+ return sys
289
+
290
+ def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
291
+ pipe = []
292
+ if dynamic_simulation:
293
+ pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
294
+ if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
295
+ "hide_injected_bodies", True
296
+ ):
297
+ pipe.append(motion_artifacts.HideInjectedBodies())
298
+ if finalize_fn is not None:
299
+ pipe.append(finalize_fns.FinalizeFn(finalize_fn))
300
+ if add_X_imus:
301
+ pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
302
+ if add_X_jointaxes:
303
+ pipe.append(
304
+ finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
305
+ )
306
+ if add_y_relpose:
307
+ pipe.append(finalize_fns.RelPose(sys_noimu))
308
+ if add_y_rootincl:
309
+ pipe.append(finalize_fns.RootIncl(sys_noimu))
310
+ if use_link_number_in_Xy:
311
+ pipe.append(finalize_fns.Names2Indices(sys_noimu))
312
+
313
+ for f in pipe:
314
+ Xy, extras = f(Xy, extras)
315
+ return Xy, extras
316
+
317
+ def _gen(key: types.PRNGKey):
318
+ key, *consume = jax.random.split(key, len(config) + 1)
319
+ syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
320
+
321
+ qs = []
322
+ for i, _config in enumerate(config):
323
+ key, _q = draw_random_q(key, syss[i], _config)
324
+ qs.append(_q)
325
+ qs = jnp.stack(qs)
326
+
327
+ @jax.vmap
328
+ def _vmapped_context(key, q, sys):
329
+ x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
330
+ Xy, extras = ({}, {}), (key, q, x, sys)
331
+ return _finalize_fn(Xy, extras)
332
+
333
+ keys = jax.random.split(key, len(config))
334
+ Xy, extras = _vmapped_context(keys, qs, syss)
335
+ output = (Xy, extras) if keep_output_extras else Xy
336
+ output = output if output_transform is None else output_transform(output)
337
+ return output
338
+
339
+ return _gen
340
+
341
+
342
+ def draw_random_q(
343
+ key: types.PRNGKey,
344
+ sys: base.System,
345
+ config: jcalc.MotionConfig,
346
+ ) -> tuple[types.Xy, types.OutputExtras]:
347
+
348
+ key_start = key
349
+ # build generalized coordintes vector `q`
350
+ q_list = []
351
+
352
+ def draw_q(key, __, link_type, link):
353
+ joint_params = link.joint_params
354
+ # limit scope
355
+ joint_params = (
356
+ joint_params[link_type]
357
+ if link_type in joint_params
358
+ else joint_params["default"]
359
+ )
360
+ if key is None:
361
+ key = key_start
362
+ key, key_t, key_value = jax.random.split(key, 3)
363
+ draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
364
+ if draw_fn is None:
365
+ raise Exception(f"The joint type {link_type} has no draw fn specified.")
366
+ q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
367
+ # even revolute and prismatic joints must be 2d arrays
368
+ q_link = q_link if q_link.ndim == 2 else q_link[:, None]
369
+ q_list.append(q_link)
370
+ return key
371
+
372
+ keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
373
+ # stack of keys; only the last key is unused
374
+ key = keys[-1]
375
+
376
+ q = jnp.concatenate(q_list, axis=1)
377
+
378
+ return key, q