imt-ring 1.3.3__py3-none-any.whl → 1.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.3
3
+ Version: 1.3.5
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,21 +1,21 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=ZAoe9B1HbAX9NYiKaisssTBn-1VBXoJTsWgFAvlQoZw,33705
3
+ ring/base.py,sha256=hZZaH1kOmeV4vK2EyOQdxSKBFLZ1bluRHcj725OHo2I,33913
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
7
- ring/algorithms/_random.py,sha256=6EG0GHYe6tCq0qUt4Jes8W1EaqqaLa0sSZhnwBbEjCE,13340
7
+ ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
8
8
  ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
9
- ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
9
+ ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=0Oym3KQssj3QDOldnz9PTy5jPg9ZLk85mMK2YX1qvB4,15600
15
+ ring/algorithms/custom_joints/suntay.py,sha256=oJFU0EnT6yzlSkoA2CegPDLyb8jAk-bGZx4iuZO4nxM,16215
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
17
  ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
- ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
18
+ ring/algorithms/generator/batch.py,sha256=EOCX0vOxDwVOweArGsUneeeYysdSY2mFB55W052Wd9o,9161
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
20
  ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
79
  ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.3.dist-info/METADATA,sha256=nTihurycKYmLCI61Cojd7VLrnb1gpd-H8nwUupysaC8,3104
81
- imt_ring-1.3.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.3.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.3.dist-info/RECORD,,
80
+ imt_ring-1.3.5.dist-info/METADATA,sha256=lfUIi30c7raML41nr1zCkIhhGbXjaQuNFFs7hBgAIZ0,3104
81
+ imt_ring-1.3.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.5.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.5.dist-info/RECORD,,
@@ -4,6 +4,7 @@ import warnings
4
4
  import jax
5
5
  from jax import random
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import maths
8
9
 
9
10
  Float = jax.Array
