imt-ring 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.0
3
+ Version: 1.3.2
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=lG1SAVi6VpJT20Xvdhv_NrObMb4008leEqPEaQ0anR8,33566
3
+ ring/base.py,sha256=gqdXejZ4E4liB5mZ6gPof3EDYTThlfro2MQs0bc5eOM,33530
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
@@ -9,13 +9,13 @@ ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,1
9
9
  ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
12
- ring/algorithms/custom_joints/__init__.py,sha256=33WBnaBJMtq3vVcpMm7zmyeMrLY9PyV_8-wk5oSF65g,227
12
+ ring/algorithms/custom_joints/__init__.py,sha256=_kUyC4TbzjngTQrJVtS6JBKPzTMNbH27jVRJYXViepI,270
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=d0Z54tIXiepMixE40W5H8JKxrT5U6VskPm2L2kKnQPw,13680
15
+ ring/algorithms/custom_joints/suntay.py,sha256=3aFDfqdC2vUAhD30kkuQltgU_WZmYDyVhKPSoEotEYo,15292
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=QDmzMAgtaK5M9WDl39qjXYfBa99d83vCPWEkYYmsplk,14952
18
- ring/algorithms/generator/batch.py,sha256=MZurZmQDH1vncoNbCspVNGNlfP0R87J6_HC7MMIqQ6A,8478
17
+ ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
+ ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
20
  ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
@@ -50,17 +50,17 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=hHQUeYAdqRoEFtpKW4zXyNkAdeH2cPh17vc29hXcWWw,1746
53
+ ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
54
54
  ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
55
55
  ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
- ring/ml/optimizer.py,sha256=OP70P70YcX-2Z-cuoMluFk-L5Vhh_MmqiHdM9OZqyhI,4703
57
+ ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
59
  ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
60
60
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
61
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
62
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
63
- ring/rendering/base_render.py,sha256=L2x83oWC2KkneuA5Lubg2qoDInDFBfH5wYlATymP9-0,8764
63
+ ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
64
64
  ring/rendering/mujoco_render.py,sha256=aluzQJp3jrDdPfAyNmQuXIHRfgfBTCCZQqxKOx_0D2s,7770
65
65
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
66
66
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -70,14 +70,14 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
70
70
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
71
71
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
72
72
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
73
- ring/utils/__init__.py,sha256=6BWUMmbQ4E-Qwd-SNfRlpbzJ0UJ1DpEclstrgbLdDvk,773
73
+ ring/utils/__init__.py,sha256=rTvSA4RiJAVCY_A64FUMd8IJTv94LgoSA3Ps5X63_jA,799
74
74
  ring/utils/batchsize.py,sha256=mPFGD7AedFMycHtyIuZtNWCaAvKLLWSWaB7X6u54xvM,1358
75
75
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
76
76
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
- ring/utils/utils.py,sha256=AzOzR95oOyfdtJhjt5iIb35u611NlTb1Ds4QDKrGMOM,4967
80
- imt_ring-1.3.0.dist-info/METADATA,sha256=fYndGMXxYVbuBmGAqpzGezBcD6-ruTt3VOAMdLxTwDE,3104
81
- imt_ring-1.3.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.0.dist-info/RECORD,,
79
+ ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
+ imt_ring-1.3.2.dist-info/METADATA,sha256=OQmB5-CEpy-JWv2K9vYo9fUNQDW0Jg-fFU4kBkiVRGQ,3104
81
+ imt_ring-1.3.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.2.dist-info/RECORD,,
@@ -2,5 +2,6 @@ from .rr_imp_joint import register_rr_imp_joint
2
2
  from .rr_joint import register_rr_joint
3
3
  from .suntay import GP_DrawFnPair
4
4
  from .suntay import MLP_DrawnFnPair
5
+ from .suntay import Polynomial_DrawnFnPair
5
6
  from .suntay import register_suntay
6
7
  from .suntay import SuntayConfig
@@ -293,6 +293,63 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
293
293
  ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
294
294
 
295
295
 
