imt-ring 1.3.0__tar.gz → 1.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {imt_ring-1.3.0 → imt_ring-1.3.1}/PKG-INFO +1 -1
  2. {imt_ring-1.3.0 → imt_ring-1.3.1}/pyproject.toml +1 -1
  3. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/base.py +9 -15
  5. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/batch.py +33 -11
  6. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/base.py +6 -7
  7. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/__init__.py +4 -2
  8. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/rendering/base_render.py +11 -6
  9. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/__init__.py +1 -0
  10. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/utils.py +14 -0
  11. {imt_ring-1.3.0 → imt_ring-1.3.1}/readme.md +0 -0
  12. {imt_ring-1.3.0 → imt_ring-1.3.1}/setup.cfg +0 -0
  13. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/imt_ring.egg-info/SOURCES.txt +0 -0
  14. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  15. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/imt_ring.egg-info/requires.txt +0 -0
  16. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/imt_ring.egg-info/top_level.txt +0 -0
  17. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/__init__.py +0 -0
  18. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algebra.py +0 -0
  19. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/__init__.py +0 -0
  20. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/_random.py +0 -0
  21. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  22. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  23. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  24. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  25. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/dynamics.py +0 -0
  26. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/__init__.py +0 -0
  27. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  28. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/pd_control.py +0 -0
  29. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/randomize.py +0 -0
  30. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/transforms.py +0 -0
  31. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/generator/types.py +0 -0
  32. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/jcalc.py +0 -0
  33. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/algorithms/sensors.py +0 -0
  35. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/__init__.py +0 -0
  36. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/branched.xml +0 -0
  37. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  38. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  39. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  40. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/inv_pendulum.xml +0 -0
  41. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  42. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/spherical_stiff.xml +0 -0
  43. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/symmetric.xml +0 -0
  44. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_all_1.xml +0 -0
  45. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_all_2.xml +0 -0
  46. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  47. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_control.xml +0 -0
  48. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  49. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_free.xml +0 -0
  50. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_kinematics.xml +0 -0
  51. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  52. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  53. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_randomize_position.xml +0 -0
  54. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_sensors.xml +0 -0
  55. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  56. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/examples.py +0 -0
  57. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/test_examples.py +0 -0
  58. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/__init__.py +0 -0
  59. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/abstract.py +0 -0
  60. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/from_xml.py +0 -0
  61. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/test_from_xml.py +0 -0
  62. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/test_to_xml.py +0 -0
  63. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/io/xml/to_xml.py +0 -0
  64. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/maths.py +0 -0
  65. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/base.py +0 -0
  66. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/callbacks.py +0 -0
  67. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/ml_utils.py +0 -0
  68. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/optimizer.py +0 -0
  69. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  70. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/ringnet.py +0 -0
  71. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/train.py +0 -0
  72. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/ml/training_loop.py +0 -0
  73. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/rendering/__init__.py +0 -0
  74. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/rendering/mujoco_render.py +0 -0
  75. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/rendering/vispy_render.py +0 -0
  76. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/rendering/vispy_visuals.py +0 -0
  77. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sim2real/__init__.py +0 -0
  78. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sim2real/sim2real.py +0 -0
  79. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/spatial.py +0 -0
  80. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sys_composer/__init__.py +0 -0
  81. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sys_composer/delete_sys.py +0 -0
  82. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sys_composer/inject_sys.py +0 -0
  83. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/sys_composer/morph_sys.py +0 -0
  84. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/batchsize.py +0 -0
  85. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/colab.py +0 -0
  86. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/hdf5.py +0 -0
  87. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/normalizer.py +0 -0
  88. {imt_ring-1.3.0 → imt_ring-1.3.1}/src/ring/utils/path.py +0 -0
  89. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_algebra.py +0 -0
  90. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_base.py +0 -0
  91. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_custom_joints.py +0 -0
  92. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_dynamics.py +0 -0
  93. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_generator.py +0 -0
  94. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_jcalc.py +0 -0
  95. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_jit.py +0 -0
  96. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_kinematics.py +0 -0
  97. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_maths.py +0 -0
  98. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_ml_utils.py +0 -0
  99. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_motion_artifacts.py +0 -0
  100. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_pd_control.py +0 -0
  101. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_random.py +0 -0
  102. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_randomize.py +0 -0
  103. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_rcmg.py +0 -0
  104. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_render.py +0 -0
  105. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_sensors.py +0 -0
  106. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_sim2real.py +0 -0
  107. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_sys_composer.py +0 -0
  108. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_train.py +0 -0
  109. {imt_ring-1.3.0 → imt_ring-1.3.1}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.0
