imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/ml/callbacks.py
ADDED
@@ -0,0 +1,434 @@
|
|
1
|
+
from collections import deque
|
2
|
+
from functools import partial
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
import time
|
6
|
+
from typing import Callable, NamedTuple, Optional
|
7
|
+
import warnings
|
8
|
+
|
9
|
+
import jax
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import numpy as np
|
12
|
+
import tree_utils
|
13
|
+
|
14
|
+
import ring
|
15
|
+
from ring.ml import base
|
16
|
+
from ring.ml import ml_utils
|
17
|
+
from ring.ml import training_loop
|
18
|
+
from ring.utils import distribute_batchsize
|
19
|
+
from ring.utils import expand_batchsize
|
20
|
+
from ring.utils import merge_batchsize
|
21
|
+
from ring.utils import parse_path
|
22
|
+
from ring.utils import pickle_save
|
23
|
+
import wandb
|
24
|
+
|
25
|
+
|
26
|
+
def _build_eval_fn2(
|
27
|
+
eval_metrices: dict[str, Callable],
|
28
|
+
filter: base.AbstractFilter,
|
29
|
+
X: jax.Array,
|
30
|
+
y: jax.Array,
|
31
|
+
lam: tuple[int] | None,
|
32
|
+
link_names: list[str] | None,
|
33
|
+
):
|
34
|
+
filter = filter.nojit()
|
35
|
+
assert X.ndim == 5
|
36
|
+
assert y.ndim == 5
|
37
|
+
y_4d = merge_batchsize(y, X.shape[0], X.shape[1])
|
38
|
+
|
39
|
+
if link_names is None:
|
40
|
+
link_names = ml_utils._unknown_link_names(y.shape[-2])
|
41
|
+
|
42
|
+
@partial(jax.pmap, in_axes=(None, 0, 0))
|
43
|
+
def pmap_vmap_apply(params, X, y):
|
44
|
+
return filter.apply(X=X, params=params, lam=lam, y=y)[0]
|
45
|
+
|
46
|
+
def eval_fn(params):
|
47
|
+
yhat = pmap_vmap_apply(params, X, y)
|
48
|
+
yhat = merge_batchsize(yhat, X.shape[0], X.shape[1])
|
49
|
+
|
50
|
+
values = {}
|
51
|
+
for metric_name, metric_fn in eval_metrices.items():
|
52
|
+
assert (
|
53
|
+
metric_name not in values
|
54
|
+
), f"The metric identitifier {metric_name} is not unique"
|
55
|
+
value = jax.vmap(metric_fn, in_axes=(2, 2))(y_4d, yhat)
|
56
|
+
assert value.ndim == 1, f"{value.shape}"
|
57
|
+
value = {name: value[i] for i, name in enumerate(link_names)}
|
58
|
+
values[metric_name] = value
|
59
|
+
return values
|
60
|
+
|
61
|
+
return eval_fn
|
62
|
+
|
63
|
+
|
64
|
+
class EvalXyTrainingLoopCallback(training_loop.TrainingLoopCallback):
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
filter: base.AbstractFilter,
|
68
|
+
eval_metrices: dict[str, Callable],
|
69
|
+
X: jax.Array,
|
70
|
+
y: jax.Array,
|
71
|
+
lam: tuple[int] | None,
|
72
|
+
metric_identifier: str,
|
73
|
+
eval_every: int = 5,
|
74
|
+
link_names: Optional[list[str]] = None,
|
75
|
+
):
|
76
|
+
"""X, y can be batched or unbatched.
|
77
|
+
Args:
|
78
|
+
eval_metrices: "(B, T, 1) -> () and links N are vmapped."
|
79
|
+
"""
|
80
|
+
if X.ndim == 3:
|
81
|
+
X, y = X[None], y[None]
|
82
|
+
B = X.shape[0]
|
83
|
+
X, y = expand_batchsize((X, y), *distribute_batchsize(B))
|
84
|
+
self.eval_fn = _build_eval_fn2(
|
85
|
+
eval_metrices,
|
86
|
+
filter,
|
87
|
+
X,
|
88
|
+
y,
|
89
|
+
lam,
|
90
|
+
link_names,
|
91
|
+
)
|
92
|
+
self.eval_every = eval_every
|
93
|
+
self.metric_identifier = metric_identifier
|
94
|
+
|
95
|
+
def after_training_step(
|
96
|
+
self,
|
97
|
+
i_episode: int,
|
98
|
+
metrices: dict,
|
99
|
+
params: dict,
|
100
|
+
grads: list[dict],
|
101
|
+
sample_eval: dict,
|
102
|
+
loggers: list[ml_utils.Logger],
|
103
|
+
opt_state,
|
104
|
+
):
|
105
|
+
if self.eval_every == -1:
|
106
|
+
return
|
107
|
+
|
108
|
+
if (i_episode % self.eval_every) == 0:
|
109
|
+
point_estimates = self.eval_fn(params)
|
110
|
+
self.last_metrices = {self.metric_identifier: point_estimates}
|
111
|
+
metrices.update(self.last_metrices)
|
112
|
+
|
113
|
+
|
114
|
+
class AverageMetricesTLCB(training_loop.TrainingLoopCallback):
|
115
|
+
def __init__(self, metrices_names: list[list[str]], name: str):
|
116
|
+
self.zoom_ins = metrices_names
|
117
|
+
self.name = name
|
118
|
+
|
119
|
+
def after_training_step(
|
120
|
+
self,
|
121
|
+
i_episode: int,
|
122
|
+
metrices: dict,
|
123
|
+
params: dict,
|
124
|
+
grads: list[dict],
|
125
|
+
sample_eval: dict,
|
126
|
+
loggers: list[ml_utils.Logger],
|
127
|
+
opt_state,
|
128
|
+
) -> None:
|
129
|
+
value = 0
|
130
|
+
N = 0
|
131
|
+
for zoom_in in self.zoom_ins:
|
132
|
+
value_zoom_in = _zoom_into_metrices(metrices, zoom_in)
|
133
|
+
|
134
|
+
if np.isnan(value_zoom_in) or np.isinf(value_zoom_in):
|
135
|
+
warning = (
|
136
|
+
f"Value of zoom_in={zoom_in} is {value_zoom_in}. "
|
137
|
+
+ f"It is not added to the metric {self.name}"
|
138
|
+
)
|
139
|
+
warnings.warn(warning)
|
140
|
+
continue
|
141
|
+
|
142
|
+
value += value_zoom_in
|
143
|
+
N += 1
|
144
|
+
|
145
|
+
if N > 0:
|
146
|
+
metrices.update({self.name: value / N})
|
147
|
+
|
148
|
+
|
149
|
+
class QueueElement(NamedTuple):
|
150
|
+
value: float
|
151
|
+
params: dict
|
152
|
+
episode: int
|
153
|
+
|
154
|
+
|
155
|
+
class Queue:
|
156
|
+
def __init__(self, maxlen: int = 1):
|
157
|
+
self._storage: list[QueueElement] = []
|
158
|
+
self.maxlen = maxlen
|
159
|
+
|
160
|
+
def __len__(self) -> int:
|
161
|
+
return len(self._storage)
|
162
|
+
|
163
|
+
def insert(self, ele: QueueElement) -> None:
|
164
|
+
sort = True
|
165
|
+
if len(self) < self.maxlen:
|
166
|
+
self._storage.append(ele)
|
167
|
+
elif ele.value < self._storage[-1].value:
|
168
|
+
self._storage[-1] = ele
|
169
|
+
else:
|
170
|
+
sort = False
|
171
|
+
|
172
|
+
if sort:
|
173
|
+
self._storage.sort(key=lambda ele: ele.value)
|
174
|
+
|
175
|
+
def __iter__(self):
|
176
|
+
return iter(self._storage)
|
177
|
+
|
178
|
+
|
179
|
+
def _zoom_into_metrices(metrices: dict, zoom_in: list[str]) -> float:
|
180
|
+
zoomed_out = metrices
|
181
|
+
for key in zoom_in:
|
182
|
+
zoomed_out = zoomed_out[key]
|
183
|
+
return float(zoomed_out)
|
184
|
+
|
185
|
+
|
186
|
+
class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
path_to_file: str,
|
190
|
+
upload: bool = True,
|
191
|
+
last_n_params: int = 1,
|
192
|
+
track_metrices: Optional[list[list[str]]] = None,
|
193
|
+
track_metrices_eval_every: int = 5,
|
194
|
+
cleanup: bool = False,
|
195
|
+
):
|
196
|
+
self.path_to_file = path_to_file
|
197
|
+
self.upload = upload
|
198
|
+
self._queue = Queue(maxlen=last_n_params)
|
199
|
+
self._loggers = []
|
200
|
+
self._track_metrices = track_metrices
|
201
|
+
self._value = 0.0
|
202
|
+
self._cleanup = cleanup
|
203
|
+
self._track_metrices_eval_every = track_metrices_eval_every
|
204
|
+
|
205
|
+
def after_training_step(
|
206
|
+
self,
|
207
|
+
i_episode: int,
|
208
|
+
metrices: dict,
|
209
|
+
params: dict,
|
210
|
+
grads: list[dict],
|
211
|
+
sample_eval: dict,
|
212
|
+
loggers: list[ml_utils.Logger | ml_utils.MixinLogger],
|
213
|
+
opt_state,
|
214
|
+
) -> None:
|
215
|
+
if self._track_metrices is None:
|
216
|
+
self._value -= 1.0
|
217
|
+
value = self._value
|
218
|
+
else:
|
219
|
+
if (i_episode % self._track_metrices_eval_every) == 0:
|
220
|
+
value = 0.0
|
221
|
+
N = 0
|
222
|
+
for combination in self._track_metrices:
|
223
|
+
value += _zoom_into_metrices(metrices, combination)
|
224
|
+
N += 1
|
225
|
+
value /= N
|
226
|
+
else:
|
227
|
+
# some very large loss such that it doesn't get added because
|
228
|
+
# we have already added this parameter set
|
229
|
+
value = 1e16
|
230
|
+
|
231
|
+
ele = QueueElement(value, params, i_episode)
|
232
|
+
self._queue.insert(ele)
|
233
|
+
|
234
|
+
self._loggers = loggers
|
235
|
+
|
236
|
+
def close(self):
|
237
|
+
filenames = []
|
238
|
+
for ele in self._queue:
|
239
|
+
if len(self._queue) == 1:
|
240
|
+
filename = parse_path(self.path_to_file, extension="pickle")
|
241
|
+
else:
|
242
|
+
value = "{:.2f}".format(ele.value).replace(".", ",")
|
243
|
+
filename = parse_path(
|
244
|
+
self.path_to_file + f"_episode={ele.episode}_value={value}",
|
245
|
+
extension="pickle",
|
246
|
+
)
|
247
|
+
|
248
|
+
pickle_save(ele.params, filename, overwrite=True)
|
249
|
+
if self.upload:
|
250
|
+
success = False
|
251
|
+
for logger in self._loggers:
|
252
|
+
try:
|
253
|
+
logger.log_params(filename)
|
254
|
+
success = True
|
255
|
+
except NotImplementedError:
|
256
|
+
pass
|
257
|
+
if not success:
|
258
|
+
warnings.warn(
|
259
|
+
"Upload of parameters was requested but no `ml_utils.Logger"
|
260
|
+
"` that implements `logger.log_params` was found."
|
261
|
+
)
|
262
|
+
|
263
|
+
filenames.append(filename)
|
264
|
+
|
265
|
+
if self._cleanup:
|
266
|
+
# wait for upload
|
267
|
+
time.sleep(3)
|
268
|
+
|
269
|
+
for filename in filenames:
|
270
|
+
os.system(f"rm {filename}")
|
271
|
+
|
272
|
+
# delete folder
|
273
|
+
os.system(f"rmdir {str(Path(filename).parent)}")
|
274
|
+
|
275
|
+
|
276
|
+
class LogGradsTrainingLoopCallBack(training_loop.TrainingLoopCallback):
|
277
|
+
def __init__(
|
278
|
+
self,
|
279
|
+
kill_if_larger: Optional[float] = None,
|
280
|
+
consecutive_larger: int = 1,
|
281
|
+
) -> None:
|
282
|
+
self.kill_if_larger = kill_if_larger
|
283
|
+
self.consecutive_larger = consecutive_larger
|
284
|
+
self.last_larger = deque(maxlen=consecutive_larger)
|
285
|
+
|
286
|
+
def after_training_step(
|
287
|
+
self,
|
288
|
+
i_episode: int,
|
289
|
+
metrices: dict,
|
290
|
+
params: dict,
|
291
|
+
grads: list[dict],
|
292
|
+
sample_eval: dict,
|
293
|
+
loggers: list[ml_utils.Logger],
|
294
|
+
opt_state,
|
295
|
+
) -> None:
|
296
|
+
gradient_log = {}
|
297
|
+
for i, grads_tbp in enumerate(grads):
|
298
|
+
grads_flat = tree_utils.batch_concat(grads_tbp, num_batch_dims=0)
|
299
|
+
grads_max = jnp.max(jnp.abs(grads_flat))
|
300
|
+
grads_norm = jnp.linalg.norm(grads_flat)
|
301
|
+
if self.kill_if_larger is not None:
|
302
|
+
if grads_norm > self.kill_if_larger:
|
303
|
+
self.last_larger.append(True)
|
304
|
+
else:
|
305
|
+
self.last_larger.append(False)
|
306
|
+
if all(self.last_larger):
|
307
|
+
training_loop.send_kill_run_signal()
|
308
|
+
gradient_log[f"grads_tbp_{i}_max"] = grads_max
|
309
|
+
gradient_log[f"grads_tbp_{i}_l2norm"] = grads_norm
|
310
|
+
|
311
|
+
metrices.update(gradient_log)
|
312
|
+
|
313
|
+
|
314
|
+
class NanKillRunCallback(training_loop.TrainingLoopCallback):
|
315
|
+
def __init__(
|
316
|
+
self,
|
317
|
+
print: bool = True,
|
318
|
+
) -> None:
|
319
|
+
self.print = print
|
320
|
+
|
321
|
+
def after_training_step(
|
322
|
+
self,
|
323
|
+
i_episode: int,
|
324
|
+
metrices: dict,
|
325
|
+
params: dict,
|
326
|
+
grads: list[dict],
|
327
|
+
sample_eval: dict,
|
328
|
+
loggers: list[ml_utils.Logger],
|
329
|
+
opt_state,
|
330
|
+
) -> None:
|
331
|
+
params_fast_flat = tree_utils.batch_concat(params, num_batch_dims=0)
|
332
|
+
params_is_nan = jnp.any(jnp.isnan(params_fast_flat))
|
333
|
+
|
334
|
+
if params_is_nan:
|
335
|
+
training_loop.send_kill_run_signal()
|
336
|
+
|
337
|
+
if params_is_nan and self.print:
|
338
|
+
print(
|
339
|
+
f"Parameters have converged to NaN at step {i_episode}. Exiting run.."
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
class LogEpisodeTrainingLoopCallback(training_loop.TrainingLoopCallback):
|
344
|
+
def __init__(self, kill_after_episode: Optional[int] = None) -> None:
|
345
|
+
self.kill_after_episode = kill_after_episode
|
346
|
+
|
347
|
+
def after_training_step(
|
348
|
+
self,
|
349
|
+
i_episode: int,
|
350
|
+
metrices: dict,
|
351
|
+
params: dict,
|
352
|
+
grads: list[dict],
|
353
|
+
sample_eval: dict,
|
354
|
+
loggers: list[ml_utils.Logger],
|
355
|
+
opt_state,
|
356
|
+
) -> None:
|
357
|
+
if self.kill_after_episode is not None and (
|
358
|
+
i_episode >= self.kill_after_episode
|
359
|
+
):
|
360
|
+
training_loop.send_kill_run_signal()
|
361
|
+
metrices.update({"i_episode": i_episode})
|
362
|
+
|
363
|
+
|
364
|
+
class TimingKillRunCallback(training_loop.TrainingLoopCallback):
|
365
|
+
def __init__(self, max_run_time_seconds: float) -> None:
|
366
|
+
self.max_run_time_seconds = max_run_time_seconds
|
367
|
+
|
368
|
+
def after_training_step(
|
369
|
+
self,
|
370
|
+
i_episode: int,
|
371
|
+
metrices: dict,
|
372
|
+
params: dict,
|
373
|
+
grads: list[dict],
|
374
|
+
sample_eval: dict,
|
375
|
+
loggers: list[ml_utils.Logger],
|
376
|
+
opt_state,
|
377
|
+
) -> None:
|
378
|
+
runtime = time.time() - ring._TRAIN_TIMING_START
|
379
|
+
if runtime > self.max_run_time_seconds:
|
380
|
+
runtime_h = runtime / 3600
|
381
|
+
print(f"Run is killed due to timing. Current runtime is {runtime_h}h.")
|
382
|
+
training_loop.send_kill_run_signal()
|
383
|
+
|
384
|
+
|
385
|
+
class CheckpointCallback(training_loop.TrainingLoopCallback):
|
386
|
+
def after_training_step(
|
387
|
+
self,
|
388
|
+
i_episode: int,
|
389
|
+
metrices: dict,
|
390
|
+
params: dict,
|
391
|
+
grads: list[dict],
|
392
|
+
sample_eval: dict,
|
393
|
+
loggers: list[ml_utils.Logger],
|
394
|
+
opt_state: tree_utils.PyTree,
|
395
|
+
) -> None:
|
396
|
+
self.params = params
|
397
|
+
self.opt_state = opt_state
|
398
|
+
|
399
|
+
def close(self):
|
400
|
+
# only checkpoint if run has been killed
|
401
|
+
if training_loop.recv_kill_run_signal():
|
402
|
+
path = parse_path(
|
403
|
+
"~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
|
404
|
+
)
|
405
|
+
data = {"params": self.params, "opt_state": self.opt_state}
|
406
|
+
pickle_save(
|
407
|
+
obj=jax.device_get(data),
|
408
|
+
path=path,
|
409
|
+
overwrite=True,
|
410
|
+
)
|
411
|
+
|
412
|
+
|
413
|
+
class WandbKillRun(training_loop.TrainingLoopCallback):
|
414
|
+
def __init__(self, stop_tag: str = "stop"):
|
415
|
+
self.stop_tag = stop_tag
|
416
|
+
|
417
|
+
def after_training_step(
|
418
|
+
self,
|
419
|
+
i_episode: int,
|
420
|
+
metrices: dict,
|
421
|
+
params: dict,
|
422
|
+
grads: list[dict],
|
423
|
+
sample_eval: dict,
|
424
|
+
loggers: list[ml_utils.Logger],
|
425
|
+
opt_state,
|
426
|
+
) -> None:
|
427
|
+
if wandb.run is not None:
|
428
|
+
tags = (
|
429
|
+
wandb.Api(timeout=99)
|
430
|
+
.run(path=f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}")
|
431
|
+
.tags
|
432
|
+
)
|
433
|
+
if self.stop_tag in tags:
|
434
|
+
training_loop.send_kill_run_signal()
|
ring/ml/ml_utils.py
ADDED
@@ -0,0 +1,272 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from functools import partial
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
import pickle
|
6
|
+
import random
|
7
|
+
import time
|
8
|
+
from typing import Optional, Protocol
|
9
|
+
import warnings
|
10
|
+
|
11
|
+
import jax
|
12
|
+
import numpy as np
|
13
|
+
import ring
|
14
|
+
from ring.utils import import_lib
|
15
|
+
from tree_utils import PyTree
|
16
|
+
|
17
|
+
import wandb
|
18
|
+
|
19
|
+
# An arbitrarily nested dictionary with Array leaves; Or strings
|
20
|
+
NestedDict = PyTree
|
21
|
+
STEP_METRIC_NAME = "i_episode"
|
22
|
+
|
23
|
+
|
24
|
+
class Logger(Protocol):
|
25
|
+
def close(self) -> None: ... # noqa: E704
|
26
|
+
|
27
|
+
def log(self, metrics: NestedDict) -> None: ... # noqa: E704
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def n_params(params) -> int:
|
31
|
+
"Number of parameters in Pytree `params`."
|
32
|
+
return sum([arr.flatten().size for arr in jax.tree_util.tree_leaves(params)])
|
33
|
+
|
34
|
+
|
35
|
+
class MixinLogger(Logger):
|
36
|
+
def close(self):
|
37
|
+
pass
|
38
|
+
|
39
|
+
def log_image(self, path: str, caption: Optional[str] = None):
|
40
|
+
raise NotImplementedError
|
41
|
+
|
42
|
+
def log_video(
|
43
|
+
self,
|
44
|
+
path: str,
|
45
|
+
fps: int = 25,
|
46
|
+
caption: Optional[str] = None,
|
47
|
+
step: Optional[int] = None,
|
48
|
+
):
|
49
|
+
raise NotImplementedError
|
50
|
+
|
51
|
+
def log_params(self, path: str):
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
def log(self, metrics: NestedDict):
|
55
|
+
step = metrics[STEP_METRIC_NAME] if STEP_METRIC_NAME in metrics else None
|
56
|
+
for key, value in _flatten_convert_filter_nested_dict(metrics).items():
|
57
|
+
self.log_key_value(key, value, step=step)
|
58
|
+
|
59
|
+
def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
|
60
|
+
raise NotImplementedError
|
61
|
+
|
62
|
+
def log_command_output(self, command: str):
|
63
|
+
path = command.replace(" ", "_") + ".txt"
|
64
|
+
os.system(f"{command} >> {path}")
|
65
|
+
self.log_txt(path, wait=True)
|
66
|
+
os.system(f"rm {path}")
|
67
|
+
|
68
|
+
def log_txt(self, path: str, wait: bool = True):
|
69
|
+
raise NotImplementedError
|
70
|
+
|
71
|
+
def _log_environment(self):
|
72
|
+
self.log_command_output("pip list")
|
73
|
+
self.log_command_output("conda list")
|
74
|
+
self.log_command_output("nvidia-smi")
|
75
|
+
|
76
|
+
|
77
|
+
class DictLogger(MixinLogger):
|
78
|
+
def __init__(self, output_path: Optional[str] = None):
|
79
|
+
self._logs = defaultdict(lambda: [])
|
80
|
+
self._output_path = output_path
|
81
|
+
|
82
|
+
def log_key_value(self, key: str, value: str | float, step: int | None = None):
|
83
|
+
self._logs[key].append(value)
|
84
|
+
|
85
|
+
def close(self):
|
86
|
+
if self._output_path is None:
|
87
|
+
return
|
88
|
+
self.save(self._output_path)
|
89
|
+
|
90
|
+
def save(self, path: str):
|
91
|
+
path = Path(path).with_suffix(".pickle").expanduser()
|
92
|
+
path.mkdir(parents=True, exist_ok=True)
|
93
|
+
with open(path, "wb") as file:
|
94
|
+
pickle.dump(self.get_logs(), file, protocol=5)
|
95
|
+
|
96
|
+
def get_logs(self):
|
97
|
+
return self._logs
|
98
|
+
|
99
|
+
|
100
|
+
class WandbLogger(MixinLogger):
|
101
|
+
def __init__(self):
|
102
|
+
self._log_environment()
|
103
|
+
wandb.run.define_metric(STEP_METRIC_NAME)
|
104
|
+
|
105
|
+
def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
|
106
|
+
data = {key: value}
|
107
|
+
if step is not None:
|
108
|
+
data.update({STEP_METRIC_NAME: step})
|
109
|
+
wandb.log(data)
|
110
|
+
|
111
|
+
def log_params(self, path: str):
|
112
|
+
wandb.save(path, policy="now")
|
113
|
+
|
114
|
+
def log_video(
|
115
|
+
self,
|
116
|
+
path: str,
|
117
|
+
fps: int = 25,
|
118
|
+
caption: Optional[str] = None,
|
119
|
+
step: Optional[int] = None,
|
120
|
+
):
|
121
|
+
# TODO >>>
|
122
|
+
wandb.save(path, policy="now")
|
123
|
+
return
|
124
|
+
# <<<
|
125
|
+
data = {"video": wandb.Video(path, caption=caption, fps=fps)}
|
126
|
+
if step is not None:
|
127
|
+
data.update({STEP_METRIC_NAME: step})
|
128
|
+
wandb.log(data)
|
129
|
+
|
130
|
+
def log_image(self, path: str, caption: Optional[str] = None):
|
131
|
+
# wandb.log({"image": wandb.Image(path, caption=caption)})
|
132
|
+
wandb.save(path, policy="now")
|
133
|
+
|
134
|
+
def log_txt(self, path: str, wait: bool = True):
|
135
|
+
wandb.save(path, policy="now")
|
136
|
+
# TODO: `wandb` is not async at all?
|
137
|
+
if wait:
|
138
|
+
time.sleep(3)
|
139
|
+
|
140
|
+
def close(self):
|
141
|
+
wandb.run.finish()
|
142
|
+
|
143
|
+
|
144
|
+
def _flatten_convert_filter_nested_dict(
|
145
|
+
metrices: NestedDict, filter_nan_inf: bool = True
|
146
|
+
):
|
147
|
+
metrices = _flatten_dict(metrices)
|
148
|
+
metrices = jax.tree_map(_to_float_if_not_string, metrices)
|
149
|
+
|
150
|
+
if not filter_nan_inf:
|
151
|
+
return metrices
|
152
|
+
|
153
|
+
filtered_metrices = {}
|
154
|
+
for key, value in metrices.items():
|
155
|
+
if not isinstance(value, str) and (np.isnan(value) or np.isinf(value)):
|
156
|
+
warning = f"Warning: Value of metric {key} is {value}. We skip it."
|
157
|
+
warnings.warn(warning)
|
158
|
+
continue
|
159
|
+
filtered_metrices[key] = value
|
160
|
+
return filtered_metrices
|
161
|
+
|
162
|
+
|
163
|
+
def _flatten_dict(d, parent_key="", sep="_"):
|
164
|
+
items = []
|
165
|
+
for k, v in d.items():
|
166
|
+
k = str(k) if isinstance(k, int) else k
|
167
|
+
new_key = parent_key + sep + k if parent_key else k
|
168
|
+
if isinstance(v, dict):
|
169
|
+
items.extend(_flatten_dict(v, new_key, sep=sep).items())
|
170
|
+
else:
|
171
|
+
items.append((new_key, v))
|
172
|
+
return dict(items)
|
173
|
+
|
174
|
+
|
175
|
+
def _to_float_if_not_string(value):
|
176
|
+
if isinstance(value, str):
|
177
|
+
return value
|
178
|
+
else:
|
179
|
+
return float(value)
|
180
|
+
|
181
|
+
|
182
|
+
def on_cluster() -> bool:
|
183
|
+
"""Return `true` if executed on cluster."""
|
184
|
+
env_var = os.environ.get("ON_CLUSTER", None)
|
185
|
+
return False if env_var is None else True
|
186
|
+
|
187
|
+
|
188
|
+
def unique_id() -> str:
|
189
|
+
return ring._UNIQUE_ID
|
190
|
+
|
191
|
+
|
192
|
+
def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
193
|
+
from jax.experimental import jax2tf
|
194
|
+
|
195
|
+
tf = import_lib("tensorflow", "the function `save_model_tf`")
|
196
|
+
|
197
|
+
def _create_module(jax_func, input):
|
198
|
+
signature = jax.tree_map(
|
199
|
+
lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
|
200
|
+
)
|
201
|
+
|
202
|
+
class RingTFModule(tf.Module):
|
203
|
+
def __init__(self, jax_func):
|
204
|
+
super().__init__()
|
205
|
+
self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
|
206
|
+
|
207
|
+
@partial(
|
208
|
+
tf.function,
|
209
|
+
autograph=False,
|
210
|
+
jit_compile=True,
|
211
|
+
input_signature=signature,
|
212
|
+
)
|
213
|
+
def __call__(self, *args):
|
214
|
+
return self.tf_func(*args)
|
215
|
+
|
216
|
+
return RingTFModule(jax_func)
|
217
|
+
|
218
|
+
model = _create_module(jax_func, input)
|
219
|
+
tf.saved_model.save(
|
220
|
+
model,
|
221
|
+
path,
|
222
|
+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
223
|
+
)
|
224
|
+
if validate:
|
225
|
+
output_jax = jax_func(*input)
|
226
|
+
output_tf = tf.saved_model.load(path)(*input)
|
227
|
+
jax.tree_map(
|
228
|
+
lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
|
229
|
+
output_jax,
|
230
|
+
output_tf,
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
def train_val_split(
|
235
|
+
tps: list[str],
|
236
|
+
bs: int,
|
237
|
+
n_batches_for_val: int = 1,
|
238
|
+
transform_gen=None,
|
239
|
+
tree_transform=None,
|
240
|
+
):
|
241
|
+
"Uses `random` module for shuffeling."
|
242
|
+
if transform_gen is None:
|
243
|
+
transform_gen = lambda gen: gen
|
244
|
+
|
245
|
+
len_val = n_batches_for_val * bs
|
246
|
+
|
247
|
+
_, N = ring.RCMG.eager_gen_from_paths(tps, 1)
|
248
|
+
include_samples = list(range(N))
|
249
|
+
random.shuffle(include_samples)
|
250
|
+
|
251
|
+
train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
|
252
|
+
X_val, y_val = transform_gen(
|
253
|
+
ring.RCMG.eager_gen_from_paths(
|
254
|
+
tps, len_val, val_data, tree_transform=tree_transform
|
255
|
+
)[0]
|
256
|
+
)(jax.random.PRNGKey(420))
|
257
|
+
|
258
|
+
generator = transform_gen(
|
259
|
+
ring.RCMG.eager_gen_from_paths(
|
260
|
+
tps,
|
261
|
+
bs,
|
262
|
+
train_data,
|
263
|
+
load_all_into_memory=True,
|
264
|
+
tree_transform=tree_transform,
|
265
|
+
)[0]
|
266
|
+
)
|
267
|
+
|
268
|
+
return generator, (X_val, y_val)
|
269
|
+
|
270
|
+
|
271
|
+
def _unknown_link_names(N: int):
|
272
|
+
return [f"link{i}" for i in range(N)]
|