imt-ring 1.3.3__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.3
3
+ Version: 1.3.11
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,23 +1,23 @@
1
1
  ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=ZAoe9B1HbAX9NYiKaisssTBn-1VBXoJTsWgFAvlQoZw,33705
3
+ ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
7
- ring/algorithms/_random.py,sha256=6EG0GHYe6tCq0qUt4Jes8W1EaqqaLa0sSZhnwBbEjCE,13340
8
- ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
9
- ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
7
+ ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
8
+ ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
9
+ ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
11
+ ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=0Oym3KQssj3QDOldnz9PTy5jPg9ZLk85mMK2YX1qvB4,15600
15
+ ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
- ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
18
- ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
19
- ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
20
- ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
17
+ ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
18
+ ring/algorithms/generator/batch.py,sha256=bslFSN2Gs_aX9cNwFooExhKUwevc70q3bspEMTwygm4,9256
19
+ ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
20
+ ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
22
22
  ring/algorithms/generator/transforms.py,sha256=nvNDvM20tEw9Zd0ra0TxA25uf01L40Y2UKvtQmOrlGo,12782
23
23
  ring/algorithms/generator/types.py,sha256=CAhvDK5qiHnrGtkCVccB07doiz_D6lHJ35B7sW0pyZA,1110
@@ -50,13 +50,14 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
53
+ ring/ml/__init__.py,sha256=52LpEjni5lG-ov5-3ocodH-vKZxNcFMU7W9XfjDicp0,2113
54
54
  ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
55
- ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
55
+ ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
- ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
59
- ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
58
+ ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
59
+ ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
60
+ ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
60
61
  ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
61
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
62
63
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
@@ -70,14 +71,15 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
70
71
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
71
72
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
72
73
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
73
- ring/utils/__init__.py,sha256=rTvSA4RiJAVCY_A64FUMd8IJTv94LgoSA3Ps5X63_jA,799
74
- ring/utils/batchsize.py,sha256=mPFGD7AedFMycHtyIuZtNWCaAvKLLWSWaB7X6u54xvM,1358
74
+ ring/utils/__init__.py,sha256=FZ9ziQrWlx16QIpQ8RdLKrvN_17CAdvnZMNNodxWY0o,812
75
+ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
76
+ ring/utils/batchsize.py,sha256=ByXGX7bw2gwrVirEsazm2JXnwNPGnqgEirzziYoSUS0,1553
75
77
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
76
78
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
77
79
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
78
80
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
79
- ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
80
- imt_ring-1.3.3.dist-info/METADATA,sha256=nTihurycKYmLCI61Cojd7VLrnb1gpd-H8nwUupysaC8,3104
81
- imt_ring-1.3.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
- imt_ring-1.3.3.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
83
- imt_ring-1.3.3.dist-info/RECORD,,
81
+ ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
82
+ imt_ring-1.3.11.dist-info/METADATA,sha256=5-bAEaCMi5JiwNMOEshFJlV1ISU3840hImimdj7l2CU,3105
83
+ imt_ring-1.3.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
84
+ imt_ring-1.3.11.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
85
+ imt_ring-1.3.11.dist-info/RECORD,,
@@ -4,6 +4,7 @@ import warnings
4
4
  import jax
5
5
  from jax import random
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import maths
8
9
 
9
10
  Float = jax.Array
