imt-ring 1.3.1__py3-none-any.whl → 1.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.1.dist-info → imt_ring-1.3.3.dist-info}/METADATA +1 -1
- {imt_ring-1.3.1.dist-info → imt_ring-1.3.3.dist-info}/RECORD +9 -9
- ring/algorithms/custom_joints/__init__.py +2 -0
- ring/algorithms/custom_joints/suntay.py +73 -0
- ring/base.py +10 -4
- ring/ml/__init__.py +2 -1
- ring/ml/optimizer.py +2 -2
- {imt_ring-1.3.1.dist-info → imt_ring-1.3.3.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.1.dist-info → imt_ring-1.3.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=ZAoe9B1HbAX9NYiKaisssTBn-1VBXoJTsWgFAvlQoZw,33705
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
|
@@ -9,10 +9,10 @@ ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,1
|
|
9
9
|
ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
|
12
|
-
ring/algorithms/custom_joints/__init__.py,sha256=
|
12
|
+
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=0Oym3KQssj3QDOldnz9PTy5jPg9ZLk85mMK2YX1qvB4,15600
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
17
|
ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
|
18
18
|
ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
|
@@ -50,11 +50,11 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
|
|
50
50
|
ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
|
51
51
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
52
52
|
ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
53
|
-
ring/ml/__init__.py,sha256
|
53
|
+
ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
|
54
54
|
ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
|
55
55
|
ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
|
-
ring/ml/optimizer.py,sha256=
|
57
|
+
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
|
59
59
|
ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
|
60
60
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
@@ -77,7 +77,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
77
77
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
78
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
79
|
ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
80
|
+
imt_ring-1.3.3.dist-info/METADATA,sha256=nTihurycKYmLCI61Cojd7VLrnb1gpd-H8nwUupysaC8,3104
|
81
|
+
imt_ring-1.3.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
82
|
+
imt_ring-1.3.3.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
83
|
+
imt_ring-1.3.3.dist-info/RECORD,,
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from .rr_imp_joint import register_rr_imp_joint
|
2
2
|
from .rr_joint import register_rr_joint
|
3
|
+
from .suntay import ConstantValue_DrawnFnPair
|
3
4
|
from .suntay import GP_DrawFnPair
|
4
5
|
from .suntay import MLP_DrawnFnPair
|
6
|
+
from .suntay import Polynomial_DrawnFnPair
|
5
7
|
from .suntay import register_suntay
|
6
8
|
from .suntay import SuntayConfig
|
@@ -293,6 +293,79 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
293
293
|
ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
|
294
294
|
|
295
295
|
|
296
|
+
def Polynomial_DrawnFnPair(
|
297
|
+
order: int = 2,
|
298
|
+
val: float = 2.0,
|
299
|
+
center: bool = False,
|
300
|
+
flexion_center: Optional[float] = None,
|
301
|
+
) -> DrawnFnPairFactory:
|
302
|
+
assert val >= 0.0
|
303
|
+
|
304
|
+
# because 0-th order is also counted
|
305
|
+
order += 1
|
306
|
+
|
307
|
+
def factory(xs, mn, mx):
|
308
|
+
nonlocal flexion_center
|
309
|
+
|
310
|
+
flexion_mn = jnp.min(xs)
|
311
|
+
flexion_mx = jnp.max(xs)
|
312
|
+
|
313
|
+
def _apply_poly_factors(poly_factors, q):
|
314
|
+
return poly_factors @ jnp.power(q, jnp.arange(order))
|
315
|
+
|
316
|
+
if flexion_center is None:
|
317
|
+
flexion_center = (flexion_mn + flexion_mx) / 2
|
318
|
+
else:
|
319
|
+
flexion_center = jnp.array(flexion_center)
|
320
|
+
|
321
|
+
def init(key):
|
322
|
+
c1, c2 = jax.random.split(key)
|
323
|
+
poly_factors = jax.random.uniform(
|
324
|
+
c1, shape=(order,), minval=-val, maxval=val
|
325
|
+
)
|
326
|
+
q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
|
327
|
+
values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
|
328
|
+
poly_factors, xs - q0
|
329
|
+
)
|
330
|
+
eps = 1e-6
|
331
|
+
amin, amax = jnp.min(values), jnp.max(values) + eps
|
332
|
+
return amin, amax, poly_factors, q0
|
333
|
+
|
334
|
+
def _apply(params, q):
|
335
|
+
amin, amax, poly_factors, q0 = params
|
336
|
+
q = q - q0
|
337
|
+
value = _apply_poly_factors(poly_factors, q)
|
338
|
+
return restrict(value, mn, mx, amin, amax)
|
339
|
+
|
340
|
+
if center:
|
341
|
+
|
342
|
+
def apply(params, q):
|
343
|
+
return _apply(params, q) - _apply(params, flexion_center)
|
344
|
+
|
345
|
+
else:
|
346
|
+
apply = _apply
|
347
|
+
|
348
|
+
return DrawnFnPair(init, apply)
|
349
|
+
|
350
|
+
return factory
|
351
|
+
|
352
|
+
|
353
|
+
def ConstantValue_DrawnFnPair(value: float) -> DrawnFnPairFactory:
|
354
|
+
value = jnp.array(value)
|
355
|
+
|
356
|
+
def factory(xs, mn, mx):
|
357
|
+
|
358
|
+
def init(key):
|
359
|
+
return {}
|
360
|
+
|
361
|
+
def apply(params, q):
|
362
|
+
return value
|
363
|
+
|
364
|
+
return DrawnFnPair(init, apply)
|
365
|
+
|
366
|
+
return factory
|
367
|
+
|
368
|
+
|
296
369
|
def MLP_DrawnFnPair(
|
297
370
|
center: bool = False, flexion_center: Optional[float] = None
|
298
371
|
) -> DrawnFnPairFactory:
|
ring/base.py
CHANGED
@@ -929,16 +929,22 @@ def _parse_system_calculate_inertia(sys: System):
|
|
929
929
|
def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = False):
|
930
930
|
assert len(args) == len(in_types)
|
931
931
|
for in_type, arg in zip(in_types, args):
|
932
|
-
|
932
|
+
|
933
933
|
if in_type == "l":
|
934
|
-
|
934
|
+
required_length = sys.num_links()
|
935
935
|
elif in_type == "q":
|
936
|
-
|
936
|
+
required_length = sys.q_size()
|
937
937
|
elif in_type == "d":
|
938
|
-
|
938
|
+
required_length = sys.qd_size()
|
939
939
|
else:
|
940
940
|
raise Exception("`in_types` must be one of `l` or `q` or `d`")
|
941
941
|
|
942
|
+
B = len(arg)
|
943
|
+
B_re = required_length
|
944
|
+
assert (
|
945
|
+
B == B_re
|
946
|
+
), f"arg={arg} has a length of B={B} which isn't the required length={B_re}"
|
947
|
+
|
942
948
|
order = range(sys.num_links())
|
943
949
|
q_idx, qd_idx = 0, 0
|
944
950
|
q_idxs, qd_idxs = {}, {}
|
ring/ml/__init__.py
CHANGED
@@ -41,6 +41,7 @@ def RNNO(
|
|
41
41
|
return_quats: bool = False,
|
42
42
|
params=None,
|
43
43
|
eval: bool = True,
|
44
|
+
samp_freq: float | None = None,
|
44
45
|
**kwargs,
|
45
46
|
):
|
46
47
|
assert "message_dim" not in kwargs
|
@@ -57,7 +58,7 @@ def RNNO(
|
|
57
58
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
58
59
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
59
60
|
if eval and return_quats:
|
60
|
-
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=
|
61
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
|
61
62
|
if return_quats:
|
62
63
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
63
64
|
return ringnet
|
ring/ml/optimizer.py
CHANGED
@@ -14,10 +14,10 @@ from optax._src.transform import AddNoiseState
|
|
14
14
|
def make_optimizer(
|
15
15
|
lr: float,
|
16
16
|
n_episodes: int,
|
17
|
-
n_steps_per_episode: int,
|
17
|
+
n_steps_per_episode: int = 6,
|
18
18
|
adap_clip: Optional[float] = 0.1,
|
19
19
|
glob_clip: Optional[float] = 0.2,
|
20
|
-
skip_large_update_max_normsq: float =
|
20
|
+
skip_large_update_max_normsq: float = 100.0,
|
21
21
|
skip_large_update_warmup: int = 300,
|
22
22
|
inner_opt=optax.lamb,
|
23
23
|
cos_decay_twice: bool = False,
|
File without changes
|
File without changes
|