imt-ring 1.6.36__py3-none-any.whl → 1.6.38__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: imt-ring
3
- Version: 1.6.36
3
+ Version: 1.6.38
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,25 +1,25 @@
1
1
  ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=_TgFrggsZfam0VPxvD4J5xp977vgiLnKTlDIJVzik5M,35362
3
+ ring/base.py,sha256=4Yxk6jk-B4UUm_6YYshxmHSHqOg0mhTOxtZP5fFS8nw,35373
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
7
  ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
8
8
  ring/algorithms/dynamics.py,sha256=dpe-F3Yq4sY2dY6DQW3v7TnPLRdxdkePtdbGPQIrijg,10997
9
9
  ring/algorithms/jcalc.py,sha256=QafnCKa1mjEl7nL_KuadPJB5ebW31NKnkdcKn2YtSsM,36171
10
- ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
- ring/algorithms/sensors.py,sha256=0xOzdQIc1kBF0CkoPXWWCx3MmV4SG3wj7knVnnMWq9M,18124
10
+ ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
11
+ ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
13
13
  ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
- ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
+ ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
18
  ring/algorithms/generator/base.py,sha256=jGQocoNZ5tkiMazBDCv-jD6FNYwebqn0_RgVFse49pg,16890
19
- ring/algorithms/generator/batch.py,sha256=P51UnAZl9TUF_eVq58VL1CsmPPStPHhRDdKjUyvu4EA,2652
20
- ring/algorithms/generator/finalize_fns.py,sha256=nY2RKiLbHriTkdec94lc4UGSZKd0v547MDNn4dr8I3E,10398
19
+ ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
20
+ ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
- ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
+ ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
23
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
24
24
  ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
25
25
  ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
@@ -47,46 +47,46 @@ ring/io/examples/test_morph_system/four_seg_seg1.xml,sha256=XJvGtEnvedejs_OmCVfQ
47
47
  ring/io/examples/test_morph_system/four_seg_seg3.xml,sha256=HktN7_a_Ly3YflWit5W-WncxApWGMORAGnRXyMEqnoA,1265
48
48
  ring/io/xml/__init__.py,sha256=-3k6ffvFyc4zm0oTyVz3ez-o3Lb9bPp2sjwSub_K1AA,242
49
49
  ring/io/xml/abstract.py,sha256=8Q2ebnUYLmuS9HJAQwDVrDTrRfD5z4G5RAB7MW8Oa60,9742
50
- ring/io/xml/from_xml.py,sha256=E7JQl_scL5U4LK6mqLMr5qaiZCc6J1fInxD7uwgNCJY,9356
50
+ ring/io/xml/from_xml.py,sha256=CR3OaBxoDuHK8k5N79XziHwS90lCaaw49UGzQirWiIw,9356
51
51
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
52
52
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
53
53
  ring/io/xml/to_xml.py,sha256=Wo4iySLw9nM-iVW42AGvMRqjtU2qRc2FD_Zlc7w1IrE,3438
54
54
  ring/ml/__init__.py,sha256=nbh48gaswWeY4S4vT1sply_3ROj2DQ7agjoLR4Ho3T8,1517
55
- ring/ml/base.py,sha256=lfwEZLBDglOSRWChUHoH1kezefhttPV9TMEpNIqsMNw,9972
55
+ ring/ml/base.py,sha256=eEpLuXlhFoJ-cpXoSGjctLaYduQhnSVpbv-FEYftNRs,9972
56
56
  ring/ml/callbacks.py,sha256=oCPXl4_Zcw3g0KRgyyUDmdiGxV0phnDVc_t8rEG4Lls,13737
57
- ring/ml/ml_utils.py,sha256=M--qkXRnhU7tHvgfTHfT9gyY0nhj3zMGEaK0X0drFLs,10915
57
+ ring/ml/ml_utils.py,sha256=hu189AnHcmkhkpEPZZ19O0gWz3T-YKpWQW9buqDTMow,10915
58
58
  ring/ml/optimizer.py,sha256=TZF0_LmnewzmGVso-zIQJtpWguUW0fW3HeRpIdG_qoI,4763
