imt-ring 1.3.1__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.1
3
+ Version: 1.3.2
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -9,10 +9,10 @@ ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,1
9
9
  ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
12
- ring/algorithms/custom_joints/__init__.py,sha256=33WBnaBJMtq3vVcpMm7zmyeMrLY9PyV_8-wk5oSF65g,227
12
+ ring/algorithms/custom_joints/__init__.py,sha256=_kUyC4TbzjngTQrJVtS6JBKPzTMNbH27jVRJYXViepI,270
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=d0Z54tIXiepMixE40W5H8JKxrT5U6VskPm2L2kKnQPw,13680
15
+ ring/algorithms/custom_joints/suntay.py,sha256=3aFDfqdC2vUAhD30kkuQltgU_WZmYDyVhKPSoEotEYo,15292
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
17
  ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
18
  ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
@@ -50,11 +50,11 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=4eK8P-pjAe_TcURaXaHKykZ3IfTbmxnQyOaI-EGQzg4,1795
53
+ ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
54
54
  ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
55
55
  ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
- ring/ml/optimizer.py,sha256=OP70P70YcX-2Z-cuoMluFk-L5Vhh_MmqiHdM9OZqyhI,4703
57
+ ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
59
  ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
60
60
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
79
  ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.1.dist-info/METADATA,sha256=sCl08586u_XLy0LUsEuIhyIUPxj3R3pzmXtXgFuRw1c,3104
81
- imt_ring-1.3.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.1.dist-info/RECORD,,
80
+ imt_ring-1.3.2.dist-info/METADATA,sha256=OQmB5-CEpy-JWv2K9vYo9fUNQDW0Jg-fFU4kBkiVRGQ,3104
81
+ imt_ring-1.3.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.2.dist-info/RECORD,,
@@ -2,5 +2,6 @@ from .rr_imp_joint import register_rr_imp_joint
2
2
  from .rr_joint import register_rr_joint
3
3
  from .suntay import GP_DrawFnPair
4
4
  from .suntay import MLP_DrawnFnPair
5
+ from .suntay import Polynomial_DrawnFnPair
5
6
  from .suntay import register_suntay
6
7
  from .suntay import SuntayConfig
@@ -293,6 +293,63 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
293
293
  ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
294
294
 
295
295
 
296
+ def Polynomial_DrawnFnPair(
297
+ order: int = 2,
298
+ val: float = 2.0,
299
+ center: bool = False,
300
+ flexion_center: Optional[float] = None,
301
+ ) -> DrawnFnPairFactory:
302
+ assert val >= 0.0
303
+
304
+ # because 0-th order is also counted
305
+ order += 1
306
+
307
+ def factory(xs, mn, mx):
308
+ nonlocal flexion_center
309
+
310
+ flexion_mn = jnp.min(xs)
311
+ flexion_mx = jnp.max(xs)
312
+
313
+ def _apply_poly_factors(poly_factors, q):
314
+ return poly_factors @ jnp.power(q, jnp.arange(order))
315
+
316
+ if flexion_center is None:
317
+ flexion_center = (flexion_mn + flexion_mx) / 2
318
+ else:
319
+ flexion_center = jnp.array(flexion_center)
320
+
321
+ def init(key):
322
+ c1, c2 = jax.random.split(key)
323
+ poly_factors = jax.random.uniform(
324
+ c1, shape=(order,), minval=-val, maxval=val
325
+ )
326
+ q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
327
+ values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
328
+ poly_factors, xs - q0
329
+ )
330
+ amax = jnp.max(values)
331
+ amin = jnp.min(values)
332
+ return amin, amax, poly_factors, q0
333
+
334
+ def _apply(params, q):
335
+ amin, amax, poly_factors, q0 = params
336
+ q = q - q0
337
+ value = _apply_poly_factors(poly_factors, q)
338
+ return restrict(value, mn, mx, amin, amax)
339
+
340
+ if center:
341
+
342
+ def apply(params, q):
343
+ return _apply(params, q) - _apply(params, flexion_center)
344
+
345
+ else:
346
+ apply = _apply
347
+
348
+ return DrawnFnPair(init, apply)
349
+
350
+ return factory
351
+
352
+
296
353
  def MLP_DrawnFnPair(
297
354
  center: bool = False, flexion_center: Optional[float] = None
298
355
  ) -> DrawnFnPairFactory:
ring/ml/__init__.py CHANGED
@@ -41,6 +41,7 @@ def RNNO(
41
41
  return_quats: bool = False,
42
42
  params=None,
43
43
  eval: bool = True,
44
+ samp_freq: float | None = None,
44
45
  **kwargs,
45
46
  ):
46
47
  assert "message_dim" not in kwargs
@@ -57,7 +58,7 @@ def RNNO(
57
58
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
58
59
  ringnet = base.ScaleX_FilterWrapper(ringnet)
59
60
  if eval and return_quats:
60
- ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
61
+ ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
61
62
  if return_quats:
62
63
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
63
64
  return ringnet
ring/ml/optimizer.py CHANGED
@@ -14,10 +14,10 @@ from optax._src.transform import AddNoiseState
14
14
  def make_optimizer(
15
15
  lr: float,
16
16
  n_episodes: int,
17
- n_steps_per_episode: int,
17
+ n_steps_per_episode: int = 6,
18
18
  adap_clip: Optional[float] = 0.1,
19
19
  glob_clip: Optional[float] = 0.2,
20
- skip_large_update_max_normsq: float = 5.0,
20
+ skip_large_update_max_normsq: float = 100.0,
21
21
  skip_large_update_warmup: int = 300,
22
22
  inner_opt=optax.lamb,
23
23
  cos_decay_twice: bool = False,