imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
ring/ml/train.py ADDED
@@ -0,0 +1,318 @@
1
+ from functools import partial
2
+ from pathlib import Path
3
+ from typing import Callable, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import optax
8
+ from ring import maths
9
+ from ring.algorithms.generator import types
10
+ from ring.ml import base as ml_base
11
+ from ring.ml import callbacks as ml_callbacks
12
+ from ring.ml import ml_utils
13
+ from ring.ml import training_loop
14
+ from ring.utils import distribute_batchsize
15
+ from ring.utils import expand_batchsize
16
+ from ring.utils import parse_path
17
+ from ring.utils import pickle_load
18
+ import tree_utils
19
+
20
+ import wandb
21
+
22
+ # (T, N, F) -> Scalar
23
+ LOSS_FN = Callable[[jax.Array, jax.Array], float]
24
+ _default_loss_fn = lambda q, qhat: maths.angle_error(q, qhat) ** 2
25
+
26
+ # reduces (batch_axis, time_axis) -> Scalar
27
+ ACCUMULATOR_FN = Callable[[jax.Array], float]
28
+ # Loss_fn here is: (F,) -> Scalar
29
+ METRICES = dict[str, Tuple[LOSS_FN, ACCUMULATOR_FN]]
30
+ _default_metrices = {
31
+ "mae_deg": (
32
+ lambda q, qhat: maths.angle_error(q, qhat),
33
+ lambda arr: jnp.rad2deg(jnp.mean(arr, axis=(0, 1))),
34
+ ),
35
+ }
36
+
37
+
38
+ def _build_step_fn(
39
+ metric_fn: LOSS_FN,
40
+ filter: ml_base.AbstractFilter,
41
+ optimizer,
42
+ tbp,
43
+ ):
44
+ """Build step function that optimizes filter parameters based on `metric_fn`.
45
+ `initial_state` has shape (pmap, vmap, state_dim)"""
46
+
47
+ @partial(jax.value_and_grad, has_aux=True)
48
+ def loss_fn(params, state, X, y):
49
+ yhat, state = filter.apply(params=params, state=state, X=X)
50
+ # this vmap maps along batch-axis, not time-axis
51
+ # time-axis is handled by `metric_fn`
52
+ pipe = lambda q, qhat: jnp.mean(jax.vmap(metric_fn)(q, qhat))
53
+ error_tree = jax.tree_map(pipe, y, yhat)
54
+ return jnp.mean(tree_utils.batch_concat(error_tree, 0)), state
55
+
56
+ @partial(
57
+ jax.pmap,
58
+ in_axes=(None, 0, 0, 0),
59
+ out_axes=((None, 0), None),
60
+ axis_name="devices",
61
+ )
62
+ def pmapped_loss_fn(params, state, X, y):
63
+ pmean = lambda arr: jax.lax.pmean(arr, axis_name="devices")
64
+ (loss, state), grads = loss_fn(params, state, X, y)
65
+ return (pmean(loss), state), pmean(grads)
66
+
67
+ @jax.jit
68
+ def apply_grads(grads, params, opt_state):
69
+ updates, opt_state = optimizer.update(grads, opt_state, params)
70
+ params = optax.apply_updates(params, updates)
71
+ return params, opt_state
72
+
73
+ initial_state = None
74
+
75
+ def step_fn(params, opt_state, X, y):
76
+ assert X.ndim == y.ndim == 4
77
+ B, T, N, F = X.shape
78
+ pmap_size, vmap_size = distribute_batchsize(B)
79
+
80
+ nonlocal initial_state
81
+ if initial_state is None:
82
+ initial_state = expand_batchsize(filter.init(B, X)[1], pmap_size, vmap_size)
83
+
84
+ X, y = expand_batchsize((X, y), pmap_size, vmap_size)
85
+
86
+ state = initial_state
87
+ debug_grads = []
88
+ for i, (X_tbp, y_tbp) in enumerate(
89
+ tree_utils.tree_split((X, y), int(T / tbp), axis=-3)
90
+ ):
91
+ (loss, state), grads = pmapped_loss_fn(params, state, X_tbp, y_tbp)
92
+ debug_grads.append(grads)
93
+ state = jax.lax.stop_gradient(state)
94
+ params, opt_state = apply_grads(grads, params, opt_state)
95
+
96
+ return params, opt_state, {"loss": loss}, debug_grads
97
+
98
+ return step_fn
99
+
100
+
101
+ def train_fn(
102
+ generator: types.BatchedGenerator,
103
+ n_episodes: int,
104
+ filter: ml_base.AbstractFilter | ml_base.AbstractFilterWrapper,
105
+ optimizer: Optional[optax.GradientTransformation] = optax.adam(1e-3),
106
+ tbp: int = 1000,
107
+ loggers: list[ml_utils.Logger] = [],
108
+ callbacks: list[training_loop.TrainingLoopCallback] = [],
109
+ checkpoint: Optional[str] = None,
110
+ seed_network: int = 1,
111
+ seed_generator: int = 2,
112
+ callback_save_params: bool | str = False,
113
+ callback_save_params_track_metrices: Optional[list[list[str]]] = None,
114
+ callback_kill_if_grads_larger: Optional[float] = None,
115
+ callback_kill_if_nan: bool = False,
116
+ callback_kill_after_episode: Optional[int] = None,
117
+ callback_kill_after_seconds: Optional[float] = None,
118
+ callback_kill_tag: Optional[str] = None,
119
+ callback_create_checkpoint: bool = True,
120
+ loss_fn: LOSS_FN = _default_loss_fn,
121
+ metrices: Optional[METRICES] = _default_metrices,
122
+ link_names: Optional[list[str]] = None,
123
+ ) -> bool:
124
+ """Trains RNNO
125
+
126
+ Args:
127
+ generator (Callable): output `build_generator`
128
+ n_episodes (int): number of episodes to train for
129
+ network (hk.TransformedWithState): RNNO network
130
+ optimizer (_type_, optional): optimizer, see optimizer.py module
131
+ tbp (int, optional): Truncated backpropagation through time step size
132
+ tbp_skip (int, optional): Skip `tbp_skip` number of first steps per epoch.
133
+ tbp_skip_keep_grads (bool, optional): Keeps grads between first `tbp_skip`
134
+ steps per epoch.
135
+ loggers: list of Loggers used to log the training progress.
136
+ callbacks: callbacks of the TrainingLoop.
137
+ initial_params: If given uses as initial parameters.
138
+ key_network: PRNG Key that inits the network state and parameters.
139
+ key_generator: PRNG Key that inits the data stream of the generator.
140
+
141
+ Returns: bool
142
+ Wether or not the training run was killed by a callback.
143
+ """
144
+
145
+ if checkpoint is not None:
146
+ checkpoint = Path(checkpoint).with_suffix(".pickle")
147
+ recv_checkpoint: dict = pickle_load(checkpoint)
148
+ filter.params = recv_checkpoint["params"]
149
+ opt_state = recv_checkpoint["opt_state"]
150
+
151
+ filter = filter.nojit()
152
+
153
+ filter_params = filter.search_attr("params")
154
+ if filter_params is None:
155
+ X, _ = generator(jax.random.PRNGKey(1))
156
+ filter_params, _ = filter.init(X=X, seed=seed_network)
157
+ del X
158
+
159
+ if checkpoint is None:
160
+ opt_state = optimizer.init(filter_params)
161
+
162
+ step_fn = _build_step_fn(
163
+ loss_fn,
164
+ filter,
165
+ optimizer,
166
+ tbp=tbp,
167
+ )
168
+
169
+ default_callbacks = []
170
+ if metrices is not None:
171
+ eval_fn = _build_eval_fn(metrices, filter, link_names)
172
+ default_callbacks.append(_DefaultEvalFnCallback(eval_fn))
173
+
174
+ if callback_kill_tag is not None:
175
+ default_callbacks.append(ml_callbacks.WandbKillRun(stop_tag=callback_kill_tag))
176
+
177
+ if not (callback_save_params is False):
178
+ if callback_save_params is True:
179
+ callback_save_params = f"~/params/{ml_utils.unique_id()}.pickle"
180
+ default_callbacks.append(
181
+ ml_callbacks.SaveParamsTrainingLoopCallback(callback_save_params)
182
+ )
183
+
184
+ if callback_kill_if_grads_larger is not None:
185
+ default_callbacks.append(
186
+ ml_callbacks.LogGradsTrainingLoopCallBack(
187
+ callback_kill_if_grads_larger, consecutive_larger=18
188
+ )
189
+ )
190
+
191
+ if callback_kill_if_nan:
192
+ default_callbacks.append(ml_callbacks.NanKillRunCallback())
193
+
194
+ # always log, because we also want `i_epsiode` to be logged in wandb
195
+ default_callbacks.append(
196
+ ml_callbacks.LogEpisodeTrainingLoopCallback(callback_kill_after_episode)
197
+ )
198
+
199
+ if callback_kill_after_seconds is not None:
200
+ default_callbacks.append(
201
+ ml_callbacks.TimingKillRunCallback(callback_kill_after_seconds)
202
+ )
203
+
204
+ if callback_create_checkpoint:
205
+ default_callbacks.append(ml_callbacks.CheckpointCallback())
206
+
207
+ callbacks_all = default_callbacks + callbacks
208
+
209
+ # we add this callback afterwards because it might require the metrices calculated
210
+ # from one of the user-provided callbacks
211
+ if callback_save_params_track_metrices is not None:
212
+ assert (
213
+ callback_save_params is not None
214
+ ), "Required field if `callback_save_params_track_metrices` is set. Used below."
215
+
216
+ callbacks_all.append(
217
+ ml_callbacks.SaveParamsTrainingLoopCallback(
218
+ path_to_file=parse_path(callback_save_params, extension=""),
219
+ last_n_params=3,
220
+ track_metrices=callback_save_params_track_metrices,
221
+ cleanup=False,
222
+ )
223
+ )
224
+
225
+ # if wandb is initialized, then add the appropriate logger
226
+ if wandb.run is not None:
227
+ wandb_logger_found = False
228
+ for logger in loggers:
229
+ if isinstance(logger, ml_utils.WandbLogger):
230
+ wandb_logger_found = True
231
+ if not wandb_logger_found:
232
+ loggers.append(ml_utils.WandbLogger())
233
+
234
+ loop = training_loop.TrainingLoop(
235
+ jax.random.PRNGKey(seed_generator),
236
+ generator,
237
+ filter_params,
238
+ opt_state,
239
+ step_fn,
240
+ loggers=loggers,
241
+ callbacks=callbacks_all,
242
+ )
243
+
244
+ return loop.run(n_episodes)
245
+
246
+
247
+ def _arr_to_dict(y: jax.Array, link_names: list[str] | None):
248
+ assert y.ndim == 4
249
+ B, T, N, F = y.shape
250
+
251
+ if link_names is None:
252
+ link_names = ml_utils._unknown_link_names(N)
253
+
254
+ return {name: y[..., i, :] for i, name in enumerate(link_names)}
255
+
256
+
257
+ def _build_eval_fn(
258
+ eval_metrices: dict[str, Tuple[Callable, Callable]],
259
+ filter: ml_base.AbstractFilter,
260
+ link_names: Optional[list[str]] = None,
261
+ ):
262
+ """Build function that evaluates the filter performance."""
263
+
264
+ def eval_fn(params, state, X, y):
265
+ yhat, _ = filter.apply(params=params, state=state, X=X)
266
+
267
+ y = _arr_to_dict(y, link_names)
268
+ yhat = _arr_to_dict(yhat, link_names)
269
+
270
+ values = {}
271
+ for metric_name, (metric_fn, reduce_fn) in eval_metrices.items():
272
+ assert (
273
+ metric_name not in values
274
+ ), f"The metric identitifier {metric_name} is not unique"
275
+
276
+ pipe = lambda q, qhat: reduce_fn(jax.vmap(jax.vmap(metric_fn))(q, qhat))
277
+ values.update({metric_name: jax.tree_map(pipe, y, yhat)})
278
+
279
+ return values
280
+
281
+ @partial(jax.pmap, in_axes=(None, 0, 0, 0), out_axes=None, axis_name="devices")
282
+ def pmapped_eval_fn(params, state, X, y):
283
+ pmean = lambda arr: jax.lax.pmean(arr, axis_name="devices")
284
+ values = eval_fn(params, state, X, y)
285
+ return pmean(values)
286
+
287
+ initial_state = None
288
+
289
+ def expand_then_pmap_eval_fn(params, X, y):
290
+ assert X.ndim == y.ndim == 4
291
+ B, T, N, F = X.shape
292
+ pmap_size, vmap_size = distribute_batchsize(B)
293
+
294
+ nonlocal initial_state
295
+ if initial_state is None:
296
+ initial_state = expand_batchsize(filter.init(B, X)[1], pmap_size, vmap_size)
297
+
298
+ X, y = expand_batchsize((X, y), pmap_size, vmap_size)
299
+ return pmapped_eval_fn(params, initial_state, X, y)
300
+
301
+ return expand_then_pmap_eval_fn
302
+
303
+
304
+ class _DefaultEvalFnCallback(training_loop.TrainingLoopCallback):
305
+ def __init__(self, eval_fn):
306
+ self.eval_fn = eval_fn
307
+
308
+ def after_training_step(
309
+ self,
310
+ i_episode: int,
311
+ metrices: dict,
312
+ params: dict,
313
+ grads: list[dict],
314
+ sample_eval: dict,
315
+ loggers: list[ml_utils.Logger],
316
+ opt_state,
317
+ ):
318
+ metrices.update(self.eval_fn(params, sample_eval[0], sample_eval[1]))
@@ -0,0 +1,131 @@
1
+ import random
2
+ from typing import Optional
3
+
4
+ import jax
5
+ from ring.algorithms import Generator
6
+ from ring.ml import ml_utils
7
+ import tqdm
8
+ import tree_utils
9
+
10
+ _KILL_RUN = False
11
+
12
+
13
+ def send_kill_run_signal(value: bool = True) -> None:
14
+ global _KILL_RUN
15
+ _KILL_RUN = value
16
+
17
+
18
+ def recv_kill_run_signal() -> bool:
19
+ global _KILL_RUN
20
+ return _KILL_RUN
21
+
22
+
23
+ class TrainingLoopCallback:
24
+ def after_training_step(
25
+ self,
26
+ i_episode: int,
27
+ metrices: dict,
28
+ params: dict,
29
+ grads: list[dict],
30
+ sample_eval: dict,
31
+ loggers: list[ml_utils.Logger],
32
+ opt_state: tree_utils.PyTree,
33
+ ) -> None:
34
+ pass
35
+
36
+ def close(self):
37
+ pass
38
+
39
+
40
+ class TrainingLoop:
41
+ def __init__(
42
+ self,
43
+ key,
44
+ generator: Generator,
45
+ params,
46
+ opt_state,
47
+ step_fn,
48
+ loggers: list[ml_utils.Logger],
49
+ callbacks: list[TrainingLoopCallback] = [],
50
+ cycle_seed: Optional[int] = None,
51
+ ):
52
+ self._key = key
53
+ self.i_episode = -1
54
+ self._generator = generator
55
+ self._params = params
56
+ self._opt_state = opt_state
57
+ self._step_fn = step_fn
58
+ self._loggers = loggers
59
+ self._callbacks = callbacks
60
+ self._seeds = list(range(cycle_seed)) if cycle_seed else None
61
+ if cycle_seed is not None:
62
+ random.seed(1)
63
+
64
+ self._sample_eval = generator(jax.random.PRNGKey(0))
65
+ batchsize = tree_utils.tree_shape(self._sample_eval, 0)
66
+ T = tree_utils.tree_shape(self._sample_eval, 1)
67
+
68
+ for logger in loggers:
69
+ logger.log(dict(n_params=logger.n_params(params), batchsize=batchsize, T=T))
70
+
71
+ @property
72
+ def key(self):
73
+ if self._seeds is not None:
74
+ seed_idx = self.i_episode % len(self._seeds)
75
+ if seed_idx == 0:
76
+ random.shuffle(self._seeds)
77
+ return jax.random.PRNGKey(self._seeds[seed_idx])
78
+ else:
79
+ self._key, consume = jax.random.split(self._key)
80
+ return consume
81
+
82
+ def run(self, n_episodes: int = 1, close_afterwards: bool = True) -> bool:
83
+ # reset the kill_run flag from previous runs
84
+ send_kill_run_signal(value=False)
85
+
86
+ for _ in tqdm.tqdm(range(n_episodes)):
87
+ self.step()
88
+
89
+ if recv_kill_run_signal():
90
+ break
91
+
92
+ if close_afterwards:
93
+ self.close()
94
+
95
+ return recv_kill_run_signal()
96
+
97
+ def step(self):
98
+ self.i_episode += 1
99
+
100
+ sample_train = self._sample_eval
101
+ self._sample_eval = self._generator(self.key)
102
+
103
+ self._params, self._opt_state, loss, debug_grads = self._step_fn(
104
+ self._params, self._opt_state, sample_train[0], sample_train[1]
105
+ )
106
+
107
+ metrices = {}
108
+ metrices.update(loss)
109
+
110
+ for callback in self._callbacks:
111
+ callback.after_training_step(
112
+ self.i_episode,
113
+ metrices,
114
+ self._params,
115
+ debug_grads,
116
+ self._sample_eval,
117
+ self._loggers,
118
+ self._opt_state,
119
+ )
120
+
121
+ for logger in self._loggers:
122
+ logger.log(metrices)
123
+
124
+ return metrices
125
+
126
+ def close(self):
127
+ for callback in self._callbacks:
128
+ callback.close()
129
+
130
+ for logger in self._loggers:
131
+ logger.close()
@@ -0,0 +1,2 @@
1
+ from .base_render import render
2
+ from .base_render import render_prediction