imt-ring 1.4.1__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {imt_ring-1.4.1 → imt_ring-1.5.0}/PKG-INFO +1 -1
  2. {imt_ring-1.4.1 → imt_ring-1.5.0}/pyproject.toml +1 -1
  3. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/SOURCES.txt +3 -2
  5. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/__init__.py +21 -10
  6. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/__init__.py +1 -11
  7. imt_ring-1.5.0/src/ring/algorithms/generator/__init__.py +11 -0
  8. imt_ring-1.5.0/src/ring/algorithms/generator/base.py +375 -0
  9. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/batch.py +26 -109
  10. imt_ring-1.5.0/src/ring/algorithms/generator/finalize_fns.py +306 -0
  11. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/motion_artifacts.py +17 -19
  12. imt_ring-1.5.0/src/ring/algorithms/generator/setup_fns.py +43 -0
  13. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/types.py +3 -18
  14. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/jcalc.py +0 -9
  15. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/mujoco_render.py +2 -1
  16. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/__init__.py +3 -4
  17. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/batchsize.py +12 -4
  18. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/utils.py +6 -0
  19. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_custom_joints.py +15 -17
  20. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_generator.py +5 -6
  21. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_ml_utils.py +5 -6
  22. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_pd_control.py +9 -7
  23. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_randomize.py +3 -2
  24. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_rcmg.py +62 -18
  25. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_train.py +18 -3
  26. imt_ring-1.4.1/src/ring/algorithms/generator/__init__.py +0 -25
  27. imt_ring-1.4.1/src/ring/algorithms/generator/base.py +0 -409
  28. imt_ring-1.4.1/src/ring/algorithms/generator/transforms.py +0 -411
  29. {imt_ring-1.4.1 → imt_ring-1.5.0}/readme.md +0 -0
  30. {imt_ring-1.4.1 → imt_ring-1.5.0}/setup.cfg +0 -0
  31. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  32. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/requires.txt +0 -0
  33. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/imt_ring.egg-info/top_level.txt +0 -0
  34. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algebra.py +0 -0
  35. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/_random.py +0 -0
  36. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  37. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  38. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  39. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  40. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/dynamics.py +0 -0
  41. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/generator/pd_control.py +0 -0
  42. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/kinematics.py +0 -0
  43. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/algorithms/sensors.py +0 -0
  44. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/base.py +0 -0
  45. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/__init__.py +0 -0
  46. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/branched.xml +0 -0
  47. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  48. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  49. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  50. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/inv_pendulum.xml +0 -0
  51. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  52. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/spherical_stiff.xml +0 -0
  53. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/symmetric.xml +0 -0
  54. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_all_1.xml +0 -0
  55. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_all_2.xml +0 -0
  56. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  57. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_control.xml +0 -0
  58. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  59. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_free.xml +0 -0
  60. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_kinematics.xml +0 -0
  61. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  62. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  63. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_randomize_position.xml +0 -0
  64. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_sensors.xml +0 -0
  65. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  66. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/examples.py +0 -0
  67. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/test_examples.py +0 -0
  68. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/__init__.py +0 -0
  69. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/abstract.py +0 -0
  70. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/from_xml.py +0 -0
  71. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/test_from_xml.py +0 -0
  72. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/test_to_xml.py +0 -0
  73. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/io/xml/to_xml.py +0 -0
  74. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/maths.py +0 -0
  75. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/__init__.py +0 -0
  76. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/base.py +0 -0
  77. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/callbacks.py +0 -0
  78. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/ml_utils.py +0 -0
  79. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/optimizer.py +0 -0
  80. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  81. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  82. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/ringnet.py +0 -0
  83. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/rnno_v1.py +0 -0
  84. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/train.py +0 -0
  85. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/ml/training_loop.py +0 -0
  86. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/__init__.py +0 -0
  87. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/base_render.py +0 -0
  88. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/vispy_render.py +0 -0
  89. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/rendering/vispy_visuals.py +0 -0
  90. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sim2real/__init__.py +0 -0
  91. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sim2real/sim2real.py +0 -0
  92. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/spatial.py +0 -0
  93. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/__init__.py +0 -0
  94. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/delete_sys.py +0 -0
  95. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/inject_sys.py +0 -0
  96. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/sys_composer/morph_sys.py +0 -0
  97. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/backend.py +0 -0
  98. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/colab.py +0 -0
  99. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/hdf5.py +0 -0
  100. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/normalizer.py +0 -0
  101. {imt_ring-1.4.1 → imt_ring-1.5.0}/src/ring/utils/path.py +0 -0
  102. /imt_ring-1.4.1/src/ring/algorithms/generator/randomize.py → /imt_ring-1.5.0/src/ring/utils/randomize_sys.py +0 -0
  103. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_algebra.py +0 -0
  104. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_base.py +0 -0
  105. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_dynamics.py +0 -0
  106. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_jcalc.py +0 -0
  107. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_jit.py +0 -0
  108. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_kinematics.py +0 -0
  109. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_maths.py +0 -0
  110. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_motion_artifacts.py +0 -0
  111. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_quickstart_example.py +0 -0
  112. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_random.py +0 -0
  113. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_render.py +0 -0
  114. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sensors.py +0 -0
  115. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sim2real.py +0 -0
  116. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_sys_composer.py +0 -0
  117. {imt_ring-1.4.1 → imt_ring-1.5.0}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.4.1