3
+ Version: 1.3.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.3.0"
7
+ version = "1.3.1"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.0
3
+ Version: 1.3.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,6 +4,8 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tree_utils
8
+
7
9
  from ring import base
8
10
  from ring import utils
9
11
  from ring.algorithms import jcalc
@@ -13,7 +15,6 @@ from ring.algorithms.generator import motion_artifacts
13
15
  from ring.algorithms.generator import randomize
14
16
  from ring.algorithms.generator import transforms
15
17
  from ring.algorithms.generator import types
16
- import tree_utils
17
18
 
18
19
 
19
20
  class RCMG:
@@ -108,23 +109,20 @@ class RCMG:
108
109
  partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
109
110
  )
110
111
 
111
- def _to_data(self, sizes, seed, jit):
112
- return batch.batch_generators_eager_to_list(
113
- self.gens, sizes, seed=seed, jit=jit
114
- )
112
+ def _to_data(self, sizes, seed):
113
+ return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
115
114
 
116
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):
117
- return self._to_data(sizes, seed, jit)
115
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
116
+ return self._to_data(sizes, seed)
118
117
 
119
118
  def to_pickle(
120
119
  self,
121
120
  path: str,
122
121
  sizes: int | list[int] = 1,
123
122
  seed: int = 1,
124
- jit: bool = False,
125
123
  overwrite: bool = True,
126
124
  ) -> None:
127
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
125
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
128
126
  utils.pickle_save(data, path, overwrite=overwrite)
129
127
 
130
128
  def to_hdf5(
@@ -132,10 +130,9 @@ class RCMG:
132
130
  path: str,
133
131
  sizes: int | list[int] = 1,
134
132
  seed: int = 1,
135
- jit: bool = False,
136
133
  overwrite: bool = True,
137
134
  ) -> None:
138
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
135
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
139
136
  utils.hdf5_save(path, data, overwrite=overwrite)
140
137
 
141
138
  def to_eager_gen(
@@ -143,11 +140,8 @@ class RCMG:
143
140
  batchsize: int = 1,
144
141
  sizes: int | list[int] = 1,
145
142
  seed: int = 1,
146
- jit: bool = False,
147
143
  ) -> types.BatchedGenerator:
148
- return batch.batch_generators_eager(
149
- self.gens, sizes, batchsize, seed=seed, jit=jit
150
- )
144
+ return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
151
145
 
