imt-ring 1.3.8__py3-none-any.whl → 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.8.dist-info → imt_ring-1.3.9.dist-info}/METADATA +1 -1
- {imt_ring-1.3.8.dist-info → imt_ring-1.3.9.dist-info}/RECORD +11 -11
- ring/algorithms/dynamics.py +11 -5
- ring/algorithms/generator/base.py +11 -13
- ring/algorithms/generator/batch.py +5 -1
- ring/algorithms/generator/motion_artifacts.py +6 -4
- ring/algorithms/generator/pd_control.py +2 -1
- ring/base.py +1 -3
- ring/ml/rnno_v1.py +6 -2
- {imt_ring-1.3.8.dist-info → imt_ring-1.3.9.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.8.dist-info → imt_ring-1.3.9.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
|
7
7
|
ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
|
8
|
-
ring/algorithms/dynamics.py,sha256=
|
8
|
+
ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
|
9
9
|
ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
@@ -14,10 +14,10 @@ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSe
|
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
15
|
ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
19
|
-
ring/algorithms/generator/motion_artifacts.py,sha256=
|
20
|
-
ring/algorithms/generator/pd_control.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
|
18
|
+
ring/algorithms/generator/batch.py,sha256=kNlq78W-nAtbp6Xe82UjbPY-rXX2alGLxTokTITSbAc,9226
|
19
|
+
ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
|
20
|
+
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
21
21
|
ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
22
22
|
ring/algorithms/generator/transforms.py,sha256=nvNDvM20tEw9Zd0ra0TxA25uf01L40Y2UKvtQmOrlGo,12782
|
23
23
|
ring/algorithms/generator/types.py,sha256=CAhvDK5qiHnrGtkCVccB07doiz_D6lHJ35B7sW0pyZA,1110
|
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
|
59
|
-
ring/ml/rnno_v1.py,sha256=
|
59
|
+
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
60
60
|
ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
|
61
61
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
@@ -78,7 +78,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
78
78
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
79
79
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
80
80
|
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
84
|
-
imt_ring-1.3.
|
81
|
+
imt_ring-1.3.9.dist-info/METADATA,sha256=H65-QICwM4mtRPumYJbrenN74nmiMBGbeV3pecKEeOg,3104
|
82
|
+
imt_ring-1.3.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
83
|
+
imt_ring-1.3.9.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
84
|
+
imt_ring-1.3.9.dist-info/RECORD,,
|
ring/algorithms/dynamics.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import Optional, Tuple
|
2
|
+
import warnings
|
2
3
|
|
3
4
|
import jax
|
4
5
|
import jax.numpy as jnp
|
6
|
+
|
5
7
|
from ring import algebra
|
6
8
|
from ring import base
|
7
9
|
from ring import maths
|
@@ -213,7 +215,7 @@ def forward_dynamics(
|
|
213
215
|
q: jax.Array,
|
214
216
|
qd: jax.Array,
|
215
217
|
tau: jax.Array,
|
216
|
-
mass_mat_inv: jax.Array,
|
218
|
+
# mass_mat_inv: jax.Array,
|
217
219
|
) -> Tuple[jax.Array, jax.Array]:
|
218
220
|
C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
|
219
221
|
mass_matrix = compute_mass_matrix(sys)
|
@@ -235,6 +237,11 @@ def forward_dynamics(
|
|
235
237
|
|
236
238
|
mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
|
237
239
|
else:
|
240
|
+
warnings.warn(
|
241
|
+
f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
|
242
|
+
"This feature is currently not fully supported. See the local TODO."
|
243
|
+
)
|
244
|
+
mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
|
238
245
|
mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
|
239
246
|
|
240
247
|
return mass_mat_inv @ qf_smooth, mass_mat_inv
|
@@ -254,9 +261,8 @@ def _strapdown_integration(
|
|
254
261
|
def _semi_implicit_euler_integration(
|
255
262
|
sys: base.System, state: base.State, taus: jax.Array
|
256
263
|
) -> base.State:
|
257
|
-
qdd, mass_mat_inv = forward_dynamics(
|
258
|
-
|
259
|
-
)
|
264
|
+
qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
|
265
|
+
del mass_mat_inv
|
260
266
|
qd_next = state.qd + sys.dt * qdd
|
261
267
|
|
262
268
|
q_next = []
|
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
|
|
277
283
|
sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
|
278
284
|
q_next = jnp.concatenate(q_next)
|
279
285
|
|
280
|
-
state = state.replace(q=q_next, qd=qd_next
|
286
|
+
state = state.replace(q=q_next, qd=qd_next)
|
281
287
|
return state
|
282
288
|
|
283
289
|
|
@@ -4,6 +4,7 @@ import warnings
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
import tqdm
|
7
8
|
import tree_utils
|
8
9
|
|
9
10
|
from ring import base
|
@@ -83,10 +84,14 @@ class RCMG:
|
|
83
84
|
), "If `randomize_anchors`, then only one system is expected"
|
84
85
|
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
85
86
|
|
86
|
-
zip_sys_config = False
|
87
87
|
if randomize_hz:
|
88
|
-
zip_sys_config = True
|
89
88
|
sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
|
89
|
+
else:
|
90
|
+
# create zip
|
91
|
+
N_sys = len(sys)
|
92
|
+
sys = sum([len(config) * [s] for s in sys], start=[])
|
93
|
+
config = N_sys * config
|
94
|
+
assert len(sys) == len(config)
|
90
95
|
|
91
96
|
if sys_ml is None:
|
92
97
|
# TODO
|
@@ -97,17 +102,10 @@ class RCMG:
|
|
97
102
|
sys_ml = sys[0]
|
98
103
|
|
99
104
|
self.gens = []
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
)
|
105
|
-
else:
|
106
|
-
for _sys in sys:
|
107
|
-
for _config in config:
|
108
|
-
self.gens.append(
|
109
|
-
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
110
|
-
)
|
105
|
+
for _sys, _config in tqdm.tqdm(
|
106
|
+
zip(sys, config), desc="building generators", total=len(sys)
|
107
|
+
):
|
108
|
+
self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
|
111
109
|
|
112
110
|
def _to_data(self, sizes, seed):
|
113
111
|
return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
|
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
|
|
86
86
|
|
87
87
|
key = jax.random.PRNGKey(seed)
|
88
88
|
data = []
|
89
|
-
for gen, size in tqdm(
|
89
|
+
for gen, size in tqdm(
|
90
|
+
zip(generators, sizes),
|
91
|
+
desc="executing generators",
|
92
|
+
total=len(sizes),
|
93
|
+
):
|
90
94
|
|
91
95
|
n_calls = _number_of_executions_required(size)
|
92
96
|
# decrease size by n_calls times
|
@@ -49,6 +49,7 @@ def inject_subsystems(
|
|
49
49
|
rotational_damp: float = 0.1,
|
50
50
|
translational_stif: float = 50.0,
|
51
51
|
translational_damp: float = 0.1,
|
52
|
+
disable_warning: bool = False,
|
52
53
|
**kwargs,
|
53
54
|
) -> base.System:
|
54
55
|
imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
|
@@ -92,10 +93,11 @@ def inject_subsystems(
|
|
92
93
|
# TODO set all joint_params to zeros; they can not be preserved anyways and
|
93
94
|
# otherwise many warnings will be rose
|
94
95
|
# instead warn explicitly once now and move on
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
96
|
+
if not disable_warning:
|
97
|
+
warnings.warn(
|
98
|
+
"`sys.links.joint_params` has been set to zero, this might lead to "
|
99
|
+
"unexpected behaviour unless you use `randomize_joint_params`"
|
100
|
+
)
|
99
101
|
joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
|
100
102
|
sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
|
101
103
|
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
from flax import struct
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
|
7
8
|
from ring import base
|
8
9
|
from ring.algorithms import dynamics
|
9
10
|
from ring.algorithms import jcalc
|
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
49
50
|
assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
|
50
51
|
assert sys.qd_size() == P.size
|
51
52
|
if D is not None:
|
52
|
-
sys.qd_size() == D.size
|
53
|
+
assert sys.qd_size() == D.size
|
53
54
|
|
54
55
|
q_ref_as_dict = {}
|
55
56
|
qd_ref_as_dict = {}
|
ring/base.py
CHANGED
@@ -997,13 +997,11 @@ class State(_Base):
|
|
997
997
|
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
998
998
|
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
999
999
|
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
1000
|
-
mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
|
1001
1000
|
"""
|
1002
1001
|
|
1003
1002
|
q: jax.Array
|
1004
1003
|
qd: jax.Array
|
1005
1004
|
x: Transform
|
1006
|
-
mass_mat_inv: jax.Array
|
1007
1005
|
|
1008
1006
|
@classmethod
|
1009
1007
|
def create(
|
@@ -1057,4 +1055,4 @@ class State(_Base):
|
|
1057
1055
|
if x is None:
|
1058
1056
|
x = Transform.zero((sys.num_links(),))
|
1059
1057
|
|
1060
|
-
return cls(q, qd, x
|
1058
|
+
return cls(q, qd, x)
|
ring/ml/rnno_v1.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Sequence
|
1
|
+
from typing import Optional, Sequence
|
2
2
|
|
3
3
|
import haiku as hk
|
4
4
|
import jax
|
@@ -12,14 +12,18 @@ def rnno_v1_forward_factory(
|
|
12
12
|
layernorm: bool = True,
|
13
13
|
act_fn_linear=jax.nn.relu,
|
14
14
|
act_fn_rnn=jax.nn.elu,
|
15
|
+
lam: Optional[tuple[int]] = None,
|
15
16
|
):
|
17
|
+
# unused
|
18
|
+
del lam
|
19
|
+
|
16
20
|
@hk.without_apply_rng
|
17
21
|
@hk.transform_with_state
|
18
22
|
def forward_fn(X):
|
19
23
|
assert X.shape[-2] == 1
|
20
24
|
|
21
25
|
for i, n_units in enumerate(rnn_layers):
|
22
|
-
state = hk.get_state(f"rnn_{i}", shape=[n_units], init=jnp.zeros)
|
26
|
+
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
|
23
27
|
X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
|
24
28
|
hk.set_state(f"rnn_{i}", state)
|
25
29
|
|
File without changes
|
File without changes
|