3
+ Version: 1.5.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.4.1"
7
+ version = "1.5.0"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.4.1
3
+ Version: 1.5.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -24,10 +24,10 @@ src/ring/algorithms/custom_joints/suntay.py
24
24
  src/ring/algorithms/generator/__init__.py
25
25
  src/ring/algorithms/generator/base.py
26
26
  src/ring/algorithms/generator/batch.py
27
+ src/ring/algorithms/generator/finalize_fns.py
27
28
  src/ring/algorithms/generator/motion_artifacts.py
28
29
  src/ring/algorithms/generator/pd_control.py
29
- src/ring/algorithms/generator/randomize.py
30
- src/ring/algorithms/generator/transforms.py
30
+ src/ring/algorithms/generator/setup_fns.py
31
31
  src/ring/algorithms/generator/types.py
32
32
  src/ring/io/__init__.py
33
33
  src/ring/io/examples.py
@@ -87,6 +87,7 @@ src/ring/utils/colab.py
87
87
  src/ring/utils/hdf5.py
88
88
  src/ring/utils/normalizer.py
89
89
  src/ring/utils/path.py
90
+ src/ring/utils/randomize_sys.py
90
91
  src/ring/utils/utils.py
91
92
  tests/test_algebra.py
92
93
  tests/test_base.py
@@ -20,11 +20,11 @@ from .base import System
20
20
  from .base import Transform
21
21
 
22
22
 