@@ -40,20 +41,20 @@ def random_angle_over_time(
40
41
  def body_fn_outer(val):
41
42
  i, t, phi, key_t, key_ang, ANG = val
42
43
 
43
- key_t, consume = random.split(key_t)
44
- dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
45
-
46
- key_ang, consume = random.split(key_ang)
47
- phi = _resolve_range_of_motion(
44
+ key_t, consume_t = random.split(key_t)
45
+ key_ang, consume_ang = random.split(key_ang)
46
+ dt, phi = _resolve_range_of_motion(
48
47
  range_of_motion,
49
48
  range_of_motion_method,
50
49
  _to_float(dang_min, t),
51
50
  _to_float(dang_max, t),
52
51
  _to_float(delta_ang_min, t),
53
52
  _to_float(delta_ang_max, t),
54
- dt,
53
+ t_min,
54
+ _to_float(t_max, t),
55
55
  phi,
56
- consume,
56
+ consume_t,
57
+ consume_ang,
57
58
  max_iter,
58
59
  )
59
60
  t += dt
@@ -246,12 +247,14 @@ def _resolve_range_of_motion(
246
247
  dang_max,
247
248
  delta_ang_min,
248
249
  delta_ang_max,
249
- dt,
250
+ t_min,
251
+ t_max,
250
252
  prev_phi,
251
- key,
253
+ key_t,
254
+ key_ang,
252
255
  max_iter,
253
256
  ):
254
- def _next_phi(key):
257
+ def _next_phi(key, dt):
255
258
  key, consume = random.split(key)
256
259
 
257
260
  if range_of_motion:
@@ -294,21 +297,33 @@ def _resolve_range_of_motion(
294
297
  return prev_phi + sign * dphi
295
298
 
296
299
  def body_fn(val):
297
- key, _, i = val
298
- key, consume = jax.random.split(key)
299
- next_phi = _next_phi(consume)
300
- return key, next_phi, i + 1
300
+ key_t, key_ang, _, _, i = val
301
+
302
+ key_t, consume_t = jax.random.split(key_t)
303
+ dt = jax.random.uniform(consume_t, minval=t_min, maxval=t_max)
304
+
305
+ key_ang, consume_ang = jax.random.split(key_ang)
306
+ next_phi = _next_phi(consume_ang, dt)
307
+
308
+ return key_t, key_ang, dt, next_phi, i + 1
301
309
 
302
310
  def cond_fn(val):
303
- _, next_phi, i = val
311
+ *_, dt, next_phi, i = val
304
312
  delta_phi = jnp.abs(next_phi - prev_phi)
305
- # delta is in bounds
306
- break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
313
+ # delta_ang is in bounds
314
+ cond_delta_ang = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
315
+ # dang is in bounds
316
+ dang = delta_phi / dt
317
+ cond_dang = (dang >= dang_min) & (dang <= dang_max)
318
+
319
+ break_if_true1 = jnp.logical_and(cond_delta_ang, cond_dang)
320
+ # break out of loop
307
321
  break_if_true2 = i > max_iter
308
322
  return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
309
323
 
310
- # the `prev_phi` here is unused
311
- return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
324
+ init_val = (key_t, key_ang, 1.0, prev_phi, 0)
325
+ *_, dt, next_phi, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)
326
+ return dt, next_phi
312
327
 
313
328
 
314
329
  def cosInterpolate(x, xp, fp):
@@ -295,14 +295,21 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
295
295
 
296
296
  def Polynomial_DrawnFnPair(
297
297
  order: int = 2,
298
- val: float = 2.0,
299
298
  center: bool = False,
300
- flexion_center: Optional[float] = None,
299
+ flexion_center_deg: Optional[float] = None,
300
+ include_bias: bool = True,
301
+ enable_scale_delta: bool = True,
301
302
  ) -> DrawnFnPairFactory:
302
- assert val >= 0.0
303
+ assert not (order == 0 and not include_bias)
304
+
305
+ flexion_center = (
306
+ jnp.deg2rad(flexion_center_deg) if flexion_center_deg is not None else None
307
+ )
308
+ del flexion_center_deg
303
309
 
304
310
  # because 0-th order is also counted
305
311
  order += 1
312
+ powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
306
313
 
307
314
  def factory(xs, mn, mx):
308
315
  nonlocal flexion_center
@@ -311,17 +318,22 @@ def Polynomial_DrawnFnPair(
311
318
  flexion_mx = jnp.max(xs)
312
319
 
313
320
  def _apply_poly_factors(poly_factors, q):
314
- return poly_factors @ jnp.power(q, jnp.arange(order))
321
+ return poly_factors @ jnp.power(q, powers)
315
322
 
316
323
  if flexion_center is None:
317
324
  flexion_center = (flexion_mn + flexion_mx) / 2
325
+
326
+ if (order - 1) == 0:
327
+ method = "clip"
328
+ minval, maxval = mn, mx
318
329
  else:
319
- flexion_center = jnp.array(flexion_center)
330
+ method = "minmax"
331
+ minval, maxval = -1.0, 1.0
320
332
 
321
333
  def init(key):
322
- c1, c2 = jax.random.split(key)
334
+ c1, c2, c3 = jax.random.split(key, 3)
323
335
  poly_factors = jax.random.uniform(
324
- c1, shape=(order,), minval=-val, maxval=val
336
+ c1, shape=(len(powers),), minval=minval, maxval=maxval
325
337
  )
326
338
  q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
327
339
  values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
@@ -329,13 +341,17 @@ def Polynomial_DrawnFnPair(
329
341
  )
330
342
  eps = 1e-6
331
343
  amin, amax = jnp.min(values), jnp.max(values) + eps
344
+ if enable_scale_delta:
345
+ delta = amax - amin
346
+ scale_delta = jnp.clip(jax.random.normal(c3) + 0.5, 1.0)
347
+ amax = amin + delta * scale_delta
332
348
  return amin, amax, poly_factors, q0
333
349
 
334
350
  def _apply(params, q):
335
351
  amin, amax, poly_factors, q0 = params
336
352
  q = q - q0
337
353
  value = _apply_poly_factors(poly_factors, q)
338
- return restrict(value, mn, mx, amin, amax)
354
+ return restrict(value, mn, mx, amin, amax, method=method)
339
355
 
340
356
  if center:
341
357
 
@@ -97,8 +97,10 @@ def batch_generators_eager_to_list(
97
97
  for _ in range(n_calls):
98
98
  key, consume = jax.random.split(key)
99
99
  sample = gen_jit(consume)
100
- # converts also to numpy
100
+ # converts also to numpy; but with np.array.flags.writeable = False
101
101
  sample = jax.device_get(sample)
102
+ # this then sets this flag to True
103
+ sample = jax.tree_map(np.array, sample)
102
104
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
105
 
104
106
  return data
ring/algorithms/jcalc.py CHANGED
@@ -424,14 +424,14 @@ def _draw_rxyz(
424
424
  # TODO, delete these args and pass a modifified `config` with `replace` instead
425
425
  enable_range_of_motion: bool = True,
426
426
  free_spherical: bool = False,
427
+ # how often it should try to fullfill the dang_min/max and delta_ang_min/max conds
428
+ max_iter: int = 5,
427
429
  ) -> jax.Array:
428
430
  key_value, consume = jax.random.split(key_value)
429
431
  ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
430
432
  # `random_angle_over_time` always returns wrapped angles, thus it would be
431
433
  # inconsistent to allow an initial value that is not wrapped
432
434
  ANG_0 = maths.wrap_to_pi(ANG_0)
433
- # only used for `delta_ang_min_max` logic
434
- max_iter = 5
435
435
  return _random.random_angle_over_time(
436
436
  key_t,
437
437
  key_value,
ring/base.py CHANGED
@@ -573,6 +573,8 @@ class System(_Base):
573
573
  new_zero: Optional[jax.Array] = None,
574
574
  ):
575
575
  "By default damping, stiffness are set to zero."
576
+ from ring.algorithms import get_joint_model
577
+
576
578
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
577
579
 
578
580
  def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
@@ -594,7 +596,13 @@ class System(_Base):
594
596
 
595
597
  return nlt, nla, nld, nls, nlz
596
598
 
597
- return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
599
+ sys = _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
600
+
601
+ jm = get_joint_model(new_joint_type)
602
+ if jm.init_joint_params is not None:
603
+ sys = sys.from_str(sys.to_str())
604
+
605
+ return sys
598
606
 
599
607
  def findall_imus(self) -> list[str]:
600
608
  return [name for name in self.link_names if name[:3] == "imu"]