152
146
  def to_lazy_gen(
153
147
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -6,12 +6,13 @@ import warnings
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from ring import utils
10
- from ring.algorithms.generator import types
11
9
  from tqdm import tqdm
12
10
  import tree_utils
13
11
  from tree_utils import tree_batch
14
12
 
13
+ from ring import utils
14
+ from ring.algorithms.generator import types
15
+
15
16
 
16
17
  def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
17
18
  arr = []
@@ -61,11 +62,24 @@ def batch_generators_lazy(
61
62
  return generator
62
63
 
63
64
 
65
+ def _number_of_executions_required(size: int) -> int:
66
+ vmap_threshold = 128
67
+ _, vmap = utils.distribute_batchsize(size)
68
+
69
+ primes = iter(utils.primes(vmap))
70
+ n_calls = 1
71
+ while vmap > vmap_threshold:
72
+ prime = next(primes)
73
+ n_calls *= prime
74
+ vmap /= prime
75
+
76
+ return n_calls
77
+
78
+
64
79
  def batch_generators_eager_to_list(
65
80
  generators: types.Generator | list[types.Generator],
66
81
  sizes: int | list[int],
67
82
  seed: int = 1,
68
- jit: bool = True,
69
83
  ) -> list[tree_utils.PyTree]:
70
84
  "Returns list of unbatched sequences as numpy arrays."
71
85
  generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
@@ -73,11 +87,20 @@ def batch_generators_eager_to_list(
73
87
  key = jax.random.PRNGKey(seed)
74
88
  data = []
75
89
  for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
76
- key, consume = jax.random.split(key)
77
- sample = batch_generators_lazy(gen, size, jit=jit)(consume)
78
- # converts also to numpy
79
- sample = jax.device_get(sample)
80
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
90
+
91
+ n_calls = _number_of_executions_required(size)
92
+ # decrease size by n_calls times
93
+ size = int(size / n_calls)
94
+ jit = True if n_calls > 1 else False
95
+ gen_jit = batch_generators_lazy(gen, size, jit=jit)
96
+
97
+ for _ in range(n_calls):
98
+ key, consume = jax.random.split(key)
99
+ sample = gen_jit(consume)
100
+ # converts also to numpy
101
+ sample = jax.device_get(sample)
102
+ data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
+
81
104
  return data
82
105
 
83
106
 
@@ -243,12 +266,11 @@ def batch_generators_eager(
243
266
  shuffle: bool = True,
244
267
  drop_last: bool = True,
245
268
  seed: int = 1,
246
- jit: bool = True,
247
269
  ) -> types.BatchedGenerator:
248
270
  """Eagerly create a large precomputed generator by calling multiple generators
249
271
  and stacking their output."""
250
272
 
251
- data = batch_generators_eager_to_list(generators, sizes, seed=seed, jit=jit)
273
+ data = batch_generators_eager_to_list(generators, sizes, seed=seed)
252
274
  return batched_generator_from_list(data, batchsize, shuffle, drop_last)
253
275
 
254
276
 
@@ -270,7 +292,7 @@ def _process_sizes_batchsizes_generators(
270
292
 
271
293
  assert len(generators) == len(list_sizes)
272
294
 
273
- _WARN_SIZE = 4096
295
+ _WARN_SIZE = 1e6 # disable this warning
274
296
  for size in list_sizes:
275
297
  if size >= _WARN_SIZE:
276
298
  warnings.warn(
@@ -99,15 +99,15 @@ class _Base:
99
99
  def ndim(self):
100
100
  return tu.tree_ndim(self)
101
101
 
102
- def shape(self, axis=0) -> int:
103
- return tu.tree_shape(self, axis)
104
-
105
- def __len__(self) -> int:
106
- Bs = tree_map(lambda arr: arr.shape[0], self)
102
+ def shape(self, axis: int = 0) -> int:
103
+ Bs = tree_map(lambda arr: arr.shape[axis], self)
107
104
  Bs = set(jax.tree_util.tree_flatten(Bs)[0])
108
105
  assert len(Bs) == 1
109
106
  return list(Bs)[0]
110
107
 
108
+ def __len__(self) -> int:
109
+ return self.shape(axis=0)
110
+
111
111
 
112
112
  @struct.dataclass
113
113
  class Transform(_Base):
@@ -685,14 +685,13 @@ class System(_Base):
685
685
  self,
686
686
  xs: Transform | list[Transform],
687
687
  yhat: dict | jax.Array | np.ndarray,
688
- stepframe: int = 1,
689
688
  # by default we don't predict the global rotation
690
689
  transparent_segment_to_root: bool = True,
691
690
  **kwargs,
692
691
  ):
693
692
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
694
693
  return ring.rendering.render_prediction(
695
- self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
694
+ self, xs, yhat, transparent_segment_to_root, **kwargs
696
695
  )
697
696
 
698
697
  def delete_system(self, link_name: str | list[str], strict: bool = True):
@@ -12,6 +12,8 @@ from .optimizer import make_optimizer
12
12
  from .ringnet import RING
13
13
  from .train import train_fn
14
14
 
15
+ _lpf_cutoff_freq = 10.0
16
+
15
17
 
16
18
  def RING_ICML24(params=None, eval: bool = True, **kwargs):
17
19
  """Create the RING network used in the icml24 paper.
@@ -29,7 +31,7 @@ def RING_ICML24(params=None, eval: bool = True, **kwargs):
29
31
  ringnet = RING(params=params, **kwargs) # noqa: F811
30
32
  ringnet = base.ScaleX_FilterWrapper(ringnet)
31
33
  if eval:
32
- ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
34
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
33
35
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
34
36
  return ringnet
35
37
 
@@ -55,7 +57,7 @@ def RNNO(
55
57
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
56
58
  ringnet = base.ScaleX_FilterWrapper(ringnet)
57
59
  if eval and return_quats:
58
- ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
60
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
59
61
  if return_quats:
60
62
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
61
63
  return ringnet
@@ -136,12 +136,15 @@ def render_prediction(
136
136
  sys: base.System,
137
137
  xs: base.Transform | list[base.Transform],
138
138
  yhat: dict | jax.Array | np.ndarray,
139
- stepframe: int = 1,
140
139
  # by default we don't predict the global rotation
141
140
  transparent_segment_to_root: bool = True,
142
141
  **kwargs,
143
142
  ):
144
143
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
+
145
+ offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
+ offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
+
145
148
  if isinstance(xs, list):
146
149
  # list -> batched Transform
147
150
  xs = xs[0].batch(*xs[1:])
@@ -180,18 +183,23 @@ def render_prediction(
180
183
 
181
184
  # swap time axis, and link axis
182
185
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
+
187
+ add_offset = lambda x, offset: algebra.transform_mul(
188
+ x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
189
+ )
190
+
183
191
  # create mapping from `name` -> Transform
184
192
  xs_dict = dict(
185
193
  zip(
186
194
  ["hat_" + name for name in sys_noimu.link_names],
187
- [xshat[i] for i in range(sys_noimu.num_links())],
195
+ [add_offset(xshat[i], offset_pred) for i in range(sys_noimu.num_links())],
188
196
  )
189
197
  )
190
198
  xs_dict.update(
191
199
  dict(
192
200
  zip(
193
201
  sys.link_names,
194
- [xs[i] for i in range(sys.num_links())],
202
+ [add_offset(xs[i], offset_truth) for i in range(sys.num_links())],
195
203
  )
196
204
  )
197
205
  )
@@ -202,11 +210,8 @@ def render_prediction(
202
210
  xs_render.append(xs_dict[name])
203
211
  xs_render = xs_render[0].batch(*xs_render[1:])
204
212
  xs_render = xs_render.transpose((1, 0, 2))
205
- N = xs_render.shape()
206
- xs_render = [xs_render[t] for t in range(0, N, stepframe)]
207
213
 
208
214
  frames = render(sys_render, xs_render, **kwargs)
209
-
210
215
  return frames
211
216
 
212
217
 
@@ -15,6 +15,7 @@ from .utils import dict_union
15
15
  from .utils import import_lib
16
16
  from .utils import pickle_load
17
17
  from .utils import pickle_save
18
+ from .utils import primes
18
19
  from .utils import pytree_deepcopy
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
@@ -159,3 +159,17 @@ def pickle_load(
159
159
  with open(path, "rb") as file:
160
160
  obj = pickle.load(file)
161
161
  return obj
162
+
163
+
164
+ def primes(n: int) -> list[int]:
165
+ "Primefactor decomposition in ascending order."
166
+ primfac = []
167
+ d = 2
168
+ while d * d <= n:
169
+ while (n % d) == 0:
170
+ primfac.append(d) # supposing you want multiple factors repeated
171
+ n //= d
172
+ d += 1
173
+ if n > 1:
174
+ primfac.append(n)
175
+ return primfac
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes