imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/ml/base.py
ADDED
@@ -0,0 +1,292 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from abc import abstractmethod
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import tree_utils
|
7
|
+
|
8
|
+
import ring
|
9
|
+
from ring.utils import pickle_load
|
10
|
+
from ring.utils import pickle_save
|
11
|
+
|
12
|
+
|
13
|
+
def _to_3d(tree):
|
14
|
+
if tree is None:
|
15
|
+
return None
|
16
|
+
return jax.tree_map(lambda arr: arr[None], tree)
|
17
|
+
|
18
|
+
|
19
|
+
def _to_2d(tree, i: int = 0):
|
20
|
+
if tree is None:
|
21
|
+
return None
|
22
|
+
return jax.tree_map(lambda arr: arr[i], tree)
|
23
|
+
|
24
|
+
|
25
|
+
class AbstractFilter(ABC):
|
26
|
+
def _apply_unbatched(self, X, params, state, y, lam):
|
27
|
+
return _to_2d(
|
28
|
+
self._apply_batched(
|
29
|
+
X=_to_3d(X), params=params, state=_to_3d(state), y=_to_3d(y), lam=lam
|
30
|
+
)
|
31
|
+
)
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def _apply_batched(self, X, params, state, y, lam):
|
35
|
+
pass
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def init(self, bs, X, lam, seed: int):
|
39
|
+
pass
|
40
|
+
|
41
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
42
|
+
"X.shape = (B, T, N, F) or (T, N, F)"
|
43
|
+
assert X.ndim in [3, 4]
|
44
|
+
if X.ndim == 4:
|
45
|
+
return self._apply_batched(X, params, state, y, lam)
|
46
|
+
else:
|
47
|
+
return self._apply_unbatched(X, params, state, y, lam)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def name(self) -> str:
|
51
|
+
if not hasattr(self, "_name"):
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
if self._name is None:
|
55
|
+
raise RuntimeError("No `name` was given.")
|
56
|
+
return self._name
|
57
|
+
|
58
|
+
def nojit(self) -> "AbstractFilter":
|
59
|
+
return self
|
60
|
+
|
61
|
+
def _pre_save(self, *args, **kwargs) -> None:
|
62
|
+
pass
|
63
|
+
|
64
|
+
def save(self, path: str, *args, **kwargs):
|
65
|
+
self._pre_save(*args, **kwargs)
|
66
|
+
pickle_save(self.nojit(), path, overwrite=True)
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def _post_load(filter: "AbstractFilter", *args, **kwargs) -> "AbstractFilter":
|
70
|
+
pass
|
71
|
+
|
72
|
+
@classmethod
|
73
|
+
def load(cls, path: str, *args, **kwargs):
|
74
|
+
filter = pickle_load(path)
|
75
|
+
return cls._post_load(filter, *args, **kwargs)
|
76
|
+
|
77
|
+
def search_attr(self, attr: str):
|
78
|
+
return getattr(self, attr)
|
79
|
+
|
80
|
+
|
81
|
+
class AbstractFilterUnbatched(AbstractFilter):
|
82
|
+
@abstractmethod
|
83
|
+
def _apply_unbatched(self, X, params, state, y, lam):
|
84
|
+
pass
|
85
|
+
|
86
|
+
def _apply_batched(self, X, params, state, y, lam):
|
87
|
+
N = X.shape[0]
|
88
|
+
ys = []
|
89
|
+
for i in range(N):
|
90
|
+
ys.append(
|
91
|
+
self._apply_unbatched(
|
92
|
+
_to_2d(X, i), params, _to_2d(state, i), _to_2d(y, i), lam
|
93
|
+
)
|
94
|
+
)
|
95
|
+
return tree_utils.tree_batch(ys)
|
96
|
+
|
97
|
+
|
98
|
+
class AbstractFilterWrapper(AbstractFilter):
|
99
|
+
def __init__(self, filter: AbstractFilter, name=None) -> None:
|
100
|
+
self._filter = filter
|
101
|
+
self._name = name
|
102
|
+
|
103
|
+
def _apply_batched(self, X, params, state, y, lam):
|
104
|
+
raise NotImplementedError
|
105
|
+
|
106
|
+
@property
|
107
|
+
def unwrapped(self) -> AbstractFilter:
|
108
|
+
return self._filter
|
109
|
+
|
110
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
111
|
+
return self.unwrapped.apply(X=X, params=params, state=state, y=y, lam=lam)
|
112
|
+
|
113
|
+
def init(self, bs=None, X=None, lam=None, seed: int = 1):
|
114
|
+
return self.unwrapped.init(bs=bs, X=X, lam=lam, seed=seed)
|
115
|
+
|
116
|
+
def nojit(self) -> "AbstractFilterWrapper":
|
117
|
+
self._filter = self.unwrapped.nojit()
|
118
|
+
return self
|
119
|
+
|
120
|
+
def search_attr(self, attr: str):
|
121
|
+
if hasattr(self, attr):
|
122
|
+
return super().search_attr(attr)
|
123
|
+
return self.unwrapped.search_attr(attr)
|
124
|
+
|
125
|
+
def _pre_save(self, *args, **kwargs):
|
126
|
+
self.unwrapped._pre_save(*args, **kwargs)
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def _post_load(
|
130
|
+
wrapper: "AbstractFilterWrapper", *args, **kwargs
|
131
|
+
) -> "AbstractFilterWrapper":
|
132
|
+
wrapper._filter = wrapper._filter._post_load(wrapper._filter, *args, **kwargs)
|
133
|
+
return wrapper
|
134
|
+
|
135
|
+
@property
|
136
|
+
def name(self):
|
137
|
+
return self.unwrapped.name + " ->\n" + super().name
|
138
|
+
|
139
|
+
|
140
|
+
class LPF_FilterWrapper(AbstractFilterWrapper):
|
141
|
+
def __init__(
|
142
|
+
self,
|
143
|
+
filter: AbstractFilter,
|
144
|
+
cutoff_freq: float,
|
145
|
+
samp_freq: float | None,
|
146
|
+
filtfilt: bool = True,
|
147
|
+
name="LPF_FilterWrapper",
|
148
|
+
) -> None:
|
149
|
+
super().__init__(filter, name)
|
150
|
+
self.samp_freq = samp_freq
|
151
|
+
self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
|
152
|
+
|
153
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
154
|
+
if X.ndim == 4:
|
155
|
+
if self.samp_freq is not None:
|
156
|
+
samp_freq = jnp.repeat(jnp.array(self.samp_freq), X.shape[0])
|
157
|
+
else:
|
158
|
+
assert X.shape[-1] == 10
|
159
|
+
dt = X[:, 0, 0, -1]
|
160
|
+
samp_freq = 1 / dt
|
161
|
+
else:
|
162
|
+
if self.samp_freq is not None:
|
163
|
+
samp_freq = jnp.array(self.samp_freq)
|
164
|
+
else:
|
165
|
+
assert X.shape[-1] == 10
|
166
|
+
dt = X[0, 0, -1]
|
167
|
+
samp_freq = 1 / dt
|
168
|
+
|
169
|
+
if self.samp_freq is None:
|
170
|
+
print(f"Detected the following sampling rates from `X`: {samp_freq}")
|
171
|
+
|
172
|
+
yhat, state = super().apply(X, params, state, y, lam)
|
173
|
+
|
174
|
+
if yhat.ndim == 4:
|
175
|
+
yhat = jax.vmap(
|
176
|
+
jax.vmap(
|
177
|
+
lambda q, samp_freq: ring.maths.quat_lowpassfilter(
|
178
|
+
q, samp_freq=samp_freq, **self._kwargs
|
179
|
+
),
|
180
|
+
in_axes=(1, None),
|
181
|
+
out_axes=1,
|
182
|
+
)
|
183
|
+
)(yhat, samp_freq)
|
184
|
+
else:
|
185
|
+
yhat = jax.vmap(
|
186
|
+
lambda q, samp_freq: ring.maths.quat_lowpassfilter(
|
187
|
+
q, samp_freq=samp_freq, **self._kwargs
|
188
|
+
),
|
189
|
+
in_axes=(1, None),
|
190
|
+
out_axes=1,
|
191
|
+
)(yhat, samp_freq)
|
192
|
+
return yhat, state
|
193
|
+
|
194
|
+
|
195
|
+
class GroundTruthHeading_FilterWrapper(AbstractFilterWrapper):
|
196
|
+
|
197
|
+
def __init__(
|
198
|
+
self, filter: AbstractFilter, name="GroundTruthHeading_FilterWrapper"
|
199
|
+
) -> None:
|
200
|
+
super().__init__(filter, name)
|
201
|
+
|
202
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
203
|
+
yhat, state = super().apply(X, params, state, y, lam)
|
204
|
+
if lam is None:
|
205
|
+
lam = self.search_attr("lam")
|
206
|
+
yhat = self.transfer_ground_truth_heading(lam, y, yhat)
|
207
|
+
return yhat, state
|
208
|
+
|
209
|
+
@staticmethod
|
210
|
+
def transfer_ground_truth_heading(lam, y, yhat) -> None:
|
211
|
+
if y is None:
|
212
|
+
return yhat
|
213
|
+
|
214
|
+
assert lam is not None
|
215
|
+
yhat = jnp.array(yhat)
|
216
|
+
for i, p in enumerate(lam):
|
217
|
+
if p == -1:
|
218
|
+
yhat = yhat.at[..., i, :].set(
|
219
|
+
ring.maths.quat_transfer_heading(y[..., i, :], yhat[..., i, :])
|
220
|
+
)
|
221
|
+
return yhat
|
222
|
+
|
223
|
+
|
224
|
+
_default_factors = dict(gyr=1 / 2.2, acc=1 / 9.81, joint_axes=1 / 0.57, dt=10.0)
|
225
|
+
|
226
|
+
|
227
|
+
class ScaleX_FilterWrapper(AbstractFilterWrapper):
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
filter: AbstractFilter,
|
232
|
+
factors: dict[str, float] = _default_factors,
|
233
|
+
name="ScaleX_FilterWrapper",
|
234
|
+
) -> None:
|
235
|
+
super().__init__(filter, name)
|
236
|
+
self._factors = factors
|
237
|
+
|
238
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
239
|
+
F = X.shape[-1]
|
240
|
+
num_batch_dims = X.ndim - 1
|
241
|
+
|
242
|
+
if F == 6:
|
243
|
+
X = dict(acc=X[..., :3], gyr=X[..., 3:])
|
244
|
+
elif F == 9:
|
245
|
+
X = dict(acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:])
|
246
|
+
elif F == 10:
|
247
|
+
X = dict(
|
248
|
+
acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:9], dt=X[..., 9:10]
|
249
|
+
)
|
250
|
+
else:
|
251
|
+
raise Exception(f"X.shape={X.shape}")
|
252
|
+
X = {key: val * self._factors[key] for key, val in X.items()}
|
253
|
+
X = tree_utils.batch_concat_acme(X, num_batch_dims=num_batch_dims)
|
254
|
+
return super().apply(X, params, state, y, lam)
|
255
|
+
|
256
|
+
|
257
|
+
class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
258
|
+
|
259
|
+
def __init__(
|
260
|
+
self, filter: AbstractFilter, quat_normalize: bool = False, name=None
|
261
|
+
) -> None:
|
262
|
+
super().__init__(filter, name)
|
263
|
+
self._quat_normalize = quat_normalize
|
264
|
+
|
265
|
+
def init(self, bs=None, X=None, lam=None, seed: int = 1):
|
266
|
+
batched = X.ndim == 4
|
267
|
+
if batched:
|
268
|
+
B, T, N, F = X.shape
|
269
|
+
X = X.reshape((B, T, 1, N * F))
|
270
|
+
else:
|
271
|
+
T, N, F = X.shape
|
272
|
+
X = X.reshape(T, 1, N * F)
|
273
|
+
return super().init(bs, X, (-1,), seed)
|
274
|
+
|
275
|
+
def apply(self, X: jax.Array, params=None, state=None, y=None, lam=None):
|
276
|
+
batched = X.ndim == 4
|
277
|
+
if batched:
|
278
|
+
B, T, N, F = X.shape
|
279
|
+
X = X.reshape((B, T, 1, N * F))
|
280
|
+
yhat, state = super().apply(X, params, state, y, (-1,))
|
281
|
+
yhat = yhat.reshape((B, T, N, -1))
|
282
|
+
else:
|
283
|
+
T, N, F = X.shape
|
284
|
+
X = X.reshape((T, 1, N * F))
|
285
|
+
yhat, state = super().apply(X, params, state, y, (-1,))
|
286
|
+
yhat = yhat.reshape((T, N, -1))
|
287
|
+
|
288
|
+
if self._quat_normalize:
|
289
|
+
assert yhat.shape[-1] == 4
|
290
|
+
yhat = ring.maths.safe_normalize(yhat)
|
291
|
+
|
292
|
+
return yhat, state
|