23
- def RING(lam: list[int], Ts: float | None):
23
+ def RING(lam: list[int] | None, Ts: float | None, **kwargs):
24
24
  """Creates the RING network.
25
25
 
26
26
  Params:
27
- lam: parent array
27
+ lam: parent array, if `None` must be given via `ringnet.apply(..., lam=lam)`
28
28
  Ts : sampling interval of IMU data; time delta in seconds
29
29
 
30
30
  Usage:
@@ -55,6 +55,7 @@ def RING(lam: list[int], Ts: float | None):
55
55
  >>>
56
56
  >>> yhat, _ = ringnet.apply(X)
57
57
  >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
+ >>> # yhat[b, :, i] is the orientation from body `i` to parent body `lam[i]`
58
59
  >>>
59
60
  >>> # use `jax.jit` to compile the forward pass
60
61
  >>> jit_apply = jax.jit(ringnet.apply)
@@ -69,13 +70,20 @@ def RING(lam: list[int], Ts: float | None):
69
70
  from pathlib import Path
70
71
  import warnings
71
72
 
73
+ config = dict(
74
+ use_100Hz_RING=True,
75
+ use_lpf=True,
76
+ lpf_cutoff_freq=ml._LPF_CUTOFF_FREQ,
77
+ )
78
+ config.update(kwargs)
79
+
72
80
  if Ts is not None and (Ts > (1 / 40) or Ts < (1 / 200)):
73
81
  warnings.warn(
74
82
  "RING was only trained on sampling rates between 40 to 200 Hz "
75
83
  f"but found {1 / Ts}Hz"
76
84
  )
77
85
 
78
- if Ts is not None and Ts == 0.01:
86
+ if Ts is not None and Ts == 0.01 and config["use_100Hz_RING"]:
79
87
  # this set of parameters was trained exclusively on 100Hz data; it also
80
88
  # expects F=9 features per node and not F=10 where the last features is
81
89
  # the sampling interval Ts
@@ -86,14 +94,17 @@ def RING(lam: list[int], Ts: float | None):
86
94
  params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
87
95
  add_Ts = True
88
96
 
89
- ringnet = ml.RING(params=params, lam=tuple(lam), jit=False, name="RING")
90
- ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
91
- ringnet = ml.base.LPF_FilterWrapper(
92
- ringnet,
93
- ml._LPF_CUTOFF_FREQ,
94
- samp_freq=None if Ts is None else 1 / Ts,
95
- quiet=True,
97
+ ringnet = ml.RING(
98
+ params=params, lam=None if lam is None else tuple(lam), jit=False, name="RING"
96
99
  )
100
+ ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
101
+ if config["use_lpf"]:
102
+ ringnet = ml.base.LPF_FilterWrapper(
103
+ ringnet,
104
+ config["lpf_cutoff_freq"],
105
+ samp_freq=None if Ts is None else 1 / Ts,
106
+ quiet=True,
107
+ )
97
108
  ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
98
109
  if add_Ts:
99
110
  ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
@@ -10,21 +10,11 @@ from .dynamics import compute_mass_matrix
10
10
  from .dynamics import forward_dynamics
11
11
  from .dynamics import inverse_dynamics
12
12
  from .dynamics import step
13
- from .generator import batch_generators_eager
14
- from .generator import batch_generators_eager_to_list
15
- from .generator import batch_generators_lazy
16
- from .generator import batched_generator_from_list
17
- from .generator import batched_generator_from_paths
18
13
  from .generator import FINALIZE_FN
19
14
  from .generator import Generator
20
- from .generator import GeneratorPipe
21
- from .generator import GeneratorTrafo
22
- from .generator import GeneratorTrafoExpandFlatten
23
- from .generator import GeneratorTrafoRandomizePositions
24
- from .generator import GeneratorTrafoRemoveInputExtras
25
- from .generator import GeneratorTrafoRemoveOutputExtras
26
15
  from .generator import RCMG
27
16
  from .generator import SETUP_FN
17
+ from .generator.finalize_fns import GeneratorTrafoExpandFlatten
28
18
  from .jcalc import get_joint_model
29
19
  from .jcalc import jcalc_motion
30
20
  from .jcalc import jcalc_tau
@@ -0,0 +1,11 @@
1
+ from . import base
2
+ from . import batch
3
+ from . import finalize_fns
4
+ from . import motion_artifacts
5
+ from . import pd_control
6
+ from . import setup_fns
7
+ from . import types
8
+ from .base import RCMG
9
+ from .types import FINALIZE_FN
10
+ from .types import Generator
11
+ from .types import SETUP_FN
@@ -0,0 +1,375 @@
1
+ from typing import Callable, Optional
2
+ import warnings
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import tree_utils
7
+
8
+ from ring import base
9
+ from ring import utils
10
+ from ring.algorithms import jcalc
11
+ from ring.algorithms import kinematics
12
+ from ring.algorithms.generator import batch
13
+ from ring.algorithms.generator import finalize_fns
14
+ from ring.algorithms.generator import motion_artifacts
15
+ from ring.algorithms.generator import setup_fns
16
+ from ring.algorithms.generator import types
17
+
18
+
19
+ class RCMG:
20
+ def __init__(
21
+ self,
22
+ sys: base.System | list[base.System],
23
+ config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),
24
+ setup_fn: Optional[types.SETUP_FN] = None,
25
+ finalize_fn: Optional[types.FINALIZE_FN] = None,
26
+ add_X_imus: bool = False,
27
+ add_X_imus_kwargs: dict = dict(),
28
+ add_X_jointaxes: bool = False,
29
+ add_X_jointaxes_kwargs: dict = dict(),
30
+ add_y_relpose: bool = False,
31
+ add_y_rootincl: bool = False,
32
+ sys_ml: Optional[base.System] = None,
33
+ randomize_positions: bool = False,
34
+ randomize_motion_artifacts: bool = False,
35
+ randomize_joint_params: bool = False,
36
+ imu_motion_artifacts: bool = False,
37
+ imu_motion_artifacts_kwargs: dict = dict(hide_injected_bodies=True),
38
+ dynamic_simulation: bool = False,
39
+ dynamic_simulation_kwargs: dict = dict(),
40
+ output_transform: Optional[Callable] = None,
41
+ keep_output_extras: bool = False,
42
+ use_link_number_in_Xy: bool = False,
43
+ cor: bool = False,
44
+ disable_tqdm: bool = False,
45
+ ) -> None:
46
+
47
+ sys, config = utils.to_list(sys), utils.to_list(config)
48
+ sys_ml = sys[0] if sys_ml is None else sys_ml
49
+
50
+ for c in config:
51
+ assert c.is_feasible()
52
+
53
+ if cor:
54
+ sys = [s._replace_free_with_cor() for s in sys]
55
+
56
+ self.gens = []
57
+ for _sys in sys:
58
+ self.gens.append(
59
+ _build_mconfig_batched_generator(
60
+ sys=_sys,
61
+ config=config,
62
+ setup_fn=setup_fn,
63
+ finalize_fn=finalize_fn,
64
+ add_X_imus=add_X_imus,
65
+ add_X_imus_kwargs=add_X_imus_kwargs,
66
+ add_X_jointaxes=add_X_jointaxes,
67
+ add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
68
+ add_y_relpose=add_y_relpose,
69
+ add_y_rootincl=add_y_rootincl,
70
+ sys_ml=sys_ml,
71
+ randomize_positions=randomize_positions,
72
+ randomize_motion_artifacts=randomize_motion_artifacts,
73
+ randomize_joint_params=randomize_joint_params,
74
+ imu_motion_artifacts=imu_motion_artifacts,
75
+ imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
76
+ dynamic_simulation=dynamic_simulation,
77
+ dynamic_simulation_kwargs=dynamic_simulation_kwargs,
78
+ output_transform=output_transform,
79
+ keep_output_extras=keep_output_extras,
80
+ use_link_number_in_Xy=use_link_number_in_Xy,
81
+ )
82
+ )
83
+
84
+ self._n_mconfigs = len(config)
85
+ self._size_of_generators = [self._n_mconfigs] * len(self.gens)
86
+
87
+ self._disable_tqdm = disable_tqdm
88
+
89
+ def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
90
+ "how many times the generators are repeated to create a batch of `sizes`"
91
+
92
+ S, L = sum(self._size_of_generators), len(self._size_of_generators)
93
+
94
+ def assert_size(size: int):
95
+ assert self._n_mconfigs in utils.primes(size), (
96
+ f"`size`={size} is not divisible by number of "
97
+ + f"`mconfigs`={self._n_mconfigs}"
98
+ )
99
+
100
+ if isinstance(sizes, int):
101
+ assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
102
+ assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
103
+ repeats = L * [sizes // S]
104
+ else:
105
+ for size in sizes:
106
+ assert_size(size)
107
+
108
+ assert len(sizes) == len(
109
+ self.gens
110
+ ), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
111
+
112
+ repeats = [
113
+ size // size_of_gen
114
+ for size, size_of_gen in zip(sizes, self._size_of_generators)
115
+ ]
116
+ assert 0 not in repeats
117
+
118
+ return repeats
119
+
120
+ def to_lazy_gen(
121
+ self, sizes: int | list[int] = 1, jit: bool = True
122
+ ) -> types.BatchedGenerator:
123
+ return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
124
+
125
+ @staticmethod
126
+ def _number_of_executions_required(size: int) -> int:
127
+ _, vmap = utils.distribute_batchsize(size)
128
+
129
+ eager_threshold = utils.batchsize_thresholds()[1]
130
+ primes = iter(utils.primes(vmap))
131
+ n_calls = 1
132
+ while vmap > eager_threshold:
133
+ prime = next(primes)
134
+ n_calls *= prime
135
+ vmap /= prime
136
+
137
+ return n_calls
138
+
139
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
140
+ "Returns list of unbatched sequences as numpy arrays."
141
+ repeats = self._compute_repeats(sizes)
142
+ sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
143
+
144
+ reduced_repeats = []
145
+ n_calls = []
146
+ for size, repeat in zip(sizes, repeats):
147
+ n_call = self._number_of_executions_required(size)
148
+ gcd = utils.gcd(n_call, repeat)
149
+ n_calls.append(gcd)
150
+ reduced_repeats.append(repeat // gcd)
151
+ jits = [N > 1 for N in n_calls]
152
+
153
+ gens = []
154
+ for i in range(len(repeats)):
155
+ gens.append(
156
+ batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
157
+ )
158
+
159
+ return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
160
+
161
+ def to_pickle(
162
+ self,
163
+ path: str,
164
+ sizes: int | list[int] = 1,
165
+ seed: int = 1,
166
+ overwrite: bool = True,
167
+ ) -> None:
168
+ data = tree_utils.tree_batch(self.to_list(sizes, seed))
169
+ utils.pickle_save(data, path, overwrite=overwrite)
170
+
171
+ def to_eager_gen(
172
+ self,
173
+ batchsize: int = 1,
174
+ sizes: int | list[int] = 1,
175
+ seed: int = 1,
176
+ shuffle: bool = True,
177
+ ) -> types.BatchedGenerator:
178
+ data = self.to_list(sizes, seed)
179
+ assert len(data) >= batchsize
180
+
181
+ def data_fn(indices: list[int]):
182
+ return tree_utils.tree_batch([data[i] for i in indices])
183
+
184
+ return batch.generator_from_data_fn(
185
+ data_fn, list(range(len(data))), shuffle, batchsize
186
+ )
187
+
188
+ @staticmethod
189
+ def eager_gen_from_paths(
190
+ paths: str | list[str],
191
+ batchsize: int,
192
+ include_samples: Optional[list[int]] = None,
193
+ shuffle: bool = True,
194
+ load_all_into_memory: bool = False,
195
+ tree_transform=None,
196
+ ) -> tuple[types.BatchedGenerator, int]:
197
+ paths = utils.to_list(paths)
198
+ return batch.generator_from_paths(
199
+ paths,
200
+ batchsize,
201
+ include_samples,
202
+ shuffle,
203
+ load_all_into_memory=load_all_into_memory,
204
+ tree_transform=tree_transform,
205
+ )
206
+
207
+
208
+ def _copy_dicts(f) -> dict:
209
+ def _f(*args, **kwargs):
210
+ _copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
211
+ args = tuple([_copy(ele) for ele in args])
212
+ kwargs = {k: _copy(v) for k, v in kwargs.items()}
213
+ return f(*args, **kwargs)
214
+
215
+ return _f
216
+
217
+
218
+ @_copy_dicts
219
+ def _build_mconfig_batched_generator(
220
+ sys: base.System,
221
+ config: list[jcalc.MotionConfig],
222
+ setup_fn: types.SETUP_FN | None,
223
+ finalize_fn: types.FINALIZE_FN | None,
224
+ add_X_imus: bool,
225
+ add_X_imus_kwargs: dict,
226
+ add_X_jointaxes: bool,
227
+ add_X_jointaxes_kwargs: dict,
228
+ add_y_relpose: bool,
229
+ add_y_rootincl: bool,
230
+ sys_ml: base.System,
231
+ randomize_positions: bool,
232
+ randomize_motion_artifacts: bool,
233
+ randomize_joint_params: bool,
234
+ imu_motion_artifacts: bool,
235
+ imu_motion_artifacts_kwargs: dict,
236
+ dynamic_simulation: bool,
237
+ dynamic_simulation_kwargs: dict,
238
+ output_transform: Callable | None,
239
+ keep_output_extras: bool,
240
+ use_link_number_in_Xy: bool,
241
+ ) -> types.BatchedGenerator:
242
+
243
+ if add_X_jointaxes or add_y_relpose or add_y_rootincl:
244
+ if len(sys_ml.findall_imus()) > 0:
245
+ # warnings.warn("Automatically removed the IMUs from `sys_ml`.")
246
+ sys_noimu, _ = sys_ml.make_sys_noimu()
247
+ else:
248
+ sys_noimu = sys_ml
249
+
250
+ unactuated_subsystems = []
251
+ if imu_motion_artifacts:
252
+ assert dynamic_simulation
253
+ unactuated_subsystems = motion_artifacts.unactuated_subsystem(sys)
254
+ sys = motion_artifacts.inject_subsystems(sys, **imu_motion_artifacts_kwargs)
255
+ assert "unactuated_subsystems" not in dynamic_simulation_kwargs
256
+ dynamic_simulation_kwargs["unactuated_subsystems"] = unactuated_subsystems
257
+
258
+ if not randomize_motion_artifacts:
259
+ warnings.warn(
260
+ "`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
261
+ )
262
+
263
+ if "prob_rigid" in imu_motion_artifacts_kwargs:
264
+ assert randomize_motion_artifacts, (
265
+ "`prob_rigid` works by overwriting damping and stiffness parameters "
266
+ "using the `randomize_motion_artifacts` flag, so it must be enabled."
267
+ )
268
+
269
+ def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
270
+ pipe = []
271
+ if imu_motion_artifacts and randomize_motion_artifacts:
272
+ pipe.append(
273
+ motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
274
+ **imu_motion_artifacts_kwargs
275
+ )
276
+ )
277
+ if randomize_positions:
278
+ pipe.append(setup_fns._setup_fn_randomize_positions)
279
+ if randomize_joint_params:
280
+ pipe.append(jcalc._init_joint_params)
281
+ if setup_fn is not None:
282
+ pipe.append(setup_fn)
283
+
284
+ for f in pipe:
285
+ key, consume = jax.random.split(key)
286
+ sys = f(consume, sys)
287
+ return sys
288
+
289
+ def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
290
+ pipe = []
291
+ if dynamic_simulation:
292
+ pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
293
+ if imu_motion_artifacts and imu_motion_artifacts_kwargs["hide_injected_bodies"]:
294
+ pipe.append(motion_artifacts.HideInjectedBodies())
295
+ if finalize_fn is not None:
296
+ pipe.append(finalize_fns.FinalizeFn(finalize_fn))
297
+ if add_X_imus:
298
+ pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
299
+ if add_X_jointaxes:
300
+ pipe.append(
301
+ finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
302
+ )
303
+ if add_y_relpose:
304
+ pipe.append(finalize_fns.RelPose(sys_noimu))
305
+ if add_y_rootincl:
306
+ pipe.append(finalize_fns.RootIncl(sys_noimu))
307
+ if use_link_number_in_Xy:
308
+ pipe.append(finalize_fns.Names2Indices(sys_noimu))
309
+
310
+ for f in pipe:
311
+ Xy, extras = f(Xy, extras)
312
+ return Xy, extras
313
+
314
+ def _gen(key: types.PRNGKey):
315
+ qs = []
316
+ for _config in config:
317
+ key, _q = draw_random_q(key, sys, _config)
318
+ qs.append(_q)
319
+ qs = jnp.stack(qs)
320
+
321
+ key, *consume = jax.random.split(key, len(config) + 1)
322
+ syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
323
+
324
+ @jax.vmap
325
+ def _vmapped_context(key, q, sys):
326
+ x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
327
+ Xy, extras = ({}, {}), (key, q, x, sys)
328
+ return _finalize_fn(Xy, extras)
329
+
330
+ keys = jax.random.split(key, len(config))
331
+ Xy, extras = _vmapped_context(keys, qs, syss)
332
+ output = (Xy, extras) if keep_output_extras else Xy
333
+ output = output if output_transform is None else output_transform(output)
334
+ return output
335
+
336
+ return _gen
337
+
338
+
339
+ def draw_random_q(
340
+ key: types.PRNGKey,
341
+ sys: base.System,
342
+ config: jcalc.MotionConfig,
343
+ ) -> tuple[types.Xy, types.OutputExtras]:
344
+
345
+ key_start = key
346
+ # build generalized coordintes vector `q`
347
+ q_list = []
348
+
349
+ def draw_q(key, __, link_type, link):
350
+ joint_params = link.joint_params
351
+ # limit scope
352
+ joint_params = (
353
+ joint_params[link_type]
354
+ if link_type in joint_params
355
+ else joint_params["default"]
356
+ )
357
+ if key is None:
358
+ key = key_start
359
+ key, key_t, key_value = jax.random.split(key, 3)
360
+ draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
361
+ if draw_fn is None:
362
+ raise Exception(f"The joint type {link_type} has no draw fn specified.")
363
+ q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
364
+ # even revolute and prismatic joints must be 2d arrays
365
+ q_link = q_link if q_link.ndim == 2 else q_link[:, None]
366
+ q_list.append(q_link)
367
+ return key
368
+
369
+ keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
370
+ # stack of keys; only the last key is unused
371
+ key = keys[-1]
372
+
373
+ q = jnp.concatenate(q_list, axis=1)
374
+
375
+ return key, q