imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
ring/ml/base.py ADDED
@@ -0,0 +1,292 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import tree_utils
7
+
8
+ import ring
9
+ from ring.utils import pickle_load
10
+ from ring.utils import pickle_save
11
+
12
+
13
+ def _to_3d(tree):
14
+ if tree is None:
15
+ return None
16
+ return jax.tree_map(lambda arr: arr[None], tree)
17
+
18
+
19
+ def _to_2d(tree, i: int = 0):
20
+ if tree is None:
21
+ return None
22
+ return jax.tree_map(lambda arr: arr[i], tree)
23
+
24
+
25
+ class AbstractFilter(ABC):
26
+ def _apply_unbatched(self, X, params, state, y, lam):
27
+ return _to_2d(
28
+ self._apply_batched(
29
+ X=_to_3d(X), params=params, state=_to_3d(state), y=_to_3d(y), lam=lam
30
+ )
31
+ )
32
+
33
+ @abstractmethod
34
+ def _apply_batched(self, X, params, state, y, lam):
35
+ pass
36
+
37
+ @abstractmethod
38
+ def init(self, bs, X, lam, seed: int):
39
+ pass
40
+
41
+ def apply(self, X, params=None, state=None, y=None, lam=None):
42
+ "X.shape = (B, T, N, F) or (T, N, F)"
43
+ assert X.ndim in [3, 4]
44
+ if X.ndim == 4:
45
+ return self._apply_batched(X, params, state, y, lam)
46
+ else:
47
+ return self._apply_unbatched(X, params, state, y, lam)
48
+
49
+ @property
50
+ def name(self) -> str:
51
+ if not hasattr(self, "_name"):
52
+ raise NotImplementedError
53
+
54
+ if self._name is None:
55
+ raise RuntimeError("No `name` was given.")
56
+ return self._name
57
+
58
+ def nojit(self) -> "AbstractFilter":
59
+ return self
60
+
61
+ def _pre_save(self, *args, **kwargs) -> None:
62
+ pass
63
+
64
+ def save(self, path: str, *args, **kwargs):
65
+ self._pre_save(*args, **kwargs)
66
+ pickle_save(self.nojit(), path, overwrite=True)
67
+
68
+ @staticmethod
69
+ def _post_load(filter: "AbstractFilter", *args, **kwargs) -> "AbstractFilter":
70
+ pass
71
+
72
+ @classmethod
73
+ def load(cls, path: str, *args, **kwargs):
74
+ filter = pickle_load(path)
75
+ return cls._post_load(filter, *args, **kwargs)
76
+
77
+ def search_attr(self, attr: str):
78
+ return getattr(self, attr)
79
+
80
+
81
+ class AbstractFilterUnbatched(AbstractFilter):
82
+ @abstractmethod
83
+ def _apply_unbatched(self, X, params, state, y, lam):
84
+ pass
85
+
86
+ def _apply_batched(self, X, params, state, y, lam):
87
+ N = X.shape[0]
88
+ ys = []
89
+ for i in range(N):
90
+ ys.append(
91
+ self._apply_unbatched(
92
+ _to_2d(X, i), params, _to_2d(state, i), _to_2d(y, i), lam
93
+ )
94
+ )
95
+ return tree_utils.tree_batch(ys)
96
+
97
+
98
+ class AbstractFilterWrapper(AbstractFilter):
99
+ def __init__(self, filter: AbstractFilter, name=None) -> None:
100
+ self._filter = filter
101
+ self._name = name
102
+
103
+ def _apply_batched(self, X, params, state, y, lam):
104
+ raise NotImplementedError
105
+
106
+ @property
107
+ def unwrapped(self) -> AbstractFilter:
108
+ return self._filter
109
+
110
+ def apply(self, X, params=None, state=None, y=None, lam=None):
111
+ return self.unwrapped.apply(X=X, params=params, state=state, y=y, lam=lam)
112
+
113
+ def init(self, bs=None, X=None, lam=None, seed: int = 1):
114
+ return self.unwrapped.init(bs=bs, X=X, lam=lam, seed=seed)
115
+
116
+ def nojit(self) -> "AbstractFilterWrapper":
117
+ self._filter = self.unwrapped.nojit()
118
+ return self
119
+
120
+ def search_attr(self, attr: str):
121
+ if hasattr(self, attr):
122
+ return super().search_attr(attr)
123
+ return self.unwrapped.search_attr(attr)
124
+
125
+ def _pre_save(self, *args, **kwargs):
126
+ self.unwrapped._pre_save(*args, **kwargs)
127
+
128
+ @staticmethod
129
+ def _post_load(
130
+ wrapper: "AbstractFilterWrapper", *args, **kwargs
131
+ ) -> "AbstractFilterWrapper":
132
+ wrapper._filter = wrapper._filter._post_load(wrapper._filter, *args, **kwargs)
133
+ return wrapper
134
+
135
+ @property
136
+ def name(self):
137
+ return self.unwrapped.name + " ->\n" + super().name
138
+
139
+
140
+ class LPF_FilterWrapper(AbstractFilterWrapper):
141
+ def __init__(
142
+ self,
143
+ filter: AbstractFilter,
144
+ cutoff_freq: float,
145
+ samp_freq: float | None,
146
+ filtfilt: bool = True,
147
+ name="LPF_FilterWrapper",
148
+ ) -> None:
149
+ super().__init__(filter, name)
150
+ self.samp_freq = samp_freq
151
+ self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
152
+
153
+ def apply(self, X, params=None, state=None, y=None, lam=None):
154
+ if X.ndim == 4:
155
+ if self.samp_freq is not None:
156
+ samp_freq = jnp.repeat(jnp.array(self.samp_freq), X.shape[0])
157
+ else:
158
+ assert X.shape[-1] == 10
159
+ dt = X[:, 0, 0, -1]
160
+ samp_freq = 1 / dt
161
+ else:
162
+ if self.samp_freq is not None:
163
+ samp_freq = jnp.array(self.samp_freq)
164
+ else:
165
+ assert X.shape[-1] == 10
166
+ dt = X[0, 0, -1]
167
+ samp_freq = 1 / dt
168
+
169
+ if self.samp_freq is None:
170
+ print(f"Detected the following sampling rates from `X`: {samp_freq}")
171
+
172
+ yhat, state = super().apply(X, params, state, y, lam)
173
+
174
+ if yhat.ndim == 4:
175
+ yhat = jax.vmap(
176
+ jax.vmap(
177
+ lambda q, samp_freq: ring.maths.quat_lowpassfilter(
178
+ q, samp_freq=samp_freq, **self._kwargs
179
+ ),
180
+ in_axes=(1, None),
181
+ out_axes=1,
182
+ )
183
+ )(yhat, samp_freq)
184
+ else:
185
+ yhat = jax.vmap(
186
+ lambda q, samp_freq: ring.maths.quat_lowpassfilter(
187
+ q, samp_freq=samp_freq, **self._kwargs
188
+ ),
189
+ in_axes=(1, None),
190
+ out_axes=1,
191
+ )(yhat, samp_freq)
192
+ return yhat, state
193
+
194
+
195
+ class GroundTruthHeading_FilterWrapper(AbstractFilterWrapper):
196
+
197
+ def __init__(
198
+ self, filter: AbstractFilter, name="GroundTruthHeading_FilterWrapper"
199
+ ) -> None:
200
+ super().__init__(filter, name)
201
+
202
+ def apply(self, X, params=None, state=None, y=None, lam=None):
203
+ yhat, state = super().apply(X, params, state, y, lam)
204
+ if lam is None:
205
+ lam = self.search_attr("lam")
206
+ yhat = self.transfer_ground_truth_heading(lam, y, yhat)
207
+ return yhat, state
208
+
209
+ @staticmethod
210
+ def transfer_ground_truth_heading(lam, y, yhat) -> None:
211
+ if y is None:
212
+ return yhat
213
+
214
+ assert lam is not None
215
+ yhat = jnp.array(yhat)
216
+ for i, p in enumerate(lam):
217
+ if p == -1:
218
+ yhat = yhat.at[..., i, :].set(
219
+ ring.maths.quat_transfer_heading(y[..., i, :], yhat[..., i, :])
220
+ )
221
+ return yhat
222
+
223
+
224
+ _default_factors = dict(gyr=1 / 2.2, acc=1 / 9.81, joint_axes=1 / 0.57, dt=10.0)
225
+
226
+
227
+ class ScaleX_FilterWrapper(AbstractFilterWrapper):
228
+
229
+ def __init__(
230
+ self,
231
+ filter: AbstractFilter,
232
+ factors: dict[str, float] = _default_factors,
233
+ name="ScaleX_FilterWrapper",
234
+ ) -> None:
235
+ super().__init__(filter, name)
236
+ self._factors = factors
237
+
238
+ def apply(self, X, params=None, state=None, y=None, lam=None):
239
+ F = X.shape[-1]
240
+ num_batch_dims = X.ndim - 1
241
+
242
+ if F == 6:
243
+ X = dict(acc=X[..., :3], gyr=X[..., 3:])
244
+ elif F == 9:
245
+ X = dict(acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:])
246
+ elif F == 10:
247
+ X = dict(
248
+ acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:9], dt=X[..., 9:10]
249
+ )
250
+ else:
251
+ raise Exception(f"X.shape={X.shape}")
252
+ X = {key: val * self._factors[key] for key, val in X.items()}
253
+ X = tree_utils.batch_concat_acme(X, num_batch_dims=num_batch_dims)
254
+ return super().apply(X, params, state, y, lam)
255
+
256
+
257
+ class NoGraph_FilterWrapper(AbstractFilterWrapper):
258
+
259
+ def __init__(
260
+ self, filter: AbstractFilter, quat_normalize: bool = False, name=None
261
+ ) -> None:
262
+ super().__init__(filter, name)
263
+ self._quat_normalize = quat_normalize
264
+
265
+ def init(self, bs=None, X=None, lam=None, seed: int = 1):
266
+ batched = X.ndim == 4
267
+ if batched:
268
+ B, T, N, F = X.shape
269
+ X = X.reshape((B, T, 1, N * F))
270
+ else:
271
+ T, N, F = X.shape
272
+ X = X.reshape(T, 1, N * F)
273
+ return super().init(bs, X, (-1,), seed)
274
+
275
+ def apply(self, X: jax.Array, params=None, state=None, y=None, lam=None):
276
+ batched = X.ndim == 4
277
+ if batched:
278
+ B, T, N, F = X.shape
279
+ X = X.reshape((B, T, 1, N * F))
280
+ yhat, state = super().apply(X, params, state, y, (-1,))
281
+ yhat = yhat.reshape((B, T, N, -1))
282
+ else:
283
+ T, N, F = X.shape
284
+ X = X.reshape((T, 1, N * F))
285
+ yhat, state = super().apply(X, params, state, y, (-1,))
286
+ yhat = yhat.reshape((T, N, -1))
287
+
288
+ if self._quat_normalize:
289
+ assert yhat.shape[-1] == 4
290
+ yhat = ring.maths.safe_normalize(yhat)
291
+
292
+ return yhat, state