59
- ring/ml/ringnet.py,sha256=mef7jyN2QcApJmQGH3HYZyTV-00q8YpsYOKhW0-ku1k,8973
59
+ ring/ml/ringnet.py,sha256=uybMtLdQV0oTFldlbffrqku_l79y3coJYHwxX8P9QZQ,8973
60
60
  ring/ml/rnno_v1.py,sha256=2qE08OIvTJ5PvSxKpYGzGSrvEImWrdAT_qslZ7jP5tA,1372
61
- ring/ml/train.py,sha256=Da89HxiqXC7xuX2ldpTrJStqKWN-6Vcpml4PPQuihN4,10989
61
+ ring/ml/train.py,sha256=_CtQM3w9L01V5yn23lz0aaIPJN5sOWlL9e7G_9__11c,10989
62
62
  ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
63
63
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
64
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
65
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
66
66
  ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
67
- ring/rendering/mujoco_render.py,sha256=bSj1_7YL8wZV6cp9oD2CvbkZRSuxVhmPBA2JECxrnUE,8426
68
- ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
67
+ ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obmj4,8757
68
+ ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
69
69
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
70
70
  ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
71
71
  ring/sim2real/sim2real.py,sha256=B4nqBBnjGXhM-7PfTyxEq44ZidGNghqaq--qdFILX5A,9675
72
72
  ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E,193
73
73
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
74
- ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
74
+ ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
75
75
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
76
76
  ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
77
77
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
78
- ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
+ ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
79
79
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
- ring/utils/dataloader.py,sha256=2CcsbUY2AZs8LraS5HTJXlEseuF-1gKmfyBkSsib-tE,3748
81
- ring/utils/dataloader_torch.py,sha256=bravdBqbkxxcQDieg6OnmArGwGcWpMI3MmNFwTCt0qg,3808
82
- ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
83
- ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
80
+ ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
81
+ ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
82
+ ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
83
+ ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
84
84
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
85
85
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
86
- ring/utils/utils.py,sha256=tJaWXLGOTwkxJQj2l23dX97wO3aZYhM2qd7eNuMRs84,6907
86
+ ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
87
87
  ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
88
  ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.36.dist-info/METADATA,sha256=a-uW_s0jWJEBX9kW1q36Br4SNXPK7eGIVhlsyKDWruE,4251
90
- imt_ring-1.6.36.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
91
- imt_ring-1.6.36.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
- imt_ring-1.6.36.dist-info/RECORD,,
89
+ imt_ring-1.6.38.dist-info/METADATA,sha256=9rN1VzsIlGU8eyABz9-pTxj0OTCFOZRilEEzkB4gyvg,4251
90
+ imt_ring-1.6.38.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
91
+ imt_ring-1.6.38.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
92
+ imt_ring-1.6.38.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -184,7 +184,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
184
184
 
185
185
  suntay_link_name = _utils_find_suntay_joint(sys)
186
186
 