@@ -40,20 +41,20 @@ def random_angle_over_time(
40
41
  def body_fn_outer(val):
41
42
  i, t, phi, key_t, key_ang, ANG = val
42
43
 
43
- key_t, consume = random.split(key_t)
44
- dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
45
-
46
- key_ang, consume = random.split(key_ang)
47
- phi = _resolve_range_of_motion(
44
+ key_t, consume_t = random.split(key_t)
45
+ key_ang, consume_ang = random.split(key_ang)
46
+ dt, phi = _resolve_range_of_motion(
48
47
  range_of_motion,
49
48
  range_of_motion_method,
50
49
  _to_float(dang_min, t),
51
50
  _to_float(dang_max, t),
52
51
  _to_float(delta_ang_min, t),
53
52
  _to_float(delta_ang_max, t),
54
- dt,
53
+ t_min,
54
+ _to_float(t_max, t),
55
55
  phi,
56
- consume,
56
+ consume_t,
57
+ consume_ang,
57
58
  max_iter,
58
59
  )
59
60
  t += dt
@@ -246,12 +247,14 @@ def _resolve_range_of_motion(
246
247
  dang_max,
247
248
  delta_ang_min,
248
249
  delta_ang_max,
249
- dt,
250
+ t_min,
251
+ t_max,
250
252
  prev_phi,
251
- key,
253
+ key_t,
254
+ key_ang,
252
255
  max_iter,
253
256
  ):
254
- def _next_phi(key):
257
+ def _next_phi(key, dt):
255
258
  key, consume = random.split(key)
256
259
 
257
260
  if range_of_motion:
@@ -294,21 +297,33 @@ def _resolve_range_of_motion(
294
297
  return prev_phi + sign * dphi
295
298
 
296
299
  def body_fn(val):
297
- key, _, i = val
298
- key, consume = jax.random.split(key)
299
- next_phi = _next_phi(consume)
300
- return key, next_phi, i + 1
300
+ key_t, key_ang, _, _, i = val
301
+
302
+ key_t, consume_t = jax.random.split(key_t)
303
+ dt = jax.random.uniform(consume_t, minval=t_min, maxval=t_max)
304
+
305
+ key_ang, consume_ang = jax.random.split(key_ang)
306
+ next_phi = _next_phi(consume_ang, dt)
307
+
308
+ return key_t, key_ang, dt, next_phi, i + 1
301
309
 
302
310
  def cond_fn(val):
303
- _, next_phi, i = val
311
+ *_, dt, next_phi, i = val
304
312
  delta_phi = jnp.abs(next_phi - prev_phi)
305
- # delta is in bounds
306
- break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
313
+ # delta_ang is in bounds
314
+ cond_delta_ang = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
315
+ # dang is in bounds
316
+ dang = delta_phi / dt
317
+ cond_dang = (dang >= dang_min) & (dang <= dang_max)
318
+
319
+ break_if_true1 = jnp.logical_and(cond_delta_ang, cond_dang)
320
+ # break out of loop
307
321
  break_if_true2 = i > max_iter
308
322
  return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
309
323
 
310
- # the `prev_phi` here is unused
311
- return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
324
+ init_val = (key_t, key_ang, 1.0, prev_phi, 0)
325
+ *_, dt, next_phi, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)
326
+ return dt, next_phi
312
327
 
313
328
 
314
329
  def cosInterpolate(x, xp, fp):
@@ -293,16 +293,36 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
293
293
  ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
294
294
 
295
295
 
296
+ def _scale_delta(method: str, key, xs, mn, mx, amin, amax, **kwargs):
297
+ if method == "normal":
298
+ delta = jnp.clip(jax.random.normal(key) + 0.5, 1.0)
299
+ elif method == "uniform":
300
+ delta = 1 / (jax.random.uniform(key) + 1e-2)
301
+ else:
302
+ raise NotImplementedError
303
+
304
+ return delta
305
+
306
+
296
307
  def Polynomial_DrawnFnPair(
297
308
  order: int = 2,
298
- val: float = 2.0,
299
309
  center: bool = False,
300
- flexion_center: Optional[float] = None,
310
+ flexion_center_deg: Optional[float] = None,
311
+ include_bias: bool = True,
312
+ enable_scale_delta: bool = True,
313
+ scale_delta_method: str = "normal",
314
+ scale_delta_kwargs: dict = dict(),
301
315
  ) -> DrawnFnPairFactory:
302
- assert val >= 0.0
316
+ assert not (order == 0 and not include_bias)
317
+
318
+ flexion_center = (
319
+ jnp.deg2rad(flexion_center_deg) if flexion_center_deg is not None else None
320
+ )
321
+ del flexion_center_deg
303
322
 
304
323
  # because 0-th order is also counted
305
324
  order += 1
325
+ powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
306
326
 
307
327
  def factory(xs, mn, mx):
308
328
  nonlocal flexion_center
@@ -311,17 +331,22 @@ def Polynomial_DrawnFnPair(
311
331
  flexion_mx = jnp.max(xs)
312
332
 
313
333
  def _apply_poly_factors(poly_factors, q):
314
- return poly_factors @ jnp.power(q, jnp.arange(order))
334
+ return poly_factors @ jnp.power(q, powers)
315
335
 
316
336
  if flexion_center is None:
317
337
  flexion_center = (flexion_mn + flexion_mx) / 2
338
+
339
+ if (order - 1) == 0:
340
+ method = "clip"
341
+ minval, maxval = mn, mx
318
342
  else:
319
- flexion_center = jnp.array(flexion_center)
343
+ method = "minmax"
344
+ minval, maxval = -1.0, 1.0
320
345
 
321
346
  def init(key):
322
- c1, c2 = jax.random.split(key)
347
+ c1, c2, c3 = jax.random.split(key, 3)
323
348
  poly_factors = jax.random.uniform(
324
- c1, shape=(order,), minval=-val, maxval=val
349
+ c1, shape=(len(powers),), minval=minval, maxval=maxval
325
350
  )
326
351
  q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
327
352
  values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
@@ -329,13 +354,19 @@ def Polynomial_DrawnFnPair(
329
354
  )
330
355
  eps = 1e-6
331
356
  amin, amax = jnp.min(values), jnp.max(values) + eps
357
+ if enable_scale_delta:
358
+ delta = amax - amin
359
+ scale_delta = _scale_delta(
360
+ scale_delta_method, c3, xs, mn, mx, amin, amax, **scale_delta_kwargs
361
+ )
362
+ amax = amin + delta * scale_delta
332
363
  return amin, amax, poly_factors, q0
333
364
 
334
365
  def _apply(params, q):
335
366
  amin, amax, poly_factors, q0 = params
336
367
  q = q - q0
337
368
  value = _apply_poly_factors(poly_factors, q)
338
- return restrict(value, mn, mx, amin, amax)
369
+ return restrict(value, mn, mx, amin, amax, method=method)
339
370
 
340
371
  if center:
341
372
 
@@ -1,7 +1,9 @@
1
1
  from typing import Optional, Tuple
2
+ import warnings
2
3
 
3
4
  import jax
4
5
  import jax.numpy as jnp
6
+
5
7
  from ring import algebra
6
8
  from ring import base
7
9
  from ring import maths
@@ -213,7 +215,7 @@ def forward_dynamics(
213
215
  q: jax.Array,
214
216
  qd: jax.Array,
215
217
  tau: jax.Array,
216
- mass_mat_inv: jax.Array,
218
+ # mass_mat_inv: jax.Array,
217
219
  ) -> Tuple[jax.Array, jax.Array]:
218
220
  C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
219
221
  mass_matrix = compute_mass_matrix(sys)
@@ -235,6 +237,11 @@ def forward_dynamics(
235
237
 
236
238
  mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
237
239
  else:
240
+ warnings.warn(
241
+ f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
242
+ "This feature is currently not fully supported. See the local TODO."
243
+ )
244
+ mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
238
245
  mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
239
246
 
240
247
  return mass_mat_inv @ qf_smooth, mass_mat_inv
@@ -254,9 +261,8 @@ def _strapdown_integration(
254
261
  def _semi_implicit_euler_integration(
255
262
  sys: base.System, state: base.State, taus: jax.Array
256
263
  ) -> base.State:
257
- qdd, mass_mat_inv = forward_dynamics(
258
- sys, state.q, state.qd, taus, state.mass_mat_inv
259
- )
264
+ qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
265
+ del mass_mat_inv
260
266
  qd_next = state.qd + sys.dt * qdd
261
267
 
262
268
  q_next = []
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
277
283
  sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
278
284
  q_next = jnp.concatenate(q_next)
279
285
 
280
- state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
286
+ state = state.replace(q=q_next, qd=qd_next)
281
287
  return state
282
288
 
283
289
 
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tqdm
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -83,10 +84,14 @@ class RCMG:
83
84
  ), "If `randomize_anchors`, then only one system is expected"
84
85
  sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
85
86
 
86
- zip_sys_config = False
87
87
  if randomize_hz:
88
- zip_sys_config = True
89
88
  sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
+ else:
90
+ # create zip
91
+ N_sys = len(sys)
92
+ sys = sum([len(config) * [s] for s in sys], start=[])
93
+ config = N_sys * config
94
+ assert len(sys) == len(config)
90
95
 
91
96
  if sys_ml is None:
92
97
  # TODO
@@ -97,17 +102,10 @@ class RCMG:
97
102
  sys_ml = sys[0]
98
103
 
99
104
  self.gens = []
100
- if zip_sys_config:
101
- for _sys, _config in zip(sys, config):
102
- self.gens.append(
103
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
104
- )
105
- else:
106
- for _sys in sys:
107
- for _config in config:
108
- self.gens.append(
109
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
110
- )
105
+ for _sys, _config in tqdm.tqdm(
106
+ zip(sys, config), desc="building generators", total=len(sys)
107
+ ):
108
+ self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
111
109
 
112
110
  def _to_data(self, sizes, seed):
113
111
  return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
@@ -140,8 +138,11 @@ class RCMG:
140
138
  batchsize: int = 1,
141
139
  sizes: int | list[int] = 1,
142
140
  seed: int = 1,
141
+ shuffle: bool = True,
143
142
  ) -> types.BatchedGenerator:
144
- return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
143
+ return batch.batch_generators_eager(
144
+ self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
145
+ )
145
146
 
146
147
  def to_lazy_gen(
147
148
  self, sizes: int | list[int] = 1, jit: bool = True
@@ -63,12 +63,12 @@ def batch_generators_lazy(
63
63
 
64
64
 
65
65
  def _number_of_executions_required(size: int) -> int:
66
- vmap_threshold = 128
67
66
  _, vmap = utils.distribute_batchsize(size)
68
67
 
68
+ eager_threshold = utils.batchsize_thresholds()[1]
69
69
  primes = iter(utils.primes(vmap))
70
70
  n_calls = 1
71
- while vmap > vmap_threshold:
71
+ while vmap > eager_threshold:
72
72
  prime = next(primes)
73
73
  n_calls *= prime
74
74
  vmap /= prime
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
86
86
 
87
87
  key = jax.random.PRNGKey(seed)
88
88
  data = []
89
- for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
89
+ for gen, size in tqdm(
90
+ zip(generators, sizes),
91
+ desc="executing generators",
92
+ total=len(sizes),
93
+ ):
90
94
 
91
95
  n_calls = _number_of_executions_required(size)
92
96
  # decrease size by n_calls times
@@ -97,8 +101,10 @@ def batch_generators_eager_to_list(
97
101
  for _ in range(n_calls):
98
102
  key, consume = jax.random.split(key)
99
103
  sample = gen_jit(consume)
100
- # converts also to numpy
104
+ # converts also to numpy; but with np.array.flags.writeable = False
101
105
  sample = jax.device_get(sample)
106
+ # this then sets this flag to True
107
+ sample = jax.tree_map(np.array, sample)
102
108
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
103
109
 
104
110
  return data
@@ -145,7 +151,7 @@ def _data_fn_from_paths(
145
151
  paths = [utils.parse_path(p, mkdir=False) for p in paths]
146
152
 
147
153
  extensions = list(set([Path(p).suffix for p in paths]))
148
- assert len(extensions) == 1
154
+ assert len(extensions) == 1, f"{extensions}"
149
155
 
150
156
  if extensions[0] == ".h5":
151
157
  N = sum([utils.hdf5_load_length(p) for p in paths])
@@ -49,6 +49,7 @@ def inject_subsystems(
49
49
  rotational_damp: float = 0.1,
50
50
  translational_stif: float = 50.0,
51
51
  translational_damp: float = 0.1,
52
+ disable_warning: bool = False,
52
53
  **kwargs,
53
54
  ) -> base.System:
54
55
  imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
@@ -92,10 +93,11 @@ def inject_subsystems(
92
93
  # TODO set all joint_params to zeros; they can not be preserved anyways and
93
94
  # otherwise many warnings will be rose
94
95
  # instead warn explicitly once now and move on
95
- warnings.warn(
96
- "`sys.links.joint_params` has been set to zero, this might lead to "
97
- "unexpected behaviour unless you use `randomize_joint_params`"
98
- )
96
+ if not disable_warning:
97
+ warnings.warn(
98
+ "`sys.links.joint_params` has been set to zero, this might lead to "
99
+ "unexpected behaviour unless you use `randomize_joint_params`"
100
+ )
99
101
  joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
100
102
  sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
101
103
 
@@ -180,9 +182,13 @@ def setup_fn_randomize_damping_stiffness_factory(
180
182
  link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
181
183
  link_damping = link_damping.at[slice].set(damp)
182
184
 
183
- assert len(imus_surely_rigid) == len(triggered_surely_rigid)
185
+ assert len(imus_surely_rigid) == len(
186
+ triggered_surely_rigid
187
+ ), f"{imus_surely_rigid}, {triggered_surely_rigid}"
184
188
  for imu_surely_rigid in imus_surely_rigid:
185
- assert imu_surely_rigid in triggered_surely_rigid
189
+ assert (
190
+ imu_surely_rigid in triggered_surely_rigid
191
+ ), f"{imus_surely_rigid} not in {triggered_surely_rigid}"
186
192
 
187
193
  return sys.replace(
188
194
  link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  from flax import struct
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import base
8
9
  from ring.algorithms import dynamics
9
10
  from ring.algorithms import jcalc
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
49
50
  assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
50
51
  assert sys.qd_size() == P.size
51
52
  if D is not None:
52
- sys.qd_size() == D.size
53
+ assert sys.qd_size() == D.size
53
54
 
54
55
  q_ref_as_dict = {}
55
56
  qd_ref_as_dict = {}
ring/algorithms/jcalc.py CHANGED
@@ -424,14 +424,14 @@ def _draw_rxyz(
424
424
  # TODO, delete these args and pass a modifified `config` with `replace` instead
425
425
  enable_range_of_motion: bool = True,
426
426
  free_spherical: bool = False,
427
+ # how often it should try to fullfill the dang_min/max and delta_ang_min/max conds
428
+ max_iter: int = 5,
427
429
  ) -> jax.Array:
428
430
  key_value, consume = jax.random.split(key_value)
429
431
  ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
430
432
  # `random_angle_over_time` always returns wrapped angles, thus it would be
431
433
  # inconsistent to allow an initial value that is not wrapped
432
434
  ANG_0 = maths.wrap_to_pi(ANG_0)
433
- # only used for `delta_ang_min_max` logic
434
- max_iter = 5
435
435
  return _random.random_angle_over_time(
436
436
  key_t,
437
437
  key_value,
@@ -3,6 +3,7 @@ from typing import Optional
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+
6
7
  from ring import algebra
7
8
  from ring import algorithms
8
9
  from ring import base
@@ -445,22 +446,18 @@ def _joint_axes_from_xs(sys, xs, sys_xs):
445
446
 
446
447
  def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
447
448
  "`sys` should be `sys_noimu`. `N` is number of timesteps"
448
- xaxis = jnp.array([1.0, 0, 0])
449
- yaxis = jnp.array([0.0, 1, 0])
450
- zaxis = jnp.array([0.0, 0, 1])
451
- id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
452
449
  X = {}
453
450
 
454
451
  def f(_, __, name, link_type, link):
455
452
  joint_params = link.joint_params
456
453
  if link_type in ["rx", "ry", "rz"]:
457
- joint_axes = id_to_axis[link_type[1]]
454
+ joint_axes = maths.unit_vectors(link_type[1])
458
455
  elif link_type == "rr":
459
456
  joint_axes = joint_params["rr"]["joint_axes"]
460
457
  elif link_type[:6] == "rr_imp":
461
458
  joint_axes = joint_params[link_type]["joint_axes"]
462
459
  else:
463
- joint_axes = xaxis
460
+ joint_axes = maths.x_unit_vector
464
461
  X[name] = {"joint_axes": joint_axes}
465
462
 
466
463
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
ring/base.py CHANGED
@@ -571,8 +571,11 @@ class System(_Base):
571
571
  new_damp: Optional[jax.Array] = None,
572
572
  new_stif: Optional[jax.Array] = None,
573
573
  new_zero: Optional[jax.Array] = None,
574
+ seed: int = 1,
574
575
  ):
575
576
  "By default damping, stiffness are set to zero."
577
+ from ring.algorithms import get_joint_model
578
+
576
579
  q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
577
580
 
578
581
  def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
@@ -594,7 +597,13 @@ class System(_Base):
594
597
 
595
598
  return nlt, nla, nld, nls, nlz
596
599
 
597
- return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
600
+ sys = _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
601
+
602
+ jm = get_joint_model(new_joint_type)
603
+ if jm.init_joint_params is not None:
604
+ sys = sys.from_str(sys.to_str(), seed=seed)
605
+
606
+ return sys
598
607
 
599
608
  def findall_imus(self) -> list[str]:
600
609
  return [name for name in self.link_names if name[:3] == "imu"]
@@ -988,13 +997,11 @@ class State(_Base):
988
997
  q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
989
998
  qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
990
999
  x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
991
- mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
992
1000
  """
993
1001
 
994
1002
  q: jax.Array
995
1003
  qd: jax.Array
996
1004
  x: Transform
997
- mass_mat_inv: jax.Array
998
1005
 
999
1006
  @classmethod
1000
1007
  def create(
@@ -1048,4 +1055,4 @@ class State(_Base):
1048
1055
  if x is None:
1049
1056
  x = Transform.zero((sys.num_links(),))
1050
1057
 
1051
- return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))
1058
+ return cls(q, qd, x)
ring/ml/__init__.py CHANGED
@@ -3,6 +3,7 @@ from . import callbacks
3
3
  from . import ml_utils
4
4
  from . import optimizer
5
5
  from . import ringnet
6
+ from . import rnno_v1
6
7
  from . import train
7
8
  from . import training_loop
8
9
  from .base import AbstractFilter
@@ -42,17 +43,28 @@ def RNNO(
42
43
  params=None,
43
44
  eval: bool = True,
44
45
  samp_freq: float | None = None,
46
+ v1: bool = False,
45
47
  **kwargs,
46
48
  ):
47
49
  assert "message_dim" not in kwargs
48
50
  assert "link_output_normalize" not in kwargs
49
51
  assert "link_output_dim" not in kwargs
50
52
 
53
+ if v1:
54
+ kwargs.update(
55
+ dict(forward_factory=rnno_v1.rnno_v1_forward_factory, output_dim=output_dim)
56
+ )
57
+ else:
58
+ kwargs.update(
59
+ dict(
60
+ message_dim=0,
61
+ link_output_normalize=False,
62
+ link_output_dim=output_dim,
63
+ )
64
+ )
65
+
51
66
  ringnet = RING( # noqa: F811
52
67
  params=params,
53
- message_dim=0,
54
- link_output_normalize=False,
55
- link_output_dim=output_dim,
56
68
  **kwargs,
57
69
  )
58
70
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
ring/ml/callbacks.py CHANGED
@@ -245,7 +245,8 @@ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
245
245
  else:
246
246
  value = "{:.2f}".format(ele.value).replace(".", ",")
247
247
  filename = parse_path(
248
- self.path_to_file + f"_episode={ele.episode}_value={value}",
248
+ str(Path(self.path_to_file).with_suffix(""))
249
+ + f"_episode={ele.episode}_value={value}",
249
250
  extension="pickle",
250
251
  )
251
252
 
@@ -404,7 +405,7 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
404
405
  # only checkpoint if run has been killed
405
406
  if training_loop.recv_kill_run_signal():
406
407
  path = parse_path(
407
- "~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
408
+ "~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
408
409
  )
409
410
  data = {"params": self.params, "opt_state": self.opt_state}
410
411
  pickle_save(
ring/ml/ringnet.py CHANGED
@@ -191,8 +191,16 @@ class LSTM(hk.RNNCore):
191
191
 
192
192
 
193
193
  class RING(ml_base.AbstractFilter):
194
- def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):
195
- self.forward_lam_factory = partial(make_ring, **kwargs)
194
+ def __init__(
195
+ self,
196
+ params=None,
197
+ lam=None,
198
+ jit: bool = True,
199
+ name=None,
200
+ forward_factory=make_ring,
201
+ **kwargs,
202
+ ):
203
+ self.forward_lam_factory = partial(forward_factory, **kwargs)
196
204
  self.params = self._load_params(params)
197
205
  self.lam = lam
198
206
  self._name = name
ring/ml/rnno_v1.py ADDED
@@ -0,0 +1,41 @@
1
+ from typing import Optional, Sequence
2
+
3
+ import haiku as hk
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def rnno_v1_forward_factory(
9
+ output_dim: int,
10
+ rnn_layers: Sequence[int] = (400, 300),
11
+ linear_layers: Sequence[int] = (200, 100, 50, 50, 25, 25),
12
+ layernorm: bool = True,
13
+ act_fn_linear=jax.nn.relu,
14
+ act_fn_rnn=jax.nn.elu,
15
+ lam: Optional[tuple[int]] = None,
16
+ ):
17
+ # unused
18
+ del lam
19
+
20
+ @hk.without_apply_rng
21
+ @hk.transform_with_state
22
+ def forward_fn(X):
23
+ assert X.shape[-2] == 1
24
+
25
+ for i, n_units in enumerate(rnn_layers):
26
+ state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
27
+ X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
28
+ hk.set_state(f"rnn_{i}", state)
29
+
30
+ if layernorm:
31
+ X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
32
+ X = act_fn_rnn(X)
33
+
34
+ for n_units in linear_layers:
35
+ X = hk.Linear(n_units)(X)
36
+ X = act_fn_linear(X)
37
+
38
+ y = hk.Linear(output_dim)(X)
39
+ return y[..., None, :]
40
+
41
+ return forward_fn
ring/ml/train.py CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import optax
8
+ import tree_utils
9
+
8
10
  from ring import maths
9
11
  from ring.algorithms.generator import types
10
12
  from ring.ml import base as ml_base
@@ -13,10 +15,7 @@ from ring.ml import ml_utils
13
15
  from ring.ml import training_loop
14
16
  from ring.utils import distribute_batchsize
15
17
  from ring.utils import expand_batchsize
16
- from ring.utils import parse_path
17
18
  from ring.utils import pickle_load
18
- import tree_utils
19
-
20
19
  import wandb
21
20
 
22
21
  # (T, N, F) -> Scalar
@@ -142,15 +141,17 @@ def train_fn(
142
141
  Wether or not the training run was killed by a callback.
143
142
  """
144
143
 
144
+ filter = filter.nojit()
145
+
145
146
  if checkpoint is not None:
146
147
  checkpoint = Path(checkpoint).with_suffix(".pickle")
147
148
  recv_checkpoint: dict = pickle_load(checkpoint)
148
- filter.params = recv_checkpoint["params"]
149
+ filter_params = recv_checkpoint["params"]
149
150
  opt_state = recv_checkpoint["opt_state"]
151
+ del recv_checkpoint
152
+ else:
153
+ filter_params = filter.search_attr("params")
150
154
 
151
- filter = filter.nojit()
152
-
153
- filter_params = filter.search_attr("params")
154
155
  if filter_params is None:
155
156
  X, _ = generator(jax.random.PRNGKey(1))
156
157
  filter_params, _ = filter.init(X=X, seed=seed_network)
@@ -215,7 +216,7 @@ def train_fn(
215
216
 
216
217
  callbacks_all.append(
217
218
  ml_callbacks.SaveParamsTrainingLoopCallback(
218
- path_to_file=parse_path(callback_save_params, extension=""),
219
+ path_to_file=callback_save_params,
219
220
  last_n_params=3,
220
221
  track_metrices=callback_save_params_track_metrices,
221
222
  cleanup=False,
ring/utils/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .batchsize import backend
1
+ from .batchsize import batchsize_thresholds
2
2
  from .batchsize import distribute_batchsize
3
3
  from .batchsize import expand_batchsize
4
4
  from .batchsize import merge_batchsize
ring/utils/backend.py ADDED
@@ -0,0 +1,30 @@
1
+ import os
2
+ import re
3
+
4
+
5
+ def set_host_device_count(n):
6
+ """
7
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
8
+ that there are `n` host (CPU) devices available to use. As a consequence, this
9
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
10
+
11
+ .. note:: This utility only takes effect at the beginning of your program.
12
+ Under the hood, this sets the environment variable
13
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
14
+ `[num_device]` is the desired number of CPU devices `n`.
15
+
16
+ .. warning:: Our understanding of the side effects of using the
17
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
18
+ observe some strange phenomenon when using this utility, please let us
19
+ know through our issue or forum page. More information is available in this
20
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
21
+
22
+ :param int n: number of CPU devices to use.
23
+ """
24
+ xla_flags = os.getenv("XLA_FLAGS", "")
25
+ xla_flags = re.sub(
26
+ r"--xla_force_host_platform_device_count=\S+", "", xla_flags
27
+ ).split()
28
+ os.environ["XLA_FLAGS"] = " ".join(
29
+ ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
30
+ )
ring/utils/batchsize.py CHANGED
@@ -1,19 +1,37 @@
1
- from typing import Optional, Tuple
1
+ from typing import Tuple, TypeVar
2
2
 
3
3
  import jax
4
- from tree_utils import PyTree
4
+
5
+ PyTree = TypeVar("PyTree")
6
+
7
+
8
+ def batchsize_thresholds():
9
+ backend = jax.default_backend()
10
+ if backend == "cpu":
11
+ vmap_size_min = 1
12
+ eager_threshold = 4
13
+ elif backend == "gpu":
14
+ vmap_size_min = 8
15
+ eager_threshold = 32
16
+ else:
17
+ raise Exception(
18
+ f"Backend {backend} has no default values, please add them in this function"
19
+ )
20
+ return vmap_size_min, eager_threshold
5
21
 
6
22
 
7
23
  def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
8
24
  """Distributes batchsize accross pmap and vmap."""
9
- vmap_size_min = 8
25
+ vmap_size_min = batchsize_thresholds()[0]
10
26
  if batchsize <= vmap_size_min:
11
27
  return 1, batchsize
12
28
  else:
13
29
  n_devices = jax.local_device_count()
14
- assert (
15
- batchsize % n_devices
16
- ) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
30
+ msg = (
31
+ f"Your local device count of {n_devices} does not split batchsize"
32
+ + f" {batchsize}. local devices are {jax.local_devices()}"
33
+ )
34
+ assert (batchsize % n_devices) == 0, msg
17
35
  vmap_size = int(batchsize / n_devices)
18
36
  return int(batchsize / vmap_size), vmap_size
19
37
 
@@ -35,17 +53,3 @@ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
35
53
  ),
36
54
  tree,
37
55
  )
38
-
39
-
40
- CPU_ONLY = False
41
-
42
-
43
- def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
44
- "Sets backend for all jax operations (including this library)."
45
- global CPU_ONLY
46
-
47
- if cpu_only and not CPU_ONLY:
48
- CPU_ONLY = True
49
- from jax import config
50
-
51
- config.update("jax_platform_name", "cpu")
ring/utils/utils.py CHANGED
@@ -122,6 +122,8 @@ def pytree_deepcopy(tree):
122
122
  return tuple(pytree_deepcopy(ele) for ele in tree)
123
123
  elif isinstance(tree, dict):
124
124
  return {key: pytree_deepcopy(value) for key, value in tree.items()}
125
+ elif isinstance(tree, _Base):
126
+ return tree
125
127
  else:
126
128
  raise NotImplementedError(f"Not implemented for type={type(tree)}")
127
129