imt-ring 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.4
3
+ Version: 1.3.6
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,21 +1,21 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=ZAoe9B1HbAX9NYiKaisssTBn-1VBXoJTsWgFAvlQoZw,33705
3
+ ring/base.py,sha256=99GuspRH4QtRRJTAgyvS02FFxoaBptSsz_GPczX8vw0,33947
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
7
- ring/algorithms/_random.py,sha256=6EG0GHYe6tCq0qUt4Jes8W1EaqqaLa0sSZhnwBbEjCE,13340
7
+ ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
8
8
  ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
9
- ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
9
+ ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
11
+ ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=CN0q6G2bjmufNMr7eAjKwIKwXpFp9qFjqrJCey-xIYE,15858
15
+ ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
- ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
17
+ ring/algorithms/generator/base.py,sha256=AKH7GXEmRGV1kK8okiqa12uq0Ah9VYlqgdLw-99oFoQ,14840
18
+ ring/algorithms/generator/batch.py,sha256=EOCX0vOxDwVOweArGsUneeeYysdSY2mFB55W052Wd9o,9161
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
20
  ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
@@ -76,8 +76,8 @@ ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
76
76
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
- ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.4.dist-info/METADATA,sha256=oMI6d91TaCBndSOQJjoVBzqsYDg3oGtAxDzSmSamWmg,3104
81
- imt_ring-1.3.4.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.4.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.4.dist-info/RECORD,,
79
+ ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
80
+ imt_ring-1.3.6.dist-info/METADATA,sha256=E5mVtL-2o6-U-Ov56yd4M0RVQs0VJoLSezHpBWGtleg,3104
81
+ imt_ring-1.3.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.6.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.6.dist-info/RECORD,,
@@ -4,6 +4,7 @@ import warnings
4
4
  import jax
5
5
  from jax import random
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import maths
8
9
 
9
10
  Float = jax.Array