296
+ def Polynomial_DrawnFnPair(
297
+ order: int = 2,
298
+ val: float = 2.0,
299
+ center: bool = False,
300
+ flexion_center: Optional[float] = None,
301
+ ) -> DrawnFnPairFactory:
302
+ assert val >= 0.0
303
+
304
+ # because 0-th order is also counted
305
+ order += 1
306
+
307
+ def factory(xs, mn, mx):
308
+ nonlocal flexion_center
309
+
310
+ flexion_mn = jnp.min(xs)
311
+ flexion_mx = jnp.max(xs)
312
+
313
+ def _apply_poly_factors(poly_factors, q):
314
+ return poly_factors @ jnp.power(q, jnp.arange(order))
315
+
316
+ if flexion_center is None:
317
+ flexion_center = (flexion_mn + flexion_mx) / 2
318
+ else:
319
+ flexion_center = jnp.array(flexion_center)
320
+
321
+ def init(key):
322
+ c1, c2 = jax.random.split(key)
323
+ poly_factors = jax.random.uniform(
324
+ c1, shape=(order,), minval=-val, maxval=val
325
+ )
326
+ q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
327
+ values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
328
+ poly_factors, xs - q0
329
+ )
330
+ amax = jnp.max(values)
331
+ amin = jnp.min(values)
332
+ return amin, amax, poly_factors, q0
333
+
334
+ def _apply(params, q):
335
+ amin, amax, poly_factors, q0 = params
336
+ q = q - q0
337
+ value = _apply_poly_factors(poly_factors, q)
338
+ return restrict(value, mn, mx, amin, amax)
339
+
340
+ if center:
341
+
342
+ def apply(params, q):
343
+ return _apply(params, q) - _apply(params, flexion_center)
344
+
345
+ else:
346
+ apply = _apply
347
+
348
+ return DrawnFnPair(init, apply)
349
+
350
+ return factory
351
+
352
+
296
353
  def MLP_DrawnFnPair(
297
354
  center: bool = False, flexion_center: Optional[float] = None
298
355
  ) -> DrawnFnPairFactory:
@@ -4,6 +4,8 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tree_utils
8
+
7
9
  from ring import base
8
10
  from ring import utils
9
11
  from ring.algorithms import jcalc
@@ -13,7 +15,6 @@ from ring.algorithms.generator import motion_artifacts
13
15
  from ring.algorithms.generator import randomize
14
16
  from ring.algorithms.generator import transforms
15
17
  from ring.algorithms.generator import types
16
- import tree_utils
17
18
 
18
19
 
19
20
  class RCMG:
@@ -108,23 +109,20 @@ class RCMG:
108
109
  partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
109
110
  )
110
111
 
111
- def _to_data(self, sizes, seed, jit):
112
- return batch.batch_generators_eager_to_list(
113
- self.gens, sizes, seed=seed, jit=jit
114
- )
112
+ def _to_data(self, sizes, seed):
113
+ return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
115
114
 
116
- def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):
117
- return self._to_data(sizes, seed, jit)
115
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
116
+ return self._to_data(sizes, seed)
118
117
 
119
118
  def to_pickle(
120
119
  self,
121
120
  path: str,
122
121
  sizes: int | list[int] = 1,
123
122
  seed: int = 1,
124
- jit: bool = False,
125
123
  overwrite: bool = True,
126
124
  ) -> None:
127
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
125
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
128
126
  utils.pickle_save(data, path, overwrite=overwrite)
129
127
 
130
128
  def to_hdf5(
@@ -132,10 +130,9 @@ class RCMG:
132
130
  path: str,
133
131
  sizes: int | list[int] = 1,
134
132
  seed: int = 1,
135
- jit: bool = False,
136
133
  overwrite: bool = True,
137
134
  ) -> None:
138
- data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
135
+ data = tree_utils.tree_batch(self._to_data(sizes, seed))
139
136
  utils.hdf5_save(path, data, overwrite=overwrite)
140
137
 
141
138
  def to_eager_gen(
@@ -143,11 +140,8 @@ class RCMG:
143
140
  batchsize: int = 1,
144
141
  sizes: int | list[int] = 1,
145
142
  seed: int = 1,
146
- jit: bool = False,
147
143
  ) -> types.BatchedGenerator:
148
- return batch.batch_generators_eager(
149
- self.gens, sizes, batchsize, seed=seed, jit=jit
150
- )
144
+ return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
151
145
 
