imt-ring 1.3.5__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.5
3
+ Version: 1.3.7
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=hZZaH1kOmeV4vK2EyOQdxSKBFLZ1bluRHcj725OHo2I,33913
3
+ ring/base.py,sha256=99GuspRH4QtRRJTAgyvS02FFxoaBptSsz_GPczX8vw0,33947
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
@@ -8,13 +8,13 @@ ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13
8
8
  ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
9
9
  ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
11
+ ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=oJFU0EnT6yzlSkoA2CegPDLyb8jAk-bGZx4iuZO4nxM,16215
15
+ ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
17
+ ring/algorithms/generator/base.py,sha256=AKH7GXEmRGV1kK8okiqa12uq0Ah9VYlqgdLw-99oFoQ,14840
18
18
  ring/algorithms/generator/batch.py,sha256=EOCX0vOxDwVOweArGsUneeeYysdSY2mFB55W052Wd9o,9161
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
20
  ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
- ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
59
+ ring/ml/train.py,sha256=uDW6JMdbMcjUKr3wCL2drWzDUd0Pc3BoroUwLcYoUx4,10914
60
60
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
61
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
62
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
@@ -76,8 +76,8 @@ ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
76
76
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
77
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
78
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
- ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.5.dist-info/METADATA,sha256=lfUIi30c7raML41nr1zCkIhhGbXjaQuNFFs7hBgAIZ0,3104
81
- imt_ring-1.3.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.5.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.5.dist-info/RECORD,,
79
+ ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
80
+ imt_ring-1.3.7.dist-info/METADATA,sha256=V6Oow_ZZwpBuHuIbyPIoKFtrhFboxMmuIPx1Rilq3-A,3104
81
+ imt_ring-1.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ imt_ring-1.3.7.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
+ imt_ring-1.3.7.dist-info/RECORD,,
@@ -293,12 +293,25 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
293
293
  ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
294
294
 
295
295
 
296
+ def _scale_delta(method: str, key, xs, mn, mx, amin, amax, **kwargs):
297
+ if method == "normal":
298
+ delta = jnp.clip(jax.random.normal(key) + 0.5, 1.0)
299
+ elif method == "uniform":
300
+ delta = 1 / (jax.random.uniform(key) + 1e-2)
301
+ else:
302
+ raise NotImplementedError
303
+
304
+ return delta
305
+
306
+
296
307
  def Polynomial_DrawnFnPair(
297
308
  order: int = 2,
298
309
  center: bool = False,
299
310
  flexion_center_deg: Optional[float] = None,
300
311
  include_bias: bool = True,
301
312
  enable_scale_delta: bool = True,
313
+ scale_delta_method: str = "normal",
314
+ scale_delta_kwargs: dict = dict(),
302
315
  ) -> DrawnFnPairFactory:
303
316
  assert not (order == 0 and not include_bias)
304
317
 
@@ -343,7 +356,9 @@ def Polynomial_DrawnFnPair(
343
356
  amin, amax = jnp.min(values), jnp.max(values) + eps
344
357
  if enable_scale_delta:
345
358
  delta = amax - amin
346
- scale_delta = jnp.clip(jax.random.normal(c3) + 0.5, 1.0)
359
+ scale_delta = _scale_delta(
360
+ scale_delta_method, c3, xs, mn, mx, amin, amax, **scale_delta_kwargs
361
+ )
347
362
  amax = amin + delta * scale_delta
348
363
  return amin, amax, poly_factors, q0
349
364
 
@@ -140,8 +140,11 @@ class RCMG:
140
140
  batchsize: int = 1,
141
141
  sizes: int | list[int] = 1,
142
142
  seed: int = 1,
143
+ shuffle: bool = True,
143
144
  ) -> types.BatchedGenerator:
144
- return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
145
+ return batch.batch_generators_eager(
146
+ self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
147
+ )
145
148
 