@@ -40,20 +41,20 @@ def random_angle_over_time(
40
41
  def body_fn_outer(val):
41
42
  i, t, phi, key_t, key_ang, ANG = val
42
43
 
43
- key_t, consume = random.split(key_t)
44
- dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
45
-
46
- key_ang, consume = random.split(key_ang)
47
- phi = _resolve_range_of_motion(
44
+ key_t, consume_t = random.split(key_t)
45
+ key_ang, consume_ang = random.split(key_ang)
46
+ dt, phi = _resolve_range_of_motion(
48
47
  range_of_motion,
49
48
  range_of_motion_method,
50
49
  _to_float(dang_min, t),
51
50
  _to_float(dang_max, t),
52
51
  _to_float(delta_ang_min, t),
53
52
  _to_float(delta_ang_max, t),
54
- dt,
53
+ t_min,
54
+ _to_float(t_max, t),
55
55
  phi,
56
- consume,
56
+ consume_t,
57
+ consume_ang,
57
58
  max_iter,
58
59
  )
59
60
  t += dt
@@ -246,12 +247,14 @@ def _resolve_range_of_motion(
246
247
  dang_max,
247
248
  delta_ang_min,
248
249
  delta_ang_max,
249
- dt,
250
+ t_min,
251
+ t_max,
250
252
  prev_phi,
251
- key,
253
+ key_t,
254
+ key_ang,
252
255
  max_iter,
253
256
  ):
254
- def _next_phi(key):
257
+ def _next_phi(key, dt):
255
258
  key, consume = random.split(key)
256
259
 
257
260
  if range_of_motion:
@@ -294,21 +297,33 @@ def _resolve_range_of_motion(
294
297
  return prev_phi + sign * dphi
295
298
 
296
299
  def body_fn(val):
297
- key, _, i = val
298
- key, consume = jax.random.split(key)
299
- next_phi = _next_phi(consume)
300
- return key, next_phi, i + 1
300
+ key_t, key_ang, _, _, i = val
301
+
302
+ key_t, consume_t = jax.random.split(key_t)
303
+ dt = jax.random.uniform(consume_t, minval=t_min, maxval=t_max)
304
+
305
+ key_ang, consume_ang = jax.random.split(key_ang)
306
+ next_phi = _next_phi(consume_ang, dt)
307
+
308
+ return key_t, key_ang, dt, next_phi, i + 1
301
309
 
302
310
  def cond_fn(val):
303
- _, next_phi, i = val
311
+ *_, dt, next_phi, i = val
304
312
  delta_phi = jnp.abs(next_phi - prev_phi)
305
- # delta is in bounds
306
- break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
313
+ # delta_ang is in bounds
314
+ cond_delta_ang = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
315
+ # dang is in bounds
316
+ dang = delta_phi / dt
317
+ cond_dang = (dang >= dang_min) & (dang <= dang_max)
318
+
319
+ break_if_true1 = jnp.logical_and(cond_delta_ang, cond_dang)
320
+ # break out of loop
307
321
  break_if_true2 = i > max_iter
308
322
  return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
309
323
 
310
- # the `prev_phi` here is unused
311
- return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
324
+ init_val = (key_t, key_ang, 1.0, prev_phi, 0)
325
+ *_, dt, next_phi, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)
326
+ return dt, next_phi
312
327
 
313
328
 
314
329
  def cosInterpolate(x, xp, fp):
@@ -293,14 +293,33 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
293
293
  ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
294
294
 
295
295
 
296
+ def _scale_delta(method: str, key, xs, mn, mx, amin, amax, **kwargs):
297
+ if method == "normal":
298
+ delta = jnp.clip(jax.random.normal(key) + 0.5, 1.0)
299
+ elif method == "uniform":
300
+ delta = 1 / (jax.random.uniform(key) + 1e-2)
301
+ else:
302
+ raise NotImplementedError
303
+
304
+ return delta
305
+
306
+
296
307
  def Polynomial_DrawnFnPair(
297
308
  order: int = 2,
298
309
  center: bool = False,
299
- flexion_center: Optional[float] = None,
310
+ flexion_center_deg: Optional[float] = None,
300
311
  include_bias: bool = True,
312
+ enable_scale_delta: bool = True,
313
+ scale_delta_method: str = "normal",
314
+ scale_delta_kwargs: dict = dict(),
301
315
  ) -> DrawnFnPairFactory:
302
316
  assert not (order == 0 and not include_bias)
303
317
 
318
+ flexion_center = (
319
+ jnp.deg2rad(flexion_center_deg) if flexion_center_deg is not None else None
320
+ )
321
+ del flexion_center_deg
322
+
304
323
  # because 0-th order is also counted
305
324
  order += 1
306
325
  powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
@@ -316,13 +335,18 @@ def Polynomial_DrawnFnPair(
316
335
 
317
336
  if flexion_center is None:
318
337
  flexion_center = (flexion_mn + flexion_mx) / 2
338
+
339
+ if (order - 1) == 0:
340
+ method = "clip"
341
+ minval, maxval = mn, mx
319
342
  else:
320
- flexion_center = jnp.array(flexion_center)
343
+ method = "minmax"
344
+ minval, maxval = -1.0, 1.0
321
345
 
322
346
  def init(key):
323
347
  c1, c2, c3 = jax.random.split(key, 3)
324
348
  poly_factors = jax.random.uniform(
325
- c1, shape=(len(powers),), minval=-1.0, maxval=1.0
349
+ c1, shape=(len(powers),), minval=minval, maxval=maxval
326
350
  )
327
351
  q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
328
352
  values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
@@ -330,16 +354,19 @@ def Polynomial_DrawnFnPair(
330
354
  )
331
355
  eps = 1e-6
332
356
  amin, amax = jnp.min(values), jnp.max(values) + eps
333
- delta = amax - amin
334
- scale_delta = jnp.clip(jax.random.normal(c3) + 0.5, 1.0)
335
- amax = amin + delta * scale_delta
357
+ if enable_scale_delta:
358
+ delta = amax - amin
359
+ scale_delta = _scale_delta(
360
+ scale_delta_method, c3, xs, mn, mx, amin, amax, **scale_delta_kwargs
361
+ )
362
+ amax = amin + delta * scale_delta
336
363
  return amin, amax, poly_factors, q0
337
364
 
338
365
  def _apply(params, q):
339
366
  amin, amax, poly_factors, q0 = params
340
367
  q = q - q0
341
368
  value = _apply_poly_factors(poly_factors, q)
342
- return restrict(value, mn, mx, amin, amax)
369
+ return restrict(value, mn, mx, amin, amax, method=method)
343
370
 
344
371
  if center:
345
372
 
@@ -140,8 +140,11 @@ class RCMG:
140
140
  batchsize: int = 1,
141
141
  sizes: int | list[int] = 1,
142
142
  seed: int = 1,
143
+ shuffle: bool = True,
143
144
  ) -> types.BatchedGenerator:
144
- return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
145
+ return batch.batch_generators_eager(
146
+ self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
147
+ )
145
148
 
146
149
  def to_lazy_gen(
147
150
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -97,8 +97,10 @@ def batch_generators_eager_to_list(
97
97
  for _ in range(n_calls):
98
98
  key, consume = jax.random.split(key)
99
99
  sample = gen_jit(consume)
100
- # converts also to numpy
100
+ # converts also to numpy; but with np.array.flags.writeable = False
101
101
  sample = jax.device_get(sample)
102
+ # this then sets this flag to True
103
+ sample = jax.tree_map(np.array, sample)
102
104
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
105
 
104
106
  return data
ring/algorithms/jcalc.py CHANGED
@@ -424,14 +424,14 @@ def _draw_rxyz(
424
424
  # TODO, delete these args and pass a modifified `config` with `replace` instead
425
425
  enable_range_of_motion: bool = True,
426
426
  free_spherical: bool = False,
427
+ # how often it should try to fullfill the dang_min/max and delta_ang_min/max conds
428
+ max_iter: int = 5,
427
429
  ) -> jax.Array:
428
430
  key_value, consume = jax.random.split(key_value)
429
431
  ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
430
432
  # `random_angle_over_time` always returns wrapped angles, thus it would be
431
433
  # inconsistent to allow an initial value that is not wrapped
432
434
  ANG_0 = maths.wrap_to_pi(ANG_0)
433
- # only used for `delta_ang_min_max` logic
434
- max_iter = 5
435
435
  return _random.random_angle_over_time(
436
436
  key_t,
437
437
  key_value,
@@ -3,6 +3,7 @@ from typing import Optional
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+
6
7
  from ring import algebra
7
8
  from ring import algorithms
8
9
  from ring import base
@@ -445,22 +446,18 @@ def _joint_axes_from_xs(sys, xs, sys_xs):
445
446
 
446
447
  def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
447
448
  "`sys` should be `sys_noimu`. `N` is number of timesteps"
448
- xaxis = jnp.array([1.0, 0, 0])
449
- yaxis = jnp.array([0.0, 1, 0])
450
- zaxis = jnp.array([0.0, 0, 1])
451
- id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
452
449
  X = {}
453
450
 
454
451
  def f(_, __, name, link_type, link):
455
452
  joint_params = link.joint_params
456
453
  if link_type in ["rx", "ry", "rz"]:
457
- joint_axes = id_to_axis[link_type[1]]
454
+ joint_axes = maths.unit_vectors(link_type[1])
458
455
  elif link_type == "rr":
459
456
  joint_axes = joint_params["rr"]["joint_axes"]
460
457
  elif link_type[:6] == "rr_imp":
461
458
  joint_axes = joint_params[link_type]["joint_axes"]
462
459
  else:
463
- joint_axes = xaxis
460
+ joint_axes = maths.x_unit_vector
464
461
  X[name] = {"joint_axes": joint_axes}
465
462
 
466
463
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
ring/base.py CHANGED
@@ -571,8 +571,11 @@ class System(_Base):
571
571
  new_damp: Optional[jax.Array] = None,
572
572
  new_stif: Optional[jax.Array] = None,
573
573
  new_zero: Optional[jax.Array] = None,
574
+ seed: int = 1,
574
575
  ):
575
576
  "By default damping, stiffness are set to zero."
577
+ from ring.algorithms import get_joint_model
578
+
576
579
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
577
580
 
578
581
  def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
@@ -594,7 +597,13 @@ class System(_Base):
594
597
 
595
598
  return nlt, nla, nld, nls, nlz
596
599
 
597
- return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
600
+ sys = _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
601
+
602
+ jm = get_joint_model(new_joint_type)
603
+ if jm.init_joint_params is not None:
604
+ sys = sys.from_str(sys.to_str(), seed=seed)
605
+
606
+ return sys
598
607
 
599
608
  def findall_imus(self) -> list[str]:
600
609
  return [name for name in self.link_names if name[:3] == "imu"]
ring/utils/utils.py CHANGED
@@ -122,6 +122,8 @@ def pytree_deepcopy(tree):
122
122
  return tuple(pytree_deepcopy(ele) for ele in tree)
123
123
  elif isinstance(tree, dict):
124
124
  return {key: pytree_deepcopy(value) for key, value in tree.items()}
125
+ elif isinstance(tree, _Base):
126
+ return tree
125
127
  else:
126
128
  raise NotImplementedError(f"Not implemented for type={type(tree)}")
127
129