152
146
  def to_lazy_gen(
153
147
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -6,12 +6,13 @@ import warnings
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from ring import utils
10
- from ring.algorithms.generator import types
11
9
  from tqdm import tqdm
12
10
  import tree_utils
13
11
  from tree_utils import tree_batch
14
12
 
13
+ from ring import utils
14
+ from ring.algorithms.generator import types
15
+
15
16
 
16
17
  def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
17
18
  arr = []
@@ -61,11 +62,24 @@ def batch_generators_lazy(
61
62
  return generator
62
63
 
63
64
 
65
+ def _number_of_executions_required(size: int) -> int:
66
+ vmap_threshold = 128
67
+ _, vmap = utils.distribute_batchsize(size)
68
+
69
+ primes = iter(utils.primes(vmap))
70
+ n_calls = 1
71
+ while vmap > vmap_threshold:
72
+ prime = next(primes)
73
+ n_calls *= prime
74
+ vmap /= prime
75
+
76
+ return n_calls
77
+
78
+
64
79
  def batch_generators_eager_to_list(
65
80
  generators: types.Generator | list[types.Generator],
66
81
  sizes: int | list[int],
67
82
  seed: int = 1,
68
- jit: bool = True,
69
83
  ) -> list[tree_utils.PyTree]:
70
84
  "Returns list of unbatched sequences as numpy arrays."
71
85
  generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
@@ -73,11 +87,20 @@ def batch_generators_eager_to_list(
73
87
  key = jax.random.PRNGKey(seed)
74
88
  data = []
75
89
  for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
76
- key, consume = jax.random.split(key)
77
- sample = batch_generators_lazy(gen, size, jit=jit)(consume)
78
- # converts also to numpy
79
- sample = jax.device_get(sample)
80
- data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
90
+
91
+ n_calls = _number_of_executions_required(size)
92
+ # decrease size by n_calls times
93
+ size = int(size / n_calls)
94
+ jit = True if n_calls > 1 else False
95
+ gen_jit = batch_generators_lazy(gen, size, jit=jit)
96
+
97
+ for _ in range(n_calls):
98
+ key, consume = jax.random.split(key)
99
+ sample = gen_jit(consume)
100
+ # converts also to numpy
101
+ sample = jax.device_get(sample)
102
+ data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
+
81
104
  return data
82
105
 
83
106
 
@@ -243,12 +266,11 @@ def batch_generators_eager(
243
266
  shuffle: bool = True,
244
267
  drop_last: bool = True,
245
268
  seed: int = 1,
246
- jit: bool = True,
247
269
  ) -> types.BatchedGenerator:
248
270
  """Eagerly create a large precomputed generator by calling multiple generators
249
271
  and stacking their output."""
250
272
 
251
- data = batch_generators_eager_to_list(generators, sizes, seed=seed, jit=jit)
273
+ data = batch_generators_eager_to_list(generators, sizes, seed=seed)
252
274
  return batched_generator_from_list(data, batchsize, shuffle, drop_last)
253
275
 
254
276
 
@@ -270,7 +292,7 @@ def _process_sizes_batchsizes_generators(
270
292
 
271
293
  assert len(generators) == len(list_sizes)
272
294
 
273
- _WARN_SIZE = 4096
295
+ _WARN_SIZE = 1e6 # disable this warning
274
296
  for size in list_sizes:
275
297
  if size >= _WARN_SIZE:
276
298
  warnings.warn(
ring/base.py CHANGED
@@ -99,15 +99,15 @@ class _Base:
99
99
  def ndim(self):
100
100
  return tu.tree_ndim(self)
101
101
 
102
- def shape(self, axis=0) -> int:
103
- return tu.tree_shape(self, axis)
104
-
105
- def __len__(self) -> int:
106
- Bs = tree_map(lambda arr: arr.shape[0], self)
102
+ def shape(self, axis: int = 0) -> int:
103
+ Bs = tree_map(lambda arr: arr.shape[axis], self)
107
104
  Bs = set(jax.tree_util.tree_flatten(Bs)[0])
108
105
  assert len(Bs) == 1
109
106
  return list(Bs)[0]
110
107
 
108
+ def __len__(self) -> int:
109
+ return self.shape(axis=0)
110
+
111
111
 
112
112
  @struct.dataclass
113
113
  class Transform(_Base):
@@ -685,14 +685,13 @@ class System(_Base):
685
685
  self,
686
686
  xs: Transform | list[Transform],
687
687
  yhat: dict | jax.Array | np.ndarray,
688
- stepframe: int = 1,
689
688
  # by default we don't predict the global rotation
690
689
  transparent_segment_to_root: bool = True,
691
690
  **kwargs,
692
691
  ):
693
692
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
694
693
  return ring.rendering.render_prediction(
695
- self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
694
+ self, xs, yhat, transparent_segment_to_root, **kwargs
696
695
  )
697
696
 
698
697
  def delete_system(self, link_name: str | list[str], strict: bool = True):
ring/ml/__init__.py CHANGED
@@ -12,6 +12,8 @@ from .optimizer import make_optimizer
12
12
  from .ringnet import RING
13
13
  from .train import train_fn
14
14
 
15
+ _lpf_cutoff_freq = 10.0
16
+
15
17
 
16
18
  def RING_ICML24(params=None, eval: bool = True, **kwargs):
17
19
  """Create the RING network used in the icml24 paper.
@@ -29,7 +31,7 @@ def RING_ICML24(params=None, eval: bool = True, **kwargs):
29
31
  ringnet = RING(params=params, **kwargs) # noqa: F811
30
32
  ringnet = base.ScaleX_FilterWrapper(ringnet)
31
33
  if eval:
32
- ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
34
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
33
35
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
34
36
  return ringnet
35
37
 
@@ -39,6 +41,7 @@ def RNNO(
39
41
  return_quats: bool = False,
40
42
  params=None,
41
43
  eval: bool = True,
44
+ samp_freq: float | None = None,
42
45
  **kwargs,
43
46
  ):
44
47
  assert "message_dim" not in kwargs
@@ -55,7 +58,7 @@ def RNNO(
55
58
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
56
59
  ringnet = base.ScaleX_FilterWrapper(ringnet)
57
60
  if eval and return_quats:
58
- ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
61
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
59
62
  if return_quats:
60
63
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
61
64
  return ringnet
ring/ml/optimizer.py CHANGED
@@ -14,10 +14,10 @@ from optax._src.transform import AddNoiseState
14
14
  def make_optimizer(
15
15
  lr: float,
16
16
  n_episodes: int,
17
- n_steps_per_episode: int,
17
+ n_steps_per_episode: int = 6,
18
18
  adap_clip: Optional[float] = 0.1,
19
19
  glob_clip: Optional[float] = 0.2,
20
- skip_large_update_max_normsq: float = 5.0,
20
+ skip_large_update_max_normsq: float = 100.0,
21
21
  skip_large_update_warmup: int = 300,
22
22
  inner_opt=optax.lamb,
23
23
  cos_decay_twice: bool = False,
@@ -136,12 +136,15 @@ def render_prediction(
136
136
  sys: base.System,
137
137
  xs: base.Transform | list[base.Transform],
138
138
  yhat: dict | jax.Array | np.ndarray,
139
- stepframe: int = 1,
140
139
  # by default we don't predict the global rotation
141
140
  transparent_segment_to_root: bool = True,
142
141
  **kwargs,
143
142
  ):
144
143
  "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
+
145
+ offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
+ offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
+
145
148
  if isinstance(xs, list):
146
149
  # list -> batched Transform
147
150
  xs = xs[0].batch(*xs[1:])
@@ -180,18 +183,23 @@ def render_prediction(
180
183
 
181
184
  # swap time axis, and link axis
182
185
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
+
187
+ add_offset = lambda x, offset: algebra.transform_mul(
188
+ x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
189
+ )
190
+
183
191
  # create mapping from `name` -> Transform
184
192
  xs_dict = dict(
185
193
  zip(
186
194
  ["hat_" + name for name in sys_noimu.link_names],
187
- [xshat[i] for i in range(sys_noimu.num_links())],
195
+ [add_offset(xshat[i], offset_pred) for i in range(sys_noimu.num_links())],
188
196
  )
189
197
  )
190
198
  xs_dict.update(
191
199
  dict(
192
200
  zip(
193
201
  sys.link_names,
194
- [xs[i] for i in range(sys.num_links())],
202
+ [add_offset(xs[i], offset_truth) for i in range(sys.num_links())],
195
203
  )
196
204
  )
197
205
  )
@@ -202,11 +210,8 @@ def render_prediction(
202
210
  xs_render.append(xs_dict[name])
203
211
  xs_render = xs_render[0].batch(*xs_render[1:])
204
212
  xs_render = xs_render.transpose((1, 0, 2))
205
- N = xs_render.shape()
206
- xs_render = [xs_render[t] for t in range(0, N, stepframe)]
207
213
 
208
214
  frames = render(sys_render, xs_render, **kwargs)
209
-
210
215
  return frames
211
216
 
212
217
 
ring/utils/__init__.py CHANGED
@@ -15,6 +15,7 @@ from .utils import dict_union
15
15
  from .utils import import_lib
16
16
  from .utils import pickle_load
17
17
  from .utils import pickle_save
18
+ from .utils import primes
18
19
  from .utils import pytree_deepcopy
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
ring/utils/utils.py CHANGED
@@ -159,3 +159,17 @@ def pickle_load(
159
159
  with open(path, "rb") as file:
160
160
  obj = pickle.load(file)
161
161
  return obj
162
+
163
+
164
+ def primes(n: int) -> list[int]:
165
+ "Primefactor decomposition in ascending order."
166
+ primfac = []
167
+ d = 2
168
+ while d * d <= n:
169
+ while (n % d) == 0:
170
+ primfac.append(d) # supposing you want multiple factors repeated
171
+ n //= d
172
+ d += 1
173
+ if n > 1:
174
+ primfac.append(n)
175
+ return primfac