imt-ring 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.0.dist-info → imt_ring-1.3.2.dist-info}/METADATA +1 -1
- {imt_ring-1.3.0.dist-info → imt_ring-1.3.2.dist-info}/RECORD +14 -14
- ring/algorithms/custom_joints/__init__.py +1 -0
- ring/algorithms/custom_joints/suntay.py +57 -0
- ring/algorithms/generator/base.py +9 -15
- ring/algorithms/generator/batch.py +33 -11
- ring/base.py +6 -7
- ring/ml/__init__.py +5 -2
- ring/ml/optimizer.py +2 -2
- ring/rendering/base_render.py +11 -6
- ring/utils/__init__.py +1 -0
- ring/utils/utils.py +14 -0
- {imt_ring-1.3.0.dist-info → imt_ring-1.3.2.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.0.dist-info → imt_ring-1.3.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=gqdXejZ4E4liB5mZ6gPof3EDYTThlfro2MQs0bc5eOM,33530
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
|
@@ -9,13 +9,13 @@ ring/algorithms/dynamics.py,sha256=nqq5I0RYSbHNlGiLMlohz08IfL9Njsrid4upDnwkGbI,1
|
|
9
9
|
ring/algorithms/jcalc.py,sha256=oqSiwz3Be1VfIpmJXEFTNM_9_o3tyuTtyZt2aqttyN4,28213
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=Y3Wo9qj3BWKoIHB0V04QwyD-Z5m4BrAjfBX8Pn6y9Lg,18005
|
12
|
-
ring/algorithms/custom_joints/__init__.py,sha256=
|
12
|
+
ring/algorithms/custom_joints/__init__.py,sha256=_kUyC4TbzjngTQrJVtS6JBKPzTMNbH27jVRJYXViepI,270
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=3aFDfqdC2vUAhD30kkuQltgU_WZmYDyVhKPSoEotEYo,15292
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=zmrRK_I6BWoo4WbEcEVK7iFKdPfetc6txs7U8iu1xEk,14771
|
18
|
+
ring/algorithms/generator/batch.py,sha256=BGzmwH1AItXjPRyHtsYnAfYnoogw8jxhng9oyVw72lw,9019
|
19
19
|
ring/algorithms/generator/motion_artifacts.py,sha256=aKdkZU5OF4_aKyL4Yo-ftZRwrDCve1LuuREGAUlTqtI,8551
|
20
20
|
ring/algorithms/generator/pd_control.py,sha256=3pOaYig26vmp8gippDfy2KNJRZO3kr0rGd_PBIuEROM,5759
|
21
21
|
ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
@@ -50,17 +50,17 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
|
|
50
50
|
ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
|
51
51
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
52
52
|
ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
53
|
-
ring/ml/__init__.py,sha256
|
53
|
+
ring/ml/__init__.py,sha256=-bryExVoKJYSF_G_KYc5hI_GciIhj2xZ8WGi6TdRghw,1836
|
54
54
|
ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
|
55
55
|
ring/ml/callbacks.py,sha256=DkSy5c7IRqAAks2dx8acEBExYxUv-xiUFwZn4odPYq4,13253
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
|
-
ring/ml/optimizer.py,sha256=
|
57
|
+
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=OWRDu2COmptzbpJWlRLbPIn_ioKZCAd_iu-eiY_aPjk,8521
|
59
59
|
ring/ml/train.py,sha256=ftt2MOSSNGCdL7ZoAXcbIgeHW1Wkpgp6XYyLIBUIClI,10872
|
60
60
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
61
61
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
62
62
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
63
|
-
ring/rendering/base_render.py,sha256=
|
63
|
+
ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
|
64
64
|
ring/rendering/mujoco_render.py,sha256=aluzQJp3jrDdPfAyNmQuXIHRfgfBTCCZQqxKOx_0D2s,7770
|
65
65
|
ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
|
66
66
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
@@ -70,14 +70,14 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
70
70
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
71
71
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
72
72
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
73
|
-
ring/utils/__init__.py,sha256=
|
73
|
+
ring/utils/__init__.py,sha256=rTvSA4RiJAVCY_A64FUMd8IJTv94LgoSA3Ps5X63_jA,799
|
74
74
|
ring/utils/batchsize.py,sha256=mPFGD7AedFMycHtyIuZtNWCaAvKLLWSWaB7X6u54xvM,1358
|
75
75
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
76
76
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
77
77
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
78
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
|
-
ring/utils/utils.py,sha256=
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
79
|
+
ring/utils/utils.py,sha256=I2f6-DMBrrgy5tpLzPLlezifQgkO2fERZWyX3cfb4sI,5303
|
80
|
+
imt_ring-1.3.2.dist-info/METADATA,sha256=OQmB5-CEpy-JWv2K9vYo9fUNQDW0Jg-fFU4kBkiVRGQ,3104
|
81
|
+
imt_ring-1.3.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
82
|
+
imt_ring-1.3.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
83
|
+
imt_ring-1.3.2.dist-info/RECORD,,
|
@@ -2,5 +2,6 @@ from .rr_imp_joint import register_rr_imp_joint
|
|
2
2
|
from .rr_joint import register_rr_joint
|
3
3
|
from .suntay import GP_DrawFnPair
|
4
4
|
from .suntay import MLP_DrawnFnPair
|
5
|
+
from .suntay import Polynomial_DrawnFnPair
|
5
6
|
from .suntay import register_suntay
|
6
7
|
from .suntay import SuntayConfig
|
@@ -293,6 +293,63 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
293
293
|
ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
|
294
294
|
|
295
295
|
|
296
|
+
def Polynomial_DrawnFnPair(
|
297
|
+
order: int = 2,
|
298
|
+
val: float = 2.0,
|
299
|
+
center: bool = False,
|
300
|
+
flexion_center: Optional[float] = None,
|
301
|
+
) -> DrawnFnPairFactory:
|
302
|
+
assert val >= 0.0
|
303
|
+
|
304
|
+
# because 0-th order is also counted
|
305
|
+
order += 1
|
306
|
+
|
307
|
+
def factory(xs, mn, mx):
|
308
|
+
nonlocal flexion_center
|
309
|
+
|
310
|
+
flexion_mn = jnp.min(xs)
|
311
|
+
flexion_mx = jnp.max(xs)
|
312
|
+
|
313
|
+
def _apply_poly_factors(poly_factors, q):
|
314
|
+
return poly_factors @ jnp.power(q, jnp.arange(order))
|
315
|
+
|
316
|
+
if flexion_center is None:
|
317
|
+
flexion_center = (flexion_mn + flexion_mx) / 2
|
318
|
+
else:
|
319
|
+
flexion_center = jnp.array(flexion_center)
|
320
|
+
|
321
|
+
def init(key):
|
322
|
+
c1, c2 = jax.random.split(key)
|
323
|
+
poly_factors = jax.random.uniform(
|
324
|
+
c1, shape=(order,), minval=-val, maxval=val
|
325
|
+
)
|
326
|
+
q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
|
327
|
+
values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
|
328
|
+
poly_factors, xs - q0
|
329
|
+
)
|
330
|
+
amax = jnp.max(values)
|
331
|
+
amin = jnp.min(values)
|
332
|
+
return amin, amax, poly_factors, q0
|
333
|
+
|
334
|
+
def _apply(params, q):
|
335
|
+
amin, amax, poly_factors, q0 = params
|
336
|
+
q = q - q0
|
337
|
+
value = _apply_poly_factors(poly_factors, q)
|
338
|
+
return restrict(value, mn, mx, amin, amax)
|
339
|
+
|
340
|
+
if center:
|
341
|
+
|
342
|
+
def apply(params, q):
|
343
|
+
return _apply(params, q) - _apply(params, flexion_center)
|
344
|
+
|
345
|
+
else:
|
346
|
+
apply = _apply
|
347
|
+
|
348
|
+
return DrawnFnPair(init, apply)
|
349
|
+
|
350
|
+
return factory
|
351
|
+
|
352
|
+
|
296
353
|
def MLP_DrawnFnPair(
|
297
354
|
center: bool = False, flexion_center: Optional[float] = None
|
298
355
|
) -> DrawnFnPairFactory:
|
@@ -4,6 +4,8 @@ import warnings
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
import tree_utils
|
8
|
+
|
7
9
|
from ring import base
|
8
10
|
from ring import utils
|
9
11
|
from ring.algorithms import jcalc
|
@@ -13,7 +15,6 @@ from ring.algorithms.generator import motion_artifacts
|
|
13
15
|
from ring.algorithms.generator import randomize
|
14
16
|
from ring.algorithms.generator import transforms
|
15
17
|
from ring.algorithms.generator import types
|
16
|
-
import tree_utils
|
17
18
|
|
18
19
|
|
19
20
|
class RCMG:
|
@@ -108,23 +109,20 @@ class RCMG:
|
|
108
109
|
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
109
110
|
)
|
110
111
|
|
111
|
-
def _to_data(self, sizes, seed
|
112
|
-
return batch.batch_generators_eager_to_list(
|
113
|
-
self.gens, sizes, seed=seed, jit=jit
|
114
|
-
)
|
112
|
+
def _to_data(self, sizes, seed):
|
113
|
+
return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
|
115
114
|
|
116
|
-
def to_list(self, sizes: int | list[int] = 1, seed: int = 1
|
117
|
-
return self._to_data(sizes, seed
|
115
|
+
def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
|
116
|
+
return self._to_data(sizes, seed)
|
118
117
|
|
119
118
|
def to_pickle(
|
120
119
|
self,
|
121
120
|
path: str,
|
122
121
|
sizes: int | list[int] = 1,
|
123
122
|
seed: int = 1,
|
124
|
-
jit: bool = False,
|
125
123
|
overwrite: bool = True,
|
126
124
|
) -> None:
|
127
|
-
data = tree_utils.tree_batch(self._to_data(sizes, seed
|
125
|
+
data = tree_utils.tree_batch(self._to_data(sizes, seed))
|
128
126
|
utils.pickle_save(data, path, overwrite=overwrite)
|
129
127
|
|
130
128
|
def to_hdf5(
|
@@ -132,10 +130,9 @@ class RCMG:
|
|
132
130
|
path: str,
|
133
131
|
sizes: int | list[int] = 1,
|
134
132
|
seed: int = 1,
|
135
|
-
jit: bool = False,
|
136
133
|
overwrite: bool = True,
|
137
134
|
) -> None:
|
138
|
-
data = tree_utils.tree_batch(self._to_data(sizes, seed
|
135
|
+
data = tree_utils.tree_batch(self._to_data(sizes, seed))
|
139
136
|
utils.hdf5_save(path, data, overwrite=overwrite)
|
140
137
|
|
141
138
|
def to_eager_gen(
|
@@ -143,11 +140,8 @@ class RCMG:
|
|
143
140
|
batchsize: int = 1,
|
144
141
|
sizes: int | list[int] = 1,
|
145
142
|
seed: int = 1,
|
146
|
-
jit: bool = False,
|
147
143
|
) -> types.BatchedGenerator:
|
148
|
-
return batch.batch_generators_eager(
|
149
|
-
self.gens, sizes, batchsize, seed=seed, jit=jit
|
150
|
-
)
|
144
|
+
return batch.batch_generators_eager(self.gens, sizes, batchsize, seed=seed)
|
151
145
|
|
152
146
|
def to_lazy_gen(
|
153
147
|
self, sizes: int | list[int] = 1, jit: bool = True
|
@@ -6,12 +6,13 @@ import warnings
|
|
6
6
|
import jax
|
7
7
|
import jax.numpy as jnp
|
8
8
|
import numpy as np
|
9
|
-
from ring import utils
|
10
|
-
from ring.algorithms.generator import types
|
11
9
|
from tqdm import tqdm
|
12
10
|
import tree_utils
|
13
11
|
from tree_utils import tree_batch
|
14
12
|
|
13
|
+
from ring import utils
|
14
|
+
from ring.algorithms.generator import types
|
15
|
+
|
15
16
|
|
16
17
|
def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
|
17
18
|
arr = []
|
@@ -61,11 +62,24 @@ def batch_generators_lazy(
|
|
61
62
|
return generator
|
62
63
|
|
63
64
|
|
65
|
+
def _number_of_executions_required(size: int) -> int:
|
66
|
+
vmap_threshold = 128
|
67
|
+
_, vmap = utils.distribute_batchsize(size)
|
68
|
+
|
69
|
+
primes = iter(utils.primes(vmap))
|
70
|
+
n_calls = 1
|
71
|
+
while vmap > vmap_threshold:
|
72
|
+
prime = next(primes)
|
73
|
+
n_calls *= prime
|
74
|
+
vmap /= prime
|
75
|
+
|
76
|
+
return n_calls
|
77
|
+
|
78
|
+
|
64
79
|
def batch_generators_eager_to_list(
|
65
80
|
generators: types.Generator | list[types.Generator],
|
66
81
|
sizes: int | list[int],
|
67
82
|
seed: int = 1,
|
68
|
-
jit: bool = True,
|
69
83
|
) -> list[tree_utils.PyTree]:
|
70
84
|
"Returns list of unbatched sequences as numpy arrays."
|
71
85
|
generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
|
@@ -73,11 +87,20 @@ def batch_generators_eager_to_list(
|
|
73
87
|
key = jax.random.PRNGKey(seed)
|
74
88
|
data = []
|
75
89
|
for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
|
76
|
-
|
77
|
-
|
78
|
-
#
|
79
|
-
|
80
|
-
|
90
|
+
|
91
|
+
n_calls = _number_of_executions_required(size)
|
92
|
+
# decrease size by n_calls times
|
93
|
+
size = int(size / n_calls)
|
94
|
+
jit = True if n_calls > 1 else False
|
95
|
+
gen_jit = batch_generators_lazy(gen, size, jit=jit)
|
96
|
+
|
97
|
+
for _ in range(n_calls):
|
98
|
+
key, consume = jax.random.split(key)
|
99
|
+
sample = gen_jit(consume)
|
100
|
+
# converts also to numpy
|
101
|
+
sample = jax.device_get(sample)
|
102
|
+
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
103
|
+
|
81
104
|
return data
|
82
105
|
|
83
106
|
|
@@ -243,12 +266,11 @@ def batch_generators_eager(
|
|
243
266
|
shuffle: bool = True,
|
244
267
|
drop_last: bool = True,
|
245
268
|
seed: int = 1,
|
246
|
-
jit: bool = True,
|
247
269
|
) -> types.BatchedGenerator:
|
248
270
|
"""Eagerly create a large precomputed generator by calling multiple generators
|
249
271
|
and stacking their output."""
|
250
272
|
|
251
|
-
data = batch_generators_eager_to_list(generators, sizes, seed=seed
|
273
|
+
data = batch_generators_eager_to_list(generators, sizes, seed=seed)
|
252
274
|
return batched_generator_from_list(data, batchsize, shuffle, drop_last)
|
253
275
|
|
254
276
|
|
@@ -270,7 +292,7 @@ def _process_sizes_batchsizes_generators(
|
|
270
292
|
|
271
293
|
assert len(generators) == len(list_sizes)
|
272
294
|
|
273
|
-
_WARN_SIZE =
|
295
|
+
_WARN_SIZE = 1e6 # disable this warning
|
274
296
|
for size in list_sizes:
|
275
297
|
if size >= _WARN_SIZE:
|
276
298
|
warnings.warn(
|
ring/base.py
CHANGED
@@ -99,15 +99,15 @@ class _Base:
|
|
99
99
|
def ndim(self):
|
100
100
|
return tu.tree_ndim(self)
|
101
101
|
|
102
|
-
def shape(self, axis=0) -> int:
|
103
|
-
|
104
|
-
|
105
|
-
def __len__(self) -> int:
|
106
|
-
Bs = tree_map(lambda arr: arr.shape[0], self)
|
102
|
+
def shape(self, axis: int = 0) -> int:
|
103
|
+
Bs = tree_map(lambda arr: arr.shape[axis], self)
|
107
104
|
Bs = set(jax.tree_util.tree_flatten(Bs)[0])
|
108
105
|
assert len(Bs) == 1
|
109
106
|
return list(Bs)[0]
|
110
107
|
|
108
|
+
def __len__(self) -> int:
|
109
|
+
return self.shape(axis=0)
|
110
|
+
|
111
111
|
|
112
112
|
@struct.dataclass
|
113
113
|
class Transform(_Base):
|
@@ -685,14 +685,13 @@ class System(_Base):
|
|
685
685
|
self,
|
686
686
|
xs: Transform | list[Transform],
|
687
687
|
yhat: dict | jax.Array | np.ndarray,
|
688
|
-
stepframe: int = 1,
|
689
688
|
# by default we don't predict the global rotation
|
690
689
|
transparent_segment_to_root: bool = True,
|
691
690
|
**kwargs,
|
692
691
|
):
|
693
692
|
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
694
693
|
return ring.rendering.render_prediction(
|
695
|
-
self, xs, yhat,
|
694
|
+
self, xs, yhat, transparent_segment_to_root, **kwargs
|
696
695
|
)
|
697
696
|
|
698
697
|
def delete_system(self, link_name: str | list[str], strict: bool = True):
|
ring/ml/__init__.py
CHANGED
@@ -12,6 +12,8 @@ from .optimizer import make_optimizer
|
|
12
12
|
from .ringnet import RING
|
13
13
|
from .train import train_fn
|
14
14
|
|
15
|
+
_lpf_cutoff_freq = 10.0
|
16
|
+
|
15
17
|
|
16
18
|
def RING_ICML24(params=None, eval: bool = True, **kwargs):
|
17
19
|
"""Create the RING network used in the icml24 paper.
|
@@ -29,7 +31,7 @@ def RING_ICML24(params=None, eval: bool = True, **kwargs):
|
|
29
31
|
ringnet = RING(params=params, **kwargs) # noqa: F811
|
30
32
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
31
33
|
if eval:
|
32
|
-
ringnet = base.LPF_FilterWrapper(ringnet,
|
34
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
|
33
35
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
34
36
|
return ringnet
|
35
37
|
|
@@ -39,6 +41,7 @@ def RNNO(
|
|
39
41
|
return_quats: bool = False,
|
40
42
|
params=None,
|
41
43
|
eval: bool = True,
|
44
|
+
samp_freq: float | None = None,
|
42
45
|
**kwargs,
|
43
46
|
):
|
44
47
|
assert "message_dim" not in kwargs
|
@@ -55,7 +58,7 @@ def RNNO(
|
|
55
58
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
56
59
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
57
60
|
if eval and return_quats:
|
58
|
-
ringnet = base.LPF_FilterWrapper(ringnet,
|
61
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
|
59
62
|
if return_quats:
|
60
63
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
61
64
|
return ringnet
|
ring/ml/optimizer.py
CHANGED
@@ -14,10 +14,10 @@ from optax._src.transform import AddNoiseState
|
|
14
14
|
def make_optimizer(
|
15
15
|
lr: float,
|
16
16
|
n_episodes: int,
|
17
|
-
n_steps_per_episode: int,
|
17
|
+
n_steps_per_episode: int = 6,
|
18
18
|
adap_clip: Optional[float] = 0.1,
|
19
19
|
glob_clip: Optional[float] = 0.2,
|
20
|
-
skip_large_update_max_normsq: float =
|
20
|
+
skip_large_update_max_normsq: float = 100.0,
|
21
21
|
skip_large_update_warmup: int = 300,
|
22
22
|
inner_opt=optax.lamb,
|
23
23
|
cos_decay_twice: bool = False,
|
ring/rendering/base_render.py
CHANGED
@@ -136,12 +136,15 @@ def render_prediction(
|
|
136
136
|
sys: base.System,
|
137
137
|
xs: base.Transform | list[base.Transform],
|
138
138
|
yhat: dict | jax.Array | np.ndarray,
|
139
|
-
stepframe: int = 1,
|
140
139
|
# by default we don't predict the global rotation
|
141
140
|
transparent_segment_to_root: bool = True,
|
142
141
|
**kwargs,
|
143
142
|
):
|
144
143
|
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
144
|
+
|
145
|
+
offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
|
146
|
+
offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
|
147
|
+
|
145
148
|
if isinstance(xs, list):
|
146
149
|
# list -> batched Transform
|
147
150
|
xs = xs[0].batch(*xs[1:])
|
@@ -180,18 +183,23 @@ def render_prediction(
|
|
180
183
|
|
181
184
|
# swap time axis, and link axis
|
182
185
|
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
|
186
|
+
|
187
|
+
add_offset = lambda x, offset: algebra.transform_mul(
|
188
|
+
x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
|
189
|
+
)
|
190
|
+
|
183
191
|
# create mapping from `name` -> Transform
|
184
192
|
xs_dict = dict(
|
185
193
|
zip(
|
186
194
|
["hat_" + name for name in sys_noimu.link_names],
|
187
|
-
[xshat[i] for i in range(sys_noimu.num_links())],
|
195
|
+
[add_offset(xshat[i], offset_pred) for i in range(sys_noimu.num_links())],
|
188
196
|
)
|
189
197
|
)
|
190
198
|
xs_dict.update(
|
191
199
|
dict(
|
192
200
|
zip(
|
193
201
|
sys.link_names,
|
194
|
-
[xs[i] for i in range(sys.num_links())],
|
202
|
+
[add_offset(xs[i], offset_truth) for i in range(sys.num_links())],
|
195
203
|
)
|
196
204
|
)
|
197
205
|
)
|
@@ -202,11 +210,8 @@ def render_prediction(
|
|
202
210
|
xs_render.append(xs_dict[name])
|
203
211
|
xs_render = xs_render[0].batch(*xs_render[1:])
|
204
212
|
xs_render = xs_render.transpose((1, 0, 2))
|
205
|
-
N = xs_render.shape()
|
206
|
-
xs_render = [xs_render[t] for t in range(0, N, stepframe)]
|
207
213
|
|
208
214
|
frames = render(sys_render, xs_render, **kwargs)
|
209
|
-
|
210
215
|
return frames
|
211
216
|
|
212
217
|
|
ring/utils/__init__.py
CHANGED
ring/utils/utils.py
CHANGED
@@ -159,3 +159,17 @@ def pickle_load(
|
|
159
159
|
with open(path, "rb") as file:
|
160
160
|
obj = pickle.load(file)
|
161
161
|
return obj
|
162
|
+
|
163
|
+
|
164
|
+
def primes(n: int) -> list[int]:
|
165
|
+
"Primefactor decomposition in ascending order."
|
166
|
+
primfac = []
|
167
|
+
d = 2
|
168
|
+
while d * d <= n:
|
169
|
+
while (n % d) == 0:
|
170
|
+
primfac.append(d) # supposing you want multiple factors repeated
|
171
|
+
n //= d
|
172
|
+
d += 1
|
173
|
+
if n > 1:
|
174
|
+
primfac.append(n)
|
175
|
+
return primfac
|
File without changes
|
File without changes
|