imt-ring 1.3.2__py3-none-any.whl → 1.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.2
3
+ Version: 1.3.4
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=gqdXejZ4E4liB5mZ6gPof3EDYTThlfro2MQs0bc5eOM,33530
3
+ ring/base.py,sha256=ZAoe9B1HbAX9NYiKaisssTBn-1VBXoJTsWgFAvlQoZw,33705
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
@@ -9,10 +9,10 @@ ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,1
9
9
  ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
12
- ring/algorithms/custom_joints/__init__.py,sha256=_kUyC4TbzjngTQrJVtS6JBKPzTMNbH27jVRJYXViepI,270
12
+ ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=3aFDfqdC2vUAhD30kkuQltgU_WZmYDyVhKPSoEotEYo,15292
15
+ ring/algorithms/custom_joints/suntay.py,sha256=CN0q6G2bjmufNMr7eAjKwIKwXpFp9qFjqrJCey-xIYE,15858
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
17
  ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
18
  ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
79
  ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.2.dist-info/METADATA,sha256=OQmB5-CEpy-JWv2K9vYo9fUNQDW0Jg-fFU4kBkiVRGQ,3104
81
- imt_ring-1.3.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.2.dist-info/RECORD,,
80
+ imt_ring-1.3.4.dist-info/METADATA,sha256=oMI6d91TaCBndSOQJjoVBzqsYDg3oGtAxDzSmSamWmg,3104
81
+ imt_ring-1.3.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.4.dist-info/RECORD,,
@@ -1,5 +1,6 @@
1
1
  from .rr_imp_joint import register_rr_imp_joint
2
2
  from .rr_joint import register_rr_joint
3
+ from .suntay import ConstantValue_DrawnFnPair
3
4
  from .suntay import GP_DrawFnPair
4
5
  from .suntay import MLP_DrawnFnPair
5
6
  from .suntay import Polynomial_DrawnFnPair
@@ -295,14 +295,15 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
295
295
 
296
296
  def Polynomial_DrawnFnPair(
297
297
  order: int = 2,
298
- val: float = 2.0,
299
298
  center: bool = False,
300
299
  flexion_center: Optional[float] = None,
300
+ include_bias: bool = True,
301
301
  ) -> DrawnFnPairFactory:
302
- assert val >= 0.0
302
+ assert not (order == 0 and not include_bias)
303
303
 
304
304
  # because 0-th order is also counted
305
305
  order += 1
306
+ powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
306
307
 
307
308
  def factory(xs, mn, mx):
308
309
  nonlocal flexion_center
@@ -311,7 +312,7 @@ def Polynomial_DrawnFnPair(
311
312
  flexion_mx = jnp.max(xs)
312
313
 
313
314
  def _apply_poly_factors(poly_factors, q):
314
- return poly_factors @ jnp.power(q, jnp.arange(order))
315
+ return poly_factors @ jnp.power(q, powers)
315
316
 
316
317
  if flexion_center is None:
317
318
  flexion_center = (flexion_mn + flexion_mx) / 2
@@ -319,16 +320,19 @@ def Polynomial_DrawnFnPair(
319
320
  flexion_center = jnp.array(flexion_center)
320
321
 
321
322
  def init(key):
322
- c1, c2 = jax.random.split(key)
323
+ c1, c2, c3 = jax.random.split(key, 3)
323
324
  poly_factors = jax.random.uniform(
324
- c1, shape=(order,), minval=-val, maxval=val
325
+ c1, shape=(len(powers),), minval=-1.0, maxval=1.0
325
326
  )
326
327
  q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
327
328
  values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
328
329
  poly_factors, xs - q0
329
330
  )
330
- amax = jnp.max(values)
331
- amin = jnp.min(values)
331
+ eps = 1e-6
332
+ amin, amax = jnp.min(values), jnp.max(values) + eps
333
+ delta = amax - amin
334
+ scale_delta = jnp.clip(jax.random.normal(c3) + 0.5, 1.0)
335
+ amax = amin + delta * scale_delta
332
336
  return amin, amax, poly_factors, q0
333
337
 
334
338
  def _apply(params, q):
@@ -350,6 +354,22 @@ def Polynomial_DrawnFnPair(
350
354
  return factory
351
355
 
352
356
 
357
+ def ConstantValue_DrawnFnPair(value: float) -> DrawnFnPairFactory:
358
+ value = jnp.array(value)
359
+
360
+ def factory(xs, mn, mx):
361
+
362
+ def init(key):
363
+ return {}
364
+
365
+ def apply(params, q):
366
+ return value
367
+
368
+ return DrawnFnPair(init, apply)
369
+
370
+ return factory
371
+
372
+
353
373
  def MLP_DrawnFnPair(
354
374
  center: bool = False, flexion_center: Optional[float] = None
355
375
  ) -> DrawnFnPairFactory:
ring/base.py CHANGED
@@ -929,16 +929,22 @@ def _parse_system_calculate_inertia(sys: System):
929
929
  def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = False):
930
930
  assert len(args) == len(in_types)
931
931
  for in_type, arg in zip(in_types, args):
932
- B = len(arg)
932
+
933
933
  if in_type == "l":
934
- assert B == sys.num_links()
934
+ required_length = sys.num_links()
935
935
  elif in_type == "q":
936
- assert B == sys.q_size()
936
+ required_length = sys.q_size()
937
937
  elif in_type == "d":
938
- assert B == sys.qd_size()
938
+ required_length = sys.qd_size()
939
939
  else:
940
940
  raise Exception("`in_types` must be one of `l` or `q` or `d`")
941
941
 
942
+ B = len(arg)
943
+ B_re = required_length
944
+ assert (
945
+ B == B_re
946
+ ), f"arg={arg} has a length of B={B} which isn't the required length={B_re}"
947
+
942
948
  order = range(sys.num_links())
943
949
  q_idx, qd_idx = 0, 0
944
950
  q_idxs, qd_idxs = {}, {}