imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
ring/ml/optimizer.py ADDED
@@ -0,0 +1,149 @@
1
+ from typing import Any, NamedTuple, Optional
2
+
3
+ import jax
4
+ from jax import lax
5
+ import jax.numpy as jnp
6
+ from jax.tree_util import tree_map
7
+ import optax
8
+ from optax._src import base
9
+ from optax._src import numerics
10
+ from optax._src.transform import add_noise
11
+ from optax._src.transform import AddNoiseState
12
+
13
+
14
+ def make_optimizer(
15
+ lr: float,
16
+ n_episodes: int,
17
+ n_steps_per_episode: int,
18
+ adap_clip: Optional[float] = 0.1,
19
+ glob_clip: Optional[float] = 0.2,
20
+ skip_large_update_max_normsq: float = 5.0,
21
+ skip_large_update_warmup: int = 300,
22
+ inner_opt=optax.lamb,
23
+ cos_decay_twice: bool = False,
24
+ scale_grads: Optional[float] = None,
25
+ **inner_opt_kwargs,
26
+ ):
27
+ steps = n_steps_per_episode * n_episodes
28
+ if cos_decay_twice:
29
+ half_steps = int(steps / 2)
30
+ schedule = optax.join_schedules(
31
+ [
32
+ optax.cosine_decay_schedule(lr, half_steps, 1e-2),
33
+ optax.cosine_decay_schedule(lr * 1e-2, half_steps),
34
+ ],
35
+ [half_steps],
36
+ )
37
+ else:
38
+ schedule = optax.cosine_decay_schedule(lr, steps, 1e-7)
39
+
40
+ optimizer = optax.chain(
41
+ (
42
+ optax.scale_by_learning_rate(scale_grads, flip_sign=False)
43
+ if scale_grads is not None
44
+ else optax.identity()
45
+ ),
46
+ (
47
+ optax.adaptive_grad_clip(adap_clip)
48
+ if adap_clip is not None
49
+ else optax.identity()
50
+ ),
51
+ optax.clip_by_global_norm(0.2) if glob_clip is not None else optax.identity(),
52
+ inner_opt(schedule, **inner_opt_kwargs),
53
+ )
54
+ optimizer = skip_large_update(
55
+ optimizer,
56
+ skip_large_update_max_normsq,
57
+ max_consecutive_toolarge=6 * 25,
58
+ warmup=skip_large_update_warmup,
59
+ )
60
+ return optimizer
61
+
62
+
63
+ class SkipIfLargeUpdatesState(NamedTuple):
64
+ toolarge_count: jnp.array
65
+ count: jnp.array
66
+ inner_state: Any
67
+ add_noise_state: AddNoiseState
68
+
69
+
70
+ def _condition_not_toolarge(updates: base.Updates, max_norm_sq: float):
71
+ norm_sq = jnp.sum(
72
+ jnp.array([jnp.sum(p**2) for p in jax.tree_util.tree_leaves(updates)])
73
+ )
74
+ # This will also return False if `norm_sq` is NaN or Inf.
75
+ return norm_sq < max_norm_sq
76
+
77
+
78
+ def skip_large_update(
79
+ inner: base.GradientTransformation,
80
+ max_norm_sq: float,
81
+ max_consecutive_toolarge: int,
82
+ warmup: int = 0,
83
+ disturb_if_skip: bool = False,
84
+ disturb_adaptive: bool = False,
85
+ eta: float = 0.01,
86
+ gamma: float = 0.55,
87
+ seed: int = 0,
88
+ ) -> base.GradientTransformation:
89
+ "Also skips NaNs and Infs."
90
+ inner = base.with_extra_args_support(inner)
91
+
92
+ if disturb_adaptive:
93
+ raise NotImplementedError
94
+
95
+ add_noise_transform = add_noise(eta, gamma, seed)
96
+
97
+ def init(params):
98
+ return SkipIfLargeUpdatesState(
99
+ toolarge_count=jnp.zeros([], jnp.int32),
100
+ count=jnp.zeros([], jnp.int32),
101
+ inner_state=inner.init(params),
102
+ add_noise_state=add_noise_transform.init(params),
103
+ )
104
+
105
+ def update(updates, state: SkipIfLargeUpdatesState, params=None, **extra_args):
106
+ inner_state = state.inner_state
107
+ not_toolarge = _condition_not_toolarge(updates, max_norm_sq)
108
+ toolarge_count = jnp.where(
109
+ not_toolarge,
110
+ jnp.zeros([], jnp.int32),
111
+ numerics.safe_int32_increment(state.toolarge_count),
112
+ )
113
+
114
+ def do_update(updates):
115
+ updates, new_inner_state = inner.update(
116
+ updates, inner_state, params, **extra_args
117
+ )
118
+ return updates, new_inner_state, state.add_noise_state
119
+
120
+ def reject_update(updates):
121
+ if disturb_if_skip:
122
+ updates, new_add_noise_state = add_noise_transform.update(
123
+ updates, state.add_noise_state, params
124
+ )
125
+ else:
126
+ updates, new_add_noise_state = (
127
+ tree_map(jnp.zeros_like, updates),
128
+ state.add_noise_state,
129
+ )
130
+ return updates, inner_state, new_add_noise_state
131
+
132
+ updates, new_inner_state, new_add_noise_state = lax.cond(
133
+ jnp.logical_or(
134
+ jnp.logical_or(not_toolarge, toolarge_count > max_consecutive_toolarge),
135
+ state.count < warmup,
136
+ ),
137
+ do_update,
138
+ reject_update,
139
+ updates,
140
+ )
141
+
142
+ return updates, SkipIfLargeUpdatesState(
143
+ toolarge_count=toolarge_count,
144
+ count=numerics.safe_int32_increment(state.count),
145
+ inner_state=new_inner_state,
146
+ add_noise_state=new_add_noise_state,
147
+ )
148
+
149
+ return base.GradientTransformationExtraArgs(init=init, update=update)
Binary file
ring/ml/ringnet.py ADDED
@@ -0,0 +1,279 @@
1
+ from functools import partial
2
+ from pathlib import Path
3
+ from types import SimpleNamespace
4
+ from typing import Callable, Optional
5
+
6
+ import haiku as hk
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import tree_utils
10
+
11
+ from ring.maths import safe_normalize
12
+ from ring.ml import base as ml_base
13
+ from ring.utils import pickle_load
14
+
15
+
16
+ def _scan_sys(lam: list[int], f):
17
+ ys = []
18
+ for i, p in enumerate(lam):
19
+ ys.append(f(i, p))
20
+ return tree_utils.tree_batch(ys, backend="jax")
21
+
22
+
23
+ def _make_rnno_cell_apply_fn(
24
+ lam: list[int],
25
+ inner_cell,
26
+ send_msg,
27
+ send_output,
28
+ hidden_state_dim,
29
+ message_dim,
30
+ output_transform: Callable,
31
+ ):
32
+ N = len(lam)
33
+ parent_array = jnp.array(lam, dtype=jnp.int32)
34
+
35
+ def _rnno_cell_apply_fn(inputs, prev_state):
36
+ empty_message = jnp.zeros((1, message_dim))
37
+ mailbox = jnp.repeat(empty_message, N, axis=0)
38
+
39
+ # message is sent using the hidden state of the last cell
40
+ # for LSTM `prev_state` is of shape (2 * hidden_state_dim) du to cell state
41
+ prev_last_hidden_state = prev_state[:, -1, :hidden_state_dim]
42
+
43
+ msg = jnp.concatenate(
44
+ (jax.vmap(send_msg)(prev_last_hidden_state), empty_message)
45
+ )
46
+
47
+ def accumulate_message(link):
48
+ return jnp.sum(
49
+ jnp.where(
50
+ jnp.repeat((parent_array == link)[:, None], message_dim, axis=-1),
51
+ msg[:-1],
52
+ mailbox,
53
+ ),
54
+ axis=0,
55
+ )
56
+
57
+ mailbox = jax.vmap(accumulate_message)(jnp.arange(N))
58
+
59
+ def cell_input(i: int, p: int):
60
+ local_input = inputs[i]
61
+ local_cell_input = tree_utils.batch_concat_acme(
62
+ (local_input, msg[p], mailbox[i]), num_batch_dims=0
63
+ )
64
+ return local_cell_input
65
+
66
+ stacked_cell_input = _scan_sys(lam, cell_input)
67
+
68
+ def update_state(cell_input, state):
69
+ cell_output, state = inner_cell(cell_input, state)
70
+ output = output_transform(send_output(cell_output))
71
+ return output, state
72
+
73
+ y, state = jax.vmap(update_state)(stacked_cell_input, prev_state)
74
+ return y, state
75
+
76
+ return _rnno_cell_apply_fn
77
+
78
+
79
+ def make_ring(
80
+ lam: list[int],
81
+ hidden_state_dim: int = 400,
82
+ message_dim: int = 200,
83
+ celltype: str = "gru",
84
+ stack_rnn_cells: int = 2,
85
+ send_message_n_layers: int = 1,
86
+ link_output_dim: int = 4,
87
+ link_output_normalize: bool = True,
88
+ link_output_transform: Optional[Callable] = None,
89
+ layernorm: bool = True,
90
+ ) -> SimpleNamespace:
91
+
92
+ if link_output_normalize:
93
+ assert link_output_transform is None
94
+ link_output_transform = safe_normalize
95
+ else:
96
+ if link_output_transform is None:
97
+ link_output_transform = lambda x: x
98
+
99
+ @hk.without_apply_rng
100
+ @hk.transform_with_state
101
+ def forward(X):
102
+ send_msg = hk.nets.MLP(
103
+ [hidden_state_dim] * send_message_n_layers + [message_dim]
104
+ )
105
+
106
+ inner_cell = StackedRNNCell(
107
+ celltype, hidden_state_dim, stack_rnn_cells, layernorm=layernorm
108
+ )
109
+ send_output = hk.nets.MLP([hidden_state_dim, link_output_dim])
110
+ state = hk.get_state(
111
+ "inner_cell_state",
112
+ [
113
+ len(lam),
114
+ stack_rnn_cells,
115
+ (hidden_state_dim * 2 if celltype == "lstm" else hidden_state_dim),
116
+ ],
117
+ init=jnp.zeros,
118
+ )
119
+
120
+ y, state = hk.dynamic_unroll(
121
+ _make_rnno_cell_apply_fn(
122
+ lam=lam,
123
+ inner_cell=inner_cell,
124
+ send_msg=send_msg,
125
+ send_output=send_output,
126
+ hidden_state_dim=hidden_state_dim,
127
+ message_dim=message_dim,
128
+ output_transform=link_output_transform,
129
+ ),
130
+ X,
131
+ state,
132
+ )
133
+ hk.set_state("inner_cell_state", state)
134
+ return y
135
+
136
+ return forward
137
+
138
+
139
+ class StackedRNNCell(hk.Module):
140
+ def __init__(
141
+ self,
142
+ celltype: str,
143
+ hidden_state_dim,
144
+ stacks: int,
145
+ layernorm: bool = False,
146
+ name: str | None = None,
147
+ ):
148
+ super().__init__(name)
149
+ cell = {"gru": hk.GRU, "lstm": LSTM}[celltype]
150
+
151
+ self.cells = [cell(hidden_state_dim) for _ in range(stacks)]
152
+ self.layernorm = layernorm
153
+
154
+ def __call__(self, x, state):
155
+ output = x
156
+ next_state = []
157
+ for i in range(len(self.cells)):
158
+ output, next_state_i = self.cells[i](output, state[i])
159
+ next_state.append(next_state_i)
160
+
161
+ if self.layernorm:
162
+ output = hk.LayerNorm(-1, True, True)(output)
163
+
164
+ return output, jnp.stack(next_state)
165
+
166
+
167
+ class LSTM(hk.RNNCore):
168
+ def __init__(self, hidden_size: int, name=None):
169
+ super().__init__(name=name)
170
+ self.hidden_size = hidden_size
171
+
172
+ def __call__(
173
+ self,
174
+ inputs: jax.Array,
175
+ prev_state: jax.Array,
176
+ ):
177
+ if len(inputs.shape) > 2 or not inputs.shape:
178
+ raise ValueError("LSTM input must be rank-1 or rank-2.")
179
+ prev_state_h = prev_state[: self.hidden_size]
180
+ prev_state_c = prev_state[self.hidden_size :]
181
+ x_and_h = jnp.concatenate([inputs, prev_state_h], axis=-1)
182
+ gated = hk.Linear(4 * self.hidden_size)(x_and_h)
183
+ i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
184
+ f = jax.nn.sigmoid(f + 1) # Forget bias, as in sonnet.
185
+ c = f * prev_state_c + jax.nn.sigmoid(i) * jnp.tanh(g)
186
+ h = jax.nn.sigmoid(o) * jnp.tanh(c)
187
+ return h, jnp.concatenate((h, c))
188
+
189
+ def initial_state(self, batch_size: int | None):
190
+ raise NotImplementedError
191
+
192
+
193
+ class RING(ml_base.AbstractFilter):
194
+ def __init__(self, params=None, lam=None, jit: bool = True, name=None, **kwargs):
195
+ self.forward_lam_factory = partial(make_ring, **kwargs)
196
+ self.params = self._load_params(params)
197
+ self.lam = lam
198
+ self._name = name
199
+
200
+ if jit:
201
+ self.apply = jax.jit(self.apply, static_argnames="lam")
202
+
203
+ def apply(self, X, params=None, state=None, y=None, lam=None):
204
+ if lam is None:
205
+ assert self.lam is not None
206
+ lam = self.lam
207
+
208
+ return super().apply(X, params, state, y, tuple(lam))
209
+
210
+ def init(self, bs: Optional[int] = None, X=None, lam=None, seed: int = 1):
211
+ assert X is not None, "Providing `X` via in `ringnet.init(X=X)` is required"
212
+ if bs is not None:
213
+ assert X.ndim == 4
214
+
215
+ if X.ndim == 4:
216
+ if bs is not None:
217
+ assert bs == X.shape[0]
218
+ else:
219
+ bs = X.shape[0]
220
+ X = X[0]
221
+
222
+ # (T, N, F) -> (1, N, F) for faster .init call
223
+ X = X[0:1]
224
+
225
+ if lam is None:
226
+ assert self.lam is not None
227
+ lam = self.lam
228
+
229
+ key = jax.random.PRNGKey(seed)
230
+ params, state = self.forward_lam_factory(lam=lam).init(key, X)
231
+
232
+ if bs is not None:
233
+ state = jax.tree_map(lambda arr: jnp.repeat(arr[None], bs, axis=0), state)
234
+
235
+ return params, state
236
+
237
+ def _apply_batched(self, X, params, state, y, lam):
238
+ if (params is None and self.params is None) or state is None:
239
+ _params, _state = self.init(bs=X.shape[0], X=X, lam=lam)
240
+
241
+ if params is None and self.params is None:
242
+ params = _params
243
+ elif params is None:
244
+ params = self.params
245
+ else:
246
+ pass
247
+
248
+ if state is None:
249
+ state = _state
250
+
251
+ yhat, next_state = jax.vmap(
252
+ self.forward_lam_factory(lam=lam).apply, in_axes=(None, 0, 0)
253
+ )(params, state, X)
254
+
255
+ return yhat, next_state
256
+
257
+ @staticmethod
258
+ def _load_params(params: str | dict | None | Path):
259
+ assert isinstance(params, (str, dict, type(None), Path))
260
+ if isinstance(params, (Path, str)):
261
+ return pickle_load(params)
262
+ return params
263
+
264
+ def nojit(self) -> "RING":
265
+ ringnet = RING(params=self.params, lam=self.lam, jit=False)
266
+ ringnet.forward_lam_factory = self.forward_lam_factory
267
+ return ringnet
268
+
269
+ def _pre_save(self, params=None, lam=None) -> None:
270
+ if params is not None:
271
+ self.params = params
272
+ if lam is not None:
273
+ self.lam = lam
274
+
275
+ @staticmethod
276
+ def _post_load(ringnet: "RING", jit: bool = True) -> "RING":
277
+ if jit:
278
+ ringnet.apply = jax.jit(ringnet.apply, static_argnames="lam")
279
+ return ringnet