imt-ring 1.3.5__py3-none-any.whl → 1.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.5.dist-info → imt_ring-1.3.7.dist-info}/METADATA +1 -1
- {imt_ring-1.3.5.dist-info → imt_ring-1.3.7.dist-info}/RECORD +10 -10
- ring/algorithms/custom_joints/suntay.py +16 -1
- ring/algorithms/generator/base.py +4 -1
- ring/algorithms/sensors.py +3 -6
- ring/base.py +2 -1
- ring/ml/train.py +8 -6
- ring/utils/utils.py +2 -0
- {imt_ring-1.3.5.dist-info → imt_ring-1.3.7.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.5.dist-info → imt_ring-1.3.7.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=99GuspRH4QtRRJTAgyvS02FFxoaBptSsz_GPczX8vw0,33947
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
|
@@ -8,13 +8,13 @@ ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13
|
|
8
8
|
ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,10629
|
9
9
|
ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
|
-
ring/algorithms/sensors.py,sha256=
|
11
|
+
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=AKH7GXEmRGV1kK8okiqa12uq0Ah9VYlqgdLw-99oFoQ,14840
|
18
18
|
ring/algorithms/generator/batch.py,sha256=EOCX0vOxDwVOweArGsUneeeYysdSY2mFB55W052Wd9o,9161
|
19
19
|
ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
|
20
20
|
ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
|
@@ -56,7 +56,7 @@ ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
|
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
|
59
|
-
ring/ml/train.py,sha256=
|
59
|
+
ring/ml/train.py,sha256=uDW6JMdbMcjUKr3wCL2drWzDUd0Pc3BoroUwLcYoUx4,10914
|
60
60
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
61
61
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
62
62
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
@@ -76,8 +76,8 @@ ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
|
76
76
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
77
77
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
78
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
|
-
ring/utils/utils.py,sha256=
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
79
|
+
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
80
|
+
imt_ring-1.3.7.dist-info/METADATA,sha256=V6Oow_ZZwpBuHuIbyPIoKFtrhFboxMmuIPx1Rilq3-A,3104
|
81
|
+
imt_ring-1.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
82
|
+
imt_ring-1.3.7.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
83
|
+
imt_ring-1.3.7.dist-info/RECORD,,
|
@@ -293,12 +293,25 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
293
293
|
ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
|
294
294
|
|
295
295
|
|
296
|
+
def _scale_delta(method: str, key, xs, mn, mx, amin, amax, **kwargs):
|
297
|
+
if method == "normal":
|
298
|
+
delta = jnp.clip(jax.random.normal(key) + 0.5, 1.0)
|
299
|
+
elif method == "uniform":
|
300
|
+
delta = 1 / (jax.random.uniform(key) + 1e-2)
|
301
|
+
else:
|
302
|
+
raise NotImplementedError
|
303
|
+
|
304
|
+
return delta
|
305
|
+
|
306
|
+
|
296
307
|
def Polynomial_DrawnFnPair(
|
297
308
|
order: int = 2,
|
298
309
|
center: bool = False,
|
299
310
|
flexion_center_deg: Optional[float] = None,
|
300
311
|
include_bias: bool = True,
|
301
312
|
enable_scale_delta: bool = True,
|
313
|
+
scale_delta_method: str = "normal",
|
314
|
+
scale_delta_kwargs: dict = dict(),
|
302
315
|
) -> DrawnFnPairFactory:
|
303
316
|
assert not (order == 0 and not include_bias)
|
304
317
|
|
@@ -343,7 +356,9 @@ def Polynomial_DrawnFnPair(
|
|
343
356
|
amin, amax = jnp.min(values), jnp.max(values) + eps
|
344
357
|
if enable_scale_delta:
|
345
358
|
delta = amax - amin
|
346
|
-
scale_delta =
|
359
|
+
scale_delta = _scale_delta(
|
360
|
+
scale_delta_method, c3, xs, mn, mx, amin, amax, **scale_delta_kwargs
|
361
|
+
)
|
347
362
|
amax = amin + delta * scale_delta
|
348
363
|
return amin, amax, poly_factors, q0
|
349
364
|
|
@@ -140,8 +140,11 @@ class RCMG:
|
|
140
140
|
batchsize: int = 1,
|
141
141
|
sizes: int | list[int] = 1,
|
142
142
|
seed: int = 1,
|
143
|
+
shuffle: bool = True,
|
143
144
|
) -> types.BatchedGenerator:
|
144
|
-
return batch.batch_generators_eager(
|
145
|
+
return batch.batch_generators_eager(
|
146
|
+
self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
|
147
|
+
)
|
145
148
|
|
146
149
|
def to_lazy_gen(
|
147
150
|
self, sizes: int | list[int] = 1, jit: bool = True
|
ring/algorithms/sensors.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Optional
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
+
|
6
7
|
from ring import algebra
|
7
8
|
from ring import algorithms
|
8
9
|
from ring import base
|
@@ -445,22 +446,18 @@ def _joint_axes_from_xs(sys, xs, sys_xs):
|
|
445
446
|
|
446
447
|
def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
447
448
|
"`sys` should be `sys_noimu`. `N` is number of timesteps"
|
448
|
-
xaxis = jnp.array([1.0, 0, 0])
|
449
|
-
yaxis = jnp.array([0.0, 1, 0])
|
450
|
-
zaxis = jnp.array([0.0, 0, 1])
|
451
|
-
id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
|
452
449
|
X = {}
|
453
450
|
|
454
451
|
def f(_, __, name, link_type, link):
|
455
452
|
joint_params = link.joint_params
|
456
453
|
if link_type in ["rx", "ry", "rz"]:
|
457
|
-
joint_axes =
|
454
|
+
joint_axes = maths.unit_vectors(link_type[1])
|
458
455
|
elif link_type == "rr":
|
459
456
|
joint_axes = joint_params["rr"]["joint_axes"]
|
460
457
|
elif link_type[:6] == "rr_imp":
|
461
458
|
joint_axes = joint_params[link_type]["joint_axes"]
|
462
459
|
else:
|
463
|
-
joint_axes =
|
460
|
+
joint_axes = maths.x_unit_vector
|
464
461
|
X[name] = {"joint_axes": joint_axes}
|
465
462
|
|
466
463
|
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
ring/base.py
CHANGED
@@ -571,6 +571,7 @@ class System(_Base):
|
|
571
571
|
new_damp: Optional[jax.Array] = None,
|
572
572
|
new_stif: Optional[jax.Array] = None,
|
573
573
|
new_zero: Optional[jax.Array] = None,
|
574
|
+
seed: int = 1,
|
574
575
|
):
|
575
576
|
"By default damping, stiffness are set to zero."
|
576
577
|
from ring.algorithms import get_joint_model
|
@@ -600,7 +601,7 @@ class System(_Base):
|
|
600
601
|
|
601
602
|
jm = get_joint_model(new_joint_type)
|
602
603
|
if jm.init_joint_params is not None:
|
603
|
-
sys = sys.from_str(sys.to_str())
|
604
|
+
sys = sys.from_str(sys.to_str(), seed=seed)
|
604
605
|
|
605
606
|
return sys
|
606
607
|
|
ring/ml/train.py
CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import optax
|
8
|
+
import tree_utils
|
9
|
+
|
8
10
|
from ring import maths
|
9
11
|
from ring.algorithms.generator import types
|
10
12
|
from ring.ml import base as ml_base
|
@@ -15,8 +17,6 @@ from ring.utils import distribute_batchsize
|
|
15
17
|
from ring.utils import expand_batchsize
|
16
18
|
from ring.utils import parse_path
|
17
19
|
from ring.utils import pickle_load
|
18
|
-
import tree_utils
|
19
|
-
|
20
20
|
import wandb
|
21
21
|
|
22
22
|
# (T, N, F) -> Scalar
|
@@ -142,15 +142,17 @@ def train_fn(
|
|
142
142
|
Wether or not the training run was killed by a callback.
|
143
143
|
"""
|
144
144
|
|
145
|
+
filter = filter.nojit()
|
146
|
+
|
145
147
|
if checkpoint is not None:
|
146
148
|
checkpoint = Path(checkpoint).with_suffix(".pickle")
|
147
149
|
recv_checkpoint: dict = pickle_load(checkpoint)
|
148
|
-
|
150
|
+
filter_params = recv_checkpoint["params"]
|
149
151
|
opt_state = recv_checkpoint["opt_state"]
|
152
|
+
del recv_checkpoint
|
153
|
+
else:
|
154
|
+
filter_params = filter.search_attr("params")
|
150
155
|
|
151
|
-
filter = filter.nojit()
|
152
|
-
|
153
|
-
filter_params = filter.search_attr("params")
|
154
156
|
if filter_params is None:
|
155
157
|
X, _ = generator(jax.random.PRNGKey(1))
|
156
158
|
filter_params, _ = filter.init(X=X, seed=seed_network)
|
ring/utils/utils.py
CHANGED
@@ -122,6 +122,8 @@ def pytree_deepcopy(tree):
|
|
122
122
|
return tuple(pytree_deepcopy(ele) for ele in tree)
|
123
123
|
elif isinstance(tree, dict):
|
124
124
|
return {key: pytree_deepcopy(value) for key, value in tree.items()}
|
125
|
+
elif isinstance(tree, _Base):
|
126
|
+
return tree
|
125
127
|
else:
|
126
128
|
raise NotImplementedError(f"Not implemented for type={type(tree)}")
|
127
129
|
|
File without changes
|
File without changes
|