146
149
  def to_lazy_gen(
147
150
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -3,6 +3,7 @@ from typing import Optional
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+
6
7
  from ring import algebra
7
8
  from ring import algorithms
8
9
  from ring import base
@@ -445,22 +446,18 @@ def _joint_axes_from_xs(sys, xs, sys_xs):
445
446
 
446
447
  def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
447
448
  "`sys` should be `sys_noimu`. `N` is number of timesteps"
448
- xaxis = jnp.array([1.0, 0, 0])
449
- yaxis = jnp.array([0.0, 1, 0])
450
- zaxis = jnp.array([0.0, 0, 1])
451
- id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
452
449
  X = {}
453
450
 
454
451
  def f(_, __, name, link_type, link):
455
452
  joint_params = link.joint_params
456
453
  if link_type in ["rx", "ry", "rz"]:
457
- joint_axes = id_to_axis[link_type[1]]
454
+ joint_axes = maths.unit_vectors(link_type[1])
458
455
  elif link_type == "rr":
459
456
  joint_axes = joint_params["rr"]["joint_axes"]
460
457
  elif link_type[:6] == "rr_imp":
461
458
  joint_axes = joint_params[link_type]["joint_axes"]
462
459
  else:
463
- joint_axes = xaxis
460
+ joint_axes = maths.x_unit_vector
464
461
  X[name] = {"joint_axes": joint_axes}
465
462
 
466
463
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
ring/base.py CHANGED
@@ -571,6 +571,7 @@ class System(_Base):
571
571
  new_damp: Optional[jax.Array] = None,
572
572
  new_stif: Optional[jax.Array] = None,
573
573
  new_zero: Optional[jax.Array] = None,
574
+ seed: int = 1,
574
575
  ):
575
576
  "By default damping, stiffness are set to zero."
576
577
  from ring.algorithms import get_joint_model
@@ -600,7 +601,7 @@ class System(_Base):
600
601
 
601
602
  jm = get_joint_model(new_joint_type)
602
603
  if jm.init_joint_params is not None:
603
- sys = sys.from_str(sys.to_str())
604
+ sys = sys.from_str(sys.to_str(), seed=seed)
604
605
 
605
606
  return sys
606
607
 
ring/ml/train.py CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import optax
8
+ import tree_utils
9
+
8
10
  from ring import maths
9
11
  from ring.algorithms.generator import types
10
12
  from ring.ml import base as ml_base
@@ -15,8 +17,6 @@ from ring.utils import distribute_batchsize
15
17
  from ring.utils import expand_batchsize
16
18
  from ring.utils import parse_path
17
19
  from ring.utils import pickle_load
18
- import tree_utils
19
-
20
20
  import wandb
21
21
 
22
22
  # (T, N, F) -> Scalar
@@ -142,15 +142,17 @@ def train_fn(
142
142
  Wether or not the training run was killed by a callback.
143
143
  """
144
144
 
145
+ filter = filter.nojit()
146
+
145
147
  if checkpoint is not None:
146
148
  checkpoint = Path(checkpoint).with_suffix(".pickle")
147
149
  recv_checkpoint: dict = pickle_load(checkpoint)
148
- filter.params = recv_checkpoint["params"]
150
+ filter_params = recv_checkpoint["params"]
149
151
  opt_state = recv_checkpoint["opt_state"]
152
+ del recv_checkpoint
153
+ else:
154
+ filter_params = filter.search_attr("params")
150
155
 
151
- filter = filter.nojit()
152
-
153
- filter_params = filter.search_attr("params")
154
156
  if filter_params is None:
155
157
  X, _ = generator(jax.random.PRNGKey(1))
156
158
  filter_params, _ = filter.init(X=X, seed=seed_network)
ring/utils/utils.py CHANGED
@@ -122,6 +122,8 @@ def pytree_deepcopy(tree):
122
122
  return tuple(pytree_deepcopy(ele) for ele in tree)
123
123
  elif isinstance(tree, dict):
124
124
  return {key: pytree_deepcopy(value) for key, value in tree.items()}
125
+ elif isinstance(tree, _Base):
126
+ return tree
125
127
  else:
126
128
  raise NotImplementedError(f"Not implemented for type={type(tree)}")
127
129