187
- params = jax.tree_map(
187
+ params = jax.tree.map(
188
188
  lambda arr: arr[sys.idx_map("l")[suntay_link_name]],
189
189
  sys.links.joint_params[name],
190
190
  )
@@ -80,11 +80,11 @@ def generators_eager(
80
80
  # converts also to numpy; but with np.array.flags.writeable = False
81
81
  sample = jax.device_get(sample)
82
82
  # this then sets this flag to True
83
- sample = jax.tree_map(np.array, sample)
83
+ sample = jax.tree.map(np.array, sample)
84
84
 
85
85
  sample_flat, _ = jax.tree_util.tree_flatten(sample)
86
86
  size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
87
- callback([jax.tree_map(lambda a: a[i].copy(), sample) for i in range(size)])
87
+ callback([jax.tree.map(lambda a: a[i].copy(), sample) for i in range(size)])
88
88
 
89
89
  # cleanup
90
90
  del sample, sample_flat
@@ -311,7 +311,7 @@ def _expand_then_flatten(Xy):
311
311
 
312
312
  X, y = _flatten(X), _flatten(y)
313
313
  if not batched:
314
- X, y = jax.tree_map(lambda arr: arr[0], (X, y))
314
+ X, y = jax.tree.map(lambda arr: arr[0], (X, y))
315
315
  return X, y
316
316
 
317
317
 
@@ -86,7 +86,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
86
86
  controller_state: PDControllerState, sys: base.System, state: base.State
87
87
  ) -> jax.Array:
88
88
  taus = jnp.zeros((sys.qd_size()))
89
- q_ref, qd_ref = jax.tree_map(
89
+ q_ref, qd_ref = jax.tree.map(
90
90
  lambda arr: jax.lax.dynamic_index_in_dim(
91
91
  arr, controller_state.i, keepdims=False
92
92
  ),
@@ -4,6 +4,7 @@ import jax
4
4
  import jax.numpy as jnp
5
5
  import jaxopt
6
6
  from jaxopt._src.base import Solver
7
+
7
8
  from ring import algebra
8
9
  from ring import base
9
10
  from ring import maths
@@ -171,7 +172,7 @@ def inverse_kinematics_endeffector(
171
172
 
172
173
  # find result of best q0 initial value
173
174
  best_q_index = jnp.argmin(values)
174
- best_q, best_q_value = jax.tree_map(
175
+ best_q, best_q_value = jax.tree.map(
175
176
  lambda arr: jax.lax.dynamic_index_in_dim(
176
177
  arr, best_q_index, keepdims=False
177
178
  ),
@@ -244,7 +244,7 @@ def imu(
244
244
  measurements["mag"] = magnetometer(xs.rot, magvec)
245
245
 
246
246
  if smoothen_degree is not None:
247
- measurements = jax.tree_map(
247
+ measurements = jax.tree.map(
248
248
  lambda arr: _moving_average(arr, smoothen_degree),
249
249
  measurements,
250
250
  )
@@ -257,7 +257,7 @@ def imu(
257
257
  delay = half_window
258
258
 
259
259
  if delay is not None and delay > 0:
260
- measurements = jax.tree_map(
260
+ measurements = jax.tree.map(
261
261
  lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
262
262
  )
263
263
 
@@ -473,7 +473,7 @@ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
473
473
  X[name] = {"joint_axes": joint_axes}
474
474
 
475
475
  sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
476
- X = jax.tree_map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
476
+ X = jax.tree.map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
477
477
  return X
478
478
 
479
479
 
@@ -498,12 +498,12 @@ _quasi_physical_sys_str = r"""
498
498
  <x_xy>
499
499
  <options gravity="0 0 0"/>
500
500
  <worldbody>
501
- <body name="IMU" joint="p3d" damping="0.1 0.1 0.1" spring_stiff="3 3 3">
502
- <geom type="box" mass="0.002" dim="0.01 0.01 0.01"/>
501
+ <body name="IMU" joint="free" damping="1 1 1 10 10 10" spring_stiff="20 20 20 500 500 500">
502
+ <geom type="box" mass="1" dim="0.01 0.01 0.01"/>
503
503
  </body>
504
504
  </worldbody>
505
505
  </x_xy>
506
- """
506
+ """ # noqa: E501
507
507
 
508
508
 
509
509
  def _quasi_physical_simulation_beautiful(
@@ -512,12 +512,14 @@ def _quasi_physical_simulation_beautiful(
512
512
  sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
513
513
 
514
514
  def step_dynamics(state: base.State, x):
515
- state = algorithms.step(sys.replace(link_spring_zeropoint=x.pos), state)
515
+ state = algorithms.step(
516
+ sys.replace(link_spring_zeropoint=jnp.concatenate((x.rot, x.pos))), state
517
+ )
516
518
  return state, state.q
517
519
 
518
- state = base.State.create(sys, q=xs.pos[0])
519
- _, pos = jax.lax.scan(step_dynamics, state, xs)
520
- return xs.replace(pos=pos)
520
+ state = base.State.create(sys, q=jnp.concatenate((xs.rot[0], xs.pos[0])))
521
+ _, qs = jax.lax.scan(step_dynamics, state, xs)
522
+ return xs.replace(rot=qs[:, :4], pos=qs[:, 4:])
521
523
 
522
524
 
523
525
  _constants = {
ring/base.py CHANGED
@@ -807,7 +807,7 @@ class System(_Base):
807
807
  if exists:
808
808
  return cls.from_xml(path, seed=seed)
809
809
  else:
810
- return cls.from_str(path_or_str)
810
+ return cls.from_str(path_or_str, seed=seed)
811
811
 
812
812
  def coordinate_vector_to_q(
813
813
  self,
ring/io/xml/from_xml.py CHANGED
@@ -252,7 +252,7 @@ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
252
252
 
253
253
  # numpy -> jax
254
254
  # we load using numpy in order to have float64 precision
255
- sys = jax.tree_map(jax.numpy.asarray, sys)
255
+ sys = jax.tree.map(jax.numpy.asarray, sys)
256
256
 
257
257
  sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
258
258
 
ring/ml/base.py CHANGED
@@ -13,13 +13,13 @@ from ring.utils import pickle_save
13
13
  def _to_3d(tree):
14
14
  if tree is None:
15
15
  return None
16
- return jax.tree_map(lambda arr: arr[None], tree)
16
+ return jax.tree.map(lambda arr: arr[None], tree)
17
17
 
18
18
 
19
19
  def _to_2d(tree, i: int = 0):
20
20
  if tree is None:
21
21
  return None
22
- return jax.tree_map(lambda arr: arr[i], tree)
22
+ return jax.tree.map(lambda arr: arr[i], tree)
23
23
 
24
24
 
25
25
  class AbstractFilter(ABC):
ring/ml/ml_utils.py CHANGED
@@ -161,7 +161,7 @@ def _flatten_convert_filter_nested_dict(
161
161
  metrices: NestedDict, filter_nan_inf: bool = True
162
162
  ):
163
163
  metrices = _flatten_dict(metrices)
164
- metrices = jax.tree_map(_to_float_if_not_string, metrices)
164
+ metrices = jax.tree.map(_to_float_if_not_string, metrices)
165
165
 
166
166
  if not filter_nan_inf:
167
167
  return metrices
@@ -216,7 +216,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
216
216
  from jax.experimental import jax2tf
217
217
  import tensorflow as tf
218
218
 
219
- signature = jax.tree_map(
219
+ signature = jax.tree.map(
220
220
  lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
221
221
  )
222
222
 
@@ -241,7 +241,7 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
241
241
  if validate:
242
242
  output_jax = jax_func(*input)
243
243
  output_tf = tf.saved_model.load(path)(*input)
244
- jax.tree_map(
244
+ jax.tree.map(
245
245
  lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
246
246
  output_jax,
247
247
  output_tf,
ring/ml/ringnet.py CHANGED
@@ -248,7 +248,7 @@ class RING(ml_base.AbstractFilter):
248
248
  params, state = self.forward_lam_factory(lam=lam).init(key, X)
249
249
 
250
250
  if bs is not None:
251
- state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
251
+ state = jax.tree.map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
252
252
 
253
253
  return params, state
254
254
 
ring/ml/train.py CHANGED
@@ -50,7 +50,7 @@ def _build_step_fn(
50
50
  # this vmap maps along batch-axis, not time-axis
51
51
  # time-axis is handled by `metric_fn`
52
52
  pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
53
- error_tree = jax.tree_map(pipe, y, yhat)
53
+ error_tree = jax.tree.map(pipe, y, yhat)
54
54
  return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
55
55
 
56
56
  @partial(
@@ -274,7 +274,7 @@ def _build_eval_fn(
274
274
  ), f"The metric identitifier {metric_name} is not unique"
275
275
 
276
276
  pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
277
- values.update({metric_name: jax.tree_map(pipe, y, yhat)})
277
+ values.update({metric_name: jax.tree.map(pipe, y, yhat)})
278
278
 
279
279
  return values
280
280
 
@@ -10,8 +10,8 @@ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6
10
10
  _skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
11
11
 
12
12
 
13
- def _floor(floor_z: float) -> str:
14
- return f"""<geom name="floor" pos="0 0 {floor_z}" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
13
+ def _floor(z: float, material: str) -> str:
14
+ return f"""<geom name="floor" pos="0 0 {z}" size="0 0 1" type="plane" material="{material}" mass="0"/>""" # noqa: E501
15
15
 
16
16
 
17
17
  def _build_model_of_geoms(
@@ -19,7 +19,7 @@ def _build_model_of_geoms(
19
19
  cameras: dict[int, Sequence[str]],
20
20
  lights: dict[int, Sequence[str]],
21
21
  floor: bool,
22
- floor_z: float,
22
+ floor_kwargs: dict,
23
23
  stars: bool,
24
24
  debug: bool,
25
25
  ) -> mujoco.MjModel:
@@ -77,10 +77,13 @@ def _build_model_of_geoms(
77
77
  xml_str = f""" # noqa: E501
78
78
  <mujoco>
79
79
  <asset>
80
- <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
80
+ <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".3 .3 .3"/>
81
81
  <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
82
82
  <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
83
83
  <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="2 2" reflectance="0.2"/>
84
+ <material name="beige" rgba="0.76 0.80 0.50 1.0" specular="0.3" shininess="0.1" />
85
+ <material name="white" rgba="0.9 0.9 0.9 1.0" reflectance="0"/>
86
+ <material name="gray" rgba="0.4 0.5 0.5 1.0" reflectance="0.25"/>
84
87
  {_skybox if stars else ''}
85
88
  <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
86
89
  <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
@@ -98,7 +101,7 @@ def _build_model_of_geoms(
98
101
  <camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
99
102
  <camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
100
103
  <camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
101
- {_floor(floor_z) if floor else ''}
104
+ {_floor(**floor_kwargs) if floor else ''}
102
105
  {inside_worldbody_cameras}
103
106
  {inside_worldbody_lights}
104
107
  {inside_worldbody}
@@ -176,6 +179,7 @@ class MujocoScene:
176
179
  show_stars: bool = True,
177
180
  show_floor: bool = True,
178
181
  floor_z: float = -0.84,
182
+ floor_material: str = "matplane",
179
183
  debug: bool = False,
180
184
  ) -> None:
181
185
  self.debug = debug
@@ -190,7 +194,7 @@ class MujocoScene:
190
194
  self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
191
195
  self.show_stars = show_stars
192
196
  self.show_floor = show_floor
193
- self.floor_z = floor_z
197
+ self.floor_kwargs = dict(z=floor_z, material=floor_material)
194
198
 
195
199
  def init(self, geoms: list[base.Geometry]):
196
200
  self._parent_ids = list(set([geom.link_idx for geom in geoms]))
@@ -199,7 +203,7 @@ class MujocoScene:
199
203
  self.add_cameras,
200
204
  self.add_lights,
201
205
  floor=self.show_floor,
202
- floor_z=self.floor_z,
206
+ floor_kwargs=self.floor_kwargs,
203
207
  stars=self.show_stars,
204
208
  debug=self.debug,
205
209
  )
@@ -7,14 +7,15 @@ from typing import Optional, TypeVar
7
7
  import jax
8
8
  import jax.numpy as jnp
9
9
  import numpy as np
10
- from ring import algebra
11
- from ring import base
12
- from ring import maths
13
10
  from tree_utils import PyTree
14
11
  from tree_utils import tree_batch
15
12
  from vispy import scene
16
13
  from vispy.scene import MatrixTransform
17
14
 
15
+ from ring import algebra
16
+ from ring import base
17
+ from ring import maths
18
+
18
19
  from . import vispy_visuals
19
20
 
20
21
  Camera = TypeVar("Camera")
@@ -192,7 +193,7 @@ class Scene(ABC):
192
193
 
193
194
  # step 3: update visuals
194
195
  for i, (visual, geom) in enumerate(zip(self.visuals, self.geoms)):
195
- t = jax.tree_map(lambda arr: arr[i], transform_per_visual)
196
+ t = jax.tree.map(lambda arr: arr[i], transform_per_visual)
196
197
  if self._fresh_init:
197
198
  self._init_visual(visual, t, geom)
198
199
  else:
@@ -2,12 +2,13 @@ from typing import Optional
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- from ring import base
6
5
  from tree_utils import tree_batch
7
6
 
7
+ from ring import base
8
+
8
9
 
9
10
  def _tree_nan_like(tree, repeats: int):
10
- return jax.tree_map(
11
+ return jax.tree.map(
11
12
  lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
12
13
  )
13
14
 
ring/utils/batchsize.py CHANGED
@@ -39,19 +39,19 @@ def merge_batchsize(
39
39
  tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
40
40
  ) -> PyTree:
41
41
  if third_dim_also:
42
- return jax.tree_map(
42
+ return jax.tree.map(
43
43
  lambda arr: arr.reshape(
44
44
  (pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
45
45
  ),
46
46
  tree,
47
47
  )
48
- return jax.tree_map(
48
+ return jax.tree.map(
49
49
  lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
50
50
  )
51
51
 
52
52
 
53
53
  def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
54
- return jax.tree_map(
54
+ return jax.tree.map(
55
55
  lambda arr: arr.reshape(
56
56
  (
57
57
  pmap_size,
ring/utils/dataloader.py CHANGED
@@ -4,14 +4,15 @@ from typing import Callable, Optional
4
4
 
5
5
  import jax
6
6
  import numpy as np
7
- from ring.utils import parse_path
8
- from ring.utils import pickle_load
9
7
  import torch
10
8
  from torch.utils.data import DataLoader
11
9
  from torch.utils.data import Dataset
12
10
  import tqdm
13
11
  from tree_utils import PyTree
14
12
 
13
+ from ring.utils import parse_path
14
+ from ring.utils import pickle_load
15
+
15
16
 
16
17
  def make_generator(
17
18
  *paths,
@@ -103,7 +104,7 @@ def pytorch_generator(
103
104
  dl_iter = iter(dl)
104
105
 
105
106
  def to_numpy(tree: PyTree[torch.Tensor]):
106
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
107
+ return jax.tree.map(lambda tensor: tensor.numpy(), tree)
107
108
 
108
109
  def generator(_):
109
110
  nonlocal dl, dl_iter
@@ -1,16 +1,25 @@
1
1
  import os
2
+ import pickle
2
3
  from typing import Any, Optional
3
4
  import warnings
4
5
 
5
- import jax
6
6
  import numpy as np
7
7
  import torch
8
8
  from torch.utils.data import DataLoader
9
9
  from torch.utils.data import Dataset
10
+ import tree
10
11
  from tree_utils import PyTree
11
12
 
12
- from ring.utils import parse_path
13
- from ring.utils import pickle_load
13
+ from ring.utils.path import parse_path
14
+
15
+
16
+ def pickle_load(
17
+ path,
18
+ ):
19
+ path = parse_path(path, extension="pickle", require_is_file=True)
20
+ with open(path, "rb") as file:
21
+ obj = pickle.load(file)
22
+ return obj
14
23
 
15
24
 
16
25
  class FolderOfFilesDataset(Dataset):
@@ -60,8 +69,8 @@ def dataset_to_generator(
60
69
  )
61
70
  dl_iter = iter(dl)
62
71
 
63
- def to_numpy(tree: PyTree[torch.Tensor]):
64
- return jax.tree_map(lambda tensor: tensor.numpy(), tree)
72
+ def to_numpy(data: PyTree[torch.Tensor]):
73
+ return tree.map_structure(lambda tensor: tensor.numpy(), data)
65
74
 
66
75
  def generator(_):
67
76
  nonlocal dl, dl_iter
ring/utils/hdf5.py CHANGED
@@ -121,7 +121,7 @@ def _parse_path(
121
121
 
122
122
  def _tree_concat(trees: list):
123
123
  # otherwise scalar-arrays will lead to indexing error
124
- trees = jax.tree_map(lambda arr: np.atleast_1d(arr), trees)
124
+ trees = jax.tree.map(lambda arr: np.atleast_1d(arr), trees)
125
125
 
126
126
  if len(trees) == 0:
127
127
  return trees
ring/utils/normalizer.py CHANGED
@@ -3,9 +3,10 @@ from typing import Callable, TypeVar
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
- from ring.algorithms.generator import types
7
6
  import tree_utils
8
7
 
8
+ from ring.algorithms.generator import types
9
+
9
10
  KEY = jax.random.PRNGKey(777)
10
11
  KEY_PERMUTATION = jax.random.PRNGKey(888)
11
12
 
@@ -37,12 +38,12 @@ def make_normalizer_from_generator(
37
38
  # permute 0-th axis, since batchsize of generator might be larger than
38
39
  # `approx_with_large_batchsize`, then we would not get a representative
39
40
  # subsample otherwise
40
- Xs = jax.tree_map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
+ Xs = jax.tree.map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
42
  Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
42
43
 
43
44
  # obtain statistics
44
- mean = jax.tree_map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
45
- std = jax.tree_map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
45
+ mean = jax.tree.map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
46
+ std = jax.tree.map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
46
47
 
47
48
  if verbose:
48
49
  print("Mean: ", mean)
@@ -51,6 +52,6 @@ def make_normalizer_from_generator(
51
52
  eps = 1e-8
52
53
 
53
54
  def normalizer(X):
54
- return jax.tree_map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
+ return jax.tree.map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
56
 
56
57
  return normalizer
ring/utils/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from importlib import import_module as _import_module
2
2
  import io
3
+ from pathlib import Path
3
4
  import pickle
4
5
  import random
5
6
  from typing import Optional
@@ -152,13 +153,28 @@ def import_lib(
152
153
 
153
154
  def pickle_save(obj, path, overwrite: bool = False):
154
155
  path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
155
- with open(path, "wb") as file:
156
- pickle.dump(obj, file, protocol=5)
156
+ try:
157
+ with open(path, "wb") as file:
158
+ pickle.dump(obj, file, protocol=5)
159
+ except OSError as e:
160
+ print(
161
+ f"saving with `pickle` throws exception {e}. "
162
+ + "Attempting to save using `joblib`"
163
+ )
164
+ path = parse_path(path, extension="joblib", file_exists_ok=overwrite)
165
+ import joblib
166
+
167
+ joblib.dump(obj, path)
157
168
 
158
169
 
159
170
  def pickle_load(
160
171
  path,
161
172
  ):
173
+ if Path(path).suffix == ".joblib":
174
+ import joblib
175
+
176
+ return joblib.load(path)
177
+
162
178
  path = parse_path(path, extension="pickle", require_is_file=True)
163
179
  with open(path, "rb") as file:
164
180
  obj = pickle.load(file)