imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
ring/ml/callbacks.py ADDED
@@ -0,0 +1,434 @@
1
+ from collections import deque
2
+ from functools import partial
3
+ import os
4
+ from pathlib import Path
5
+ import time
6
+ from typing import Callable, NamedTuple, Optional
7
+ import warnings
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ import tree_utils
13
+
14
+ import ring
15
+ from ring.ml import base
16
+ from ring.ml import ml_utils
17
+ from ring.ml import training_loop
18
+ from ring.utils import distribute_batchsize
19
+ from ring.utils import expand_batchsize
20
+ from ring.utils import merge_batchsize
21
+ from ring.utils import parse_path
22
+ from ring.utils import pickle_save
23
+ import wandb
24
+
25
+
26
+ def _build_eval_fn2(
27
+ eval_metrices: dict[str, Callable],
28
+ filter: base.AbstractFilter,
29
+ X: jax.Array,
30
+ y: jax.Array,
31
+ lam: tuple[int] | None,
32
+ link_names: list[str] | None,
33
+ ):
34
+ filter = filter.nojit()
35
+ assert X.ndim == 5
36
+ assert y.ndim == 5
37
+ y_4d = merge_batchsize(y, X.shape[0], X.shape[1])
38
+
39
+ if link_names is None:
40
+ link_names = ml_utils._unknown_link_names(y.shape[-2])
41
+
42
+ @partial(jax.pmap, in_axes=(None, 0, 0))
43
+ def pmap_vmap_apply(params, X, y):
44
+ return filter.apply(X=X, params=params, lam=lam, y=y)[0]
45
+
46
+ def eval_fn(params):
47
+ yhat = pmap_vmap_apply(params, X, y)
48
+ yhat = merge_batchsize(yhat, X.shape[0], X.shape[1])
49
+
50
+ values = {}
51
+ for metric_name, metric_fn in eval_metrices.items():
52
+ assert (
53
+ metric_name not in values
54
+ ), f"The metric identitifier {metric_name} is not unique"
55
+ value = jax.vmap(metric_fn, in_axes=(2, 2))(y_4d, yhat)
56
+ assert value.ndim == 1, f"{value.shape}"
57
+ value = {name: value[i] for i, name in enumerate(link_names)}
58
+ values[metric_name] = value
59
+ return values
60
+
61
+ return eval_fn
62
+
63
+
64
+ class EvalXyTrainingLoopCallback(training_loop.TrainingLoopCallback):
65
+ def __init__(
66
+ self,
67
+ filter: base.AbstractFilter,
68
+ eval_metrices: dict[str, Callable],
69
+ X: jax.Array,
70
+ y: jax.Array,
71
+ lam: tuple[int] | None,
72
+ metric_identifier: str,
73
+ eval_every: int = 5,
74
+ link_names: Optional[list[str]] = None,
75
+ ):
76
+ """X, y can be batched or unbatched.
77
+ Args:
78
+ eval_metrices: "(B, T, 1) -> () and links N are vmapped."
79
+ """
80
+ if X.ndim == 3:
81
+ X, y = X[None], y[None]
82
+ B = X.shape[0]
83
+ X, y = expand_batchsize((X, y), *distribute_batchsize(B))
84
+ self.eval_fn = _build_eval_fn2(
85
+ eval_metrices,
86
+ filter,
87
+ X,
88
+ y,
89
+ lam,
90
+ link_names,
91
+ )
92
+ self.eval_every = eval_every
93
+ self.metric_identifier = metric_identifier
94
+
95
+ def after_training_step(
96
+ self,
97
+ i_episode: int,
98
+ metrices: dict,
99
+ params: dict,
100
+ grads: list[dict],
101
+ sample_eval: dict,
102
+ loggers: list[ml_utils.Logger],
103
+ opt_state,
104
+ ):
105
+ if self.eval_every == -1:
106
+ return
107
+
108
+ if (i_episode % self.eval_every) == 0:
109
+ point_estimates = self.eval_fn(params)
110
+ self.last_metrices = {self.metric_identifier: point_estimates}
111
+ metrices.update(self.last_metrices)
112
+
113
+
114
+ class AverageMetricesTLCB(training_loop.TrainingLoopCallback):
115
+ def __init__(self, metrices_names: list[list[str]], name: str):
116
+ self.zoom_ins = metrices_names
117
+ self.name = name
118
+
119
+ def after_training_step(
120
+ self,
121
+ i_episode: int,
122
+ metrices: dict,
123
+ params: dict,
124
+ grads: list[dict],
125
+ sample_eval: dict,
126
+ loggers: list[ml_utils.Logger],
127
+ opt_state,
128
+ ) -> None:
129
+ value = 0
130
+ N = 0
131
+ for zoom_in in self.zoom_ins:
132
+ value_zoom_in = _zoom_into_metrices(metrices, zoom_in)
133
+
134
+ if np.isnan(value_zoom_in) or np.isinf(value_zoom_in):
135
+ warning = (
136
+ f"Value of zoom_in={zoom_in} is {value_zoom_in}. "
137
+ + f"It is not added to the metric {self.name}"
138
+ )
139
+ warnings.warn(warning)
140
+ continue
141
+
142
+ value += value_zoom_in
143
+ N += 1
144
+
145
+ if N > 0:
146
+ metrices.update({self.name: value / N})
147
+
148
+
149
+ class QueueElement(NamedTuple):
150
+ value: float
151
+ params: dict
152
+ episode: int
153
+
154
+
155
+ class Queue:
156
+ def __init__(self, maxlen: int = 1):
157
+ self._storage: list[QueueElement] = []
158
+ self.maxlen = maxlen
159
+
160
+ def __len__(self) -> int:
161
+ return len(self._storage)
162
+
163
+ def insert(self, ele: QueueElement) -> None:
164
+ sort = True
165
+ if len(self) < self.maxlen:
166
+ self._storage.append(ele)
167
+ elif ele.value < self._storage[-1].value:
168
+ self._storage[-1] = ele
169
+ else:
170
+ sort = False
171
+
172
+ if sort:
173
+ self._storage.sort(key=lambda ele: ele.value)
174
+
175
+ def __iter__(self):
176
+ return iter(self._storage)
177
+
178
+
179
+ def _zoom_into_metrices(metrices: dict, zoom_in: list[str]) -> float:
180
+ zoomed_out = metrices
181
+ for key in zoom_in:
182
+ zoomed_out = zoomed_out[key]
183
+ return float(zoomed_out)
184
+
185
+
186
+ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
187
+ def __init__(
188
+ self,
189
+ path_to_file: str,
190
+ upload: bool = True,
191
+ last_n_params: int = 1,
192
+ track_metrices: Optional[list[list[str]]] = None,
193
+ track_metrices_eval_every: int = 5,
194
+ cleanup: bool = False,
195
+ ):
196
+ self.path_to_file = path_to_file
197
+ self.upload = upload
198
+ self._queue = Queue(maxlen=last_n_params)
199
+ self._loggers = []
200
+ self._track_metrices = track_metrices
201
+ self._value = 0.0
202
+ self._cleanup = cleanup
203
+ self._track_metrices_eval_every = track_metrices_eval_every
204
+
205
+ def after_training_step(
206
+ self,
207
+ i_episode: int,
208
+ metrices: dict,
209
+ params: dict,
210
+ grads: list[dict],
211
+ sample_eval: dict,
212
+ loggers: list[ml_utils.Logger | ml_utils.MixinLogger],
213
+ opt_state,
214
+ ) -> None:
215
+ if self._track_metrices is None:
216
+ self._value -= 1.0
217
+ value = self._value
218
+ else:
219
+ if (i_episode % self._track_metrices_eval_every) == 0:
220
+ value = 0.0
221
+ N = 0
222
+ for combination in self._track_metrices:
223
+ value += _zoom_into_metrices(metrices, combination)
224
+ N += 1
225
+ value /= N
226
+ else:
227
+ # some very large loss such that it doesn't get added because
228
+ # we have already added this parameter set
229
+ value = 1e16
230
+
231
+ ele = QueueElement(value, params, i_episode)
232
+ self._queue.insert(ele)
233
+
234
+ self._loggers = loggers
235
+
236
+ def close(self):
237
+ filenames = []
238
+ for ele in self._queue:
239
+ if len(self._queue) == 1:
240
+ filename = parse_path(self.path_to_file, extension="pickle")
241
+ else:
242
+ value = "{:.2f}".format(ele.value).replace(".", ",")
243
+ filename = parse_path(
244
+ self.path_to_file + f"_episode={ele.episode}_value={value}",
245
+ extension="pickle",
246
+ )
247
+
248
+ pickle_save(ele.params, filename, overwrite=True)
249
+ if self.upload:
250
+ success = False
251
+ for logger in self._loggers:
252
+ try:
253
+ logger.log_params(filename)
254
+ success = True
255
+ except NotImplementedError:
256
+ pass
257
+ if not success:
258
+ warnings.warn(
259
+ "Upload of parameters was requested but no `ml_utils.Logger"
260
+ "` that implements `logger.log_params` was found."
261
+ )
262
+
263
+ filenames.append(filename)
264
+
265
+ if self._cleanup:
266
+ # wait for upload
267
+ time.sleep(3)
268
+
269
+ for filename in filenames:
270
+ os.system(f"rm {filename}")
271
+
272
+ # delete folder
273
+ os.system(f"rmdir {str(Path(filename).parent)}")
274
+
275
+
276
+ class LogGradsTrainingLoopCallBack(training_loop.TrainingLoopCallback):
277
+ def __init__(
278
+ self,
279
+ kill_if_larger: Optional[float] = None,
280
+ consecutive_larger: int = 1,
281
+ ) -> None:
282
+ self.kill_if_larger = kill_if_larger
283
+ self.consecutive_larger = consecutive_larger
284
+ self.last_larger = deque(maxlen=consecutive_larger)
285
+
286
+ def after_training_step(
287
+ self,
288
+ i_episode: int,
289
+ metrices: dict,
290
+ params: dict,
291
+ grads: list[dict],
292
+ sample_eval: dict,
293
+ loggers: list[ml_utils.Logger],
294
+ opt_state,
295
+ ) -> None:
296
+ gradient_log = {}
297
+ for i, grads_tbp in enumerate(grads):
298
+ grads_flat = tree_utils.batch_concat(grads_tbp, num_batch_dims=0)
299
+ grads_max = jnp.max(jnp.abs(grads_flat))
300
+ grads_norm = jnp.linalg.norm(grads_flat)
301
+ if self.kill_if_larger is not None:
302
+ if grads_norm > self.kill_if_larger:
303
+ self.last_larger.append(True)
304
+ else:
305
+ self.last_larger.append(False)
306
+ if all(self.last_larger):
307
+ training_loop.send_kill_run_signal()
308
+ gradient_log[f"grads_tbp_{i}_max"] = grads_max
309
+ gradient_log[f"grads_tbp_{i}_l2norm"] = grads_norm
310
+
311
+ metrices.update(gradient_log)
312
+
313
+
314
+ class NanKillRunCallback(training_loop.TrainingLoopCallback):
315
+ def __init__(
316
+ self,
317
+ print: bool = True,
318
+ ) -> None:
319
+ self.print = print
320
+
321
+ def after_training_step(
322
+ self,
323
+ i_episode: int,
324
+ metrices: dict,
325
+ params: dict,
326
+ grads: list[dict],
327
+ sample_eval: dict,
328
+ loggers: list[ml_utils.Logger],
329
+ opt_state,
330
+ ) -> None:
331
+ params_fast_flat = tree_utils.batch_concat(params, num_batch_dims=0)
332
+ params_is_nan = jnp.any(jnp.isnan(params_fast_flat))
333
+
334
+ if params_is_nan:
335
+ training_loop.send_kill_run_signal()
336
+
337
+ if params_is_nan and self.print:
338
+ print(
339
+ f"Parameters have converged to NaN at step {i_episode}. Exiting run.."
340
+ )
341
+
342
+
343
+ class LogEpisodeTrainingLoopCallback(training_loop.TrainingLoopCallback):
344
+ def __init__(self, kill_after_episode: Optional[int] = None) -> None:
345
+ self.kill_after_episode = kill_after_episode
346
+
347
+ def after_training_step(
348
+ self,
349
+ i_episode: int,
350
+ metrices: dict,
351
+ params: dict,
352
+ grads: list[dict],
353
+ sample_eval: dict,
354
+ loggers: list[ml_utils.Logger],
355
+ opt_state,
356
+ ) -> None:
357
+ if self.kill_after_episode is not None and (
358
+ i_episode >= self.kill_after_episode
359
+ ):
360
+ training_loop.send_kill_run_signal()
361
+ metrices.update({"i_episode": i_episode})
362
+
363
+
364
+ class TimingKillRunCallback(training_loop.TrainingLoopCallback):
365
+ def __init__(self, max_run_time_seconds: float) -> None:
366
+ self.max_run_time_seconds = max_run_time_seconds
367
+
368
+ def after_training_step(
369
+ self,
370
+ i_episode: int,
371
+ metrices: dict,
372
+ params: dict,
373
+ grads: list[dict],
374
+ sample_eval: dict,
375
+ loggers: list[ml_utils.Logger],
376
+ opt_state,
377
+ ) -> None:
378
+ runtime = time.time() - ring._TRAIN_TIMING_START
379
+ if runtime > self.max_run_time_seconds:
380
+ runtime_h = runtime / 3600
381
+ print(f"Run is killed due to timing. Current runtime is {runtime_h}h.")
382
+ training_loop.send_kill_run_signal()
383
+
384
+
385
+ class CheckpointCallback(training_loop.TrainingLoopCallback):
386
+ def after_training_step(
387
+ self,
388
+ i_episode: int,
389
+ metrices: dict,
390
+ params: dict,
391
+ grads: list[dict],
392
+ sample_eval: dict,
393
+ loggers: list[ml_utils.Logger],
394
+ opt_state: tree_utils.PyTree,
395
+ ) -> None:
396
+ self.params = params
397
+ self.opt_state = opt_state
398
+
399
+ def close(self):
400
+ # only checkpoint if run has been killed
401
+ if training_loop.recv_kill_run_signal():
402
+ path = parse_path(
403
+ "~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
404
+ )
405
+ data = {"params": self.params, "opt_state": self.opt_state}
406
+ pickle_save(
407
+ obj=jax.device_get(data),
408
+ path=path,
409
+ overwrite=True,
410
+ )
411
+
412
+
413
+ class WandbKillRun(training_loop.TrainingLoopCallback):
414
+ def __init__(self, stop_tag: str = "stop"):
415
+ self.stop_tag = stop_tag
416
+
417
+ def after_training_step(
418
+ self,
419
+ i_episode: int,
420
+ metrices: dict,
421
+ params: dict,
422
+ grads: list[dict],
423
+ sample_eval: dict,
424
+ loggers: list[ml_utils.Logger],
425
+ opt_state,
426
+ ) -> None:
427
+ if wandb.run is not None:
428
+ tags = (
429
+ wandb.Api(timeout=99)
430
+ .run(path=f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}")
431
+ .tags
432
+ )
433
+ if self.stop_tag in tags:
434
+ training_loop.send_kill_run_signal()
ring/ml/ml_utils.py ADDED
@@ -0,0 +1,272 @@
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ import os
4
+ from pathlib import Path
5
+ import pickle
6
+ import random
7
+ import time
8
+ from typing import Optional, Protocol
9
+ import warnings
10
+
11
+ import jax
12
+ import numpy as np
13
+ import ring
14
+ from ring.utils import import_lib
15
+ from tree_utils import PyTree
16
+
17
+ import wandb
18
+
19
+ # An arbitrarily nested dictionary with Array leaves; Or strings
20
+ NestedDict = PyTree
21
+ STEP_METRIC_NAME = "i_episode"
22
+
23
+
24
+ class Logger(Protocol):
25
+ def close(self) -> None: ... # noqa: E704
26
+
27
+ def log(self, metrics: NestedDict) -> None: ... # noqa: E704
28
+
29
+ @staticmethod
30
+ def n_params(params) -> int:
31
+ "Number of parameters in Pytree `params`."
32
+ return sum([arr.flatten().size for arr in jax.tree_util.tree_leaves(params)])
33
+
34
+
35
+ class MixinLogger(Logger):
36
+ def close(self):
37
+ pass
38
+
39
+ def log_image(self, path: str, caption: Optional[str] = None):
40
+ raise NotImplementedError
41
+
42
+ def log_video(
43
+ self,
44
+ path: str,
45
+ fps: int = 25,
46
+ caption: Optional[str] = None,
47
+ step: Optional[int] = None,
48
+ ):
49
+ raise NotImplementedError
50
+
51
+ def log_params(self, path: str):
52
+ raise NotImplementedError
53
+
54
+ def log(self, metrics: NestedDict):
55
+ step = metrics[STEP_METRIC_NAME] if STEP_METRIC_NAME in metrics else None
56
+ for key, value in _flatten_convert_filter_nested_dict(metrics).items():
57
+ self.log_key_value(key, value, step=step)
58
+
59
+ def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
60
+ raise NotImplementedError
61
+
62
+ def log_command_output(self, command: str):
63
+ path = command.replace(" ", "_") + ".txt"
64
+ os.system(f"{command} >> {path}")
65
+ self.log_txt(path, wait=True)
66
+ os.system(f"rm {path}")
67
+
68
+ def log_txt(self, path: str, wait: bool = True):
69
+ raise NotImplementedError
70
+
71
+ def _log_environment(self):
72
+ self.log_command_output("pip list")
73
+ self.log_command_output("conda list")
74
+ self.log_command_output("nvidia-smi")
75
+
76
+
77
+ class DictLogger(MixinLogger):
78
+ def __init__(self, output_path: Optional[str] = None):
79
+ self._logs = defaultdict(lambda: [])
80
+ self._output_path = output_path
81
+
82
+ def log_key_value(self, key: str, value: str | float, step: int | None = None):
83
+ self._logs[key].append(value)
84
+
85
+ def close(self):
86
+ if self._output_path is None:
87
+ return
88
+ self.save(self._output_path)
89
+
90
+ def save(self, path: str):
91
+ path = Path(path).with_suffix(".pickle").expanduser()
92
+ path.mkdir(parents=True, exist_ok=True)
93
+ with open(path, "wb") as file:
94
+ pickle.dump(self.get_logs(), file, protocol=5)
95
+
96
+ def get_logs(self):
97
+ return self._logs
98
+
99
+
100
+ class WandbLogger(MixinLogger):
101
+ def __init__(self):
102
+ self._log_environment()
103
+ wandb.run.define_metric(STEP_METRIC_NAME)
104
+
105
+ def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
106
+ data = {key: value}
107
+ if step is not None:
108
+ data.update({STEP_METRIC_NAME: step})
109
+ wandb.log(data)
110
+
111
+ def log_params(self, path: str):
112
+ wandb.save(path, policy="now")
113
+
114
+ def log_video(
115
+ self,
116
+ path: str,
117
+ fps: int = 25,
118
+ caption: Optional[str] = None,
119
+ step: Optional[int] = None,
120
+ ):
121
+ # TODO >>>
122
+ wandb.save(path, policy="now")
123
+ return
124
+ # <<<
125
+ data = {"video": wandb.Video(path, caption=caption, fps=fps)}
126
+ if step is not None:
127
+ data.update({STEP_METRIC_NAME: step})
128
+ wandb.log(data)
129
+
130
+ def log_image(self, path: str, caption: Optional[str] = None):
131
+ # wandb.log({"image": wandb.Image(path, caption=caption)})
132
+ wandb.save(path, policy="now")
133
+
134
+ def log_txt(self, path: str, wait: bool = True):
135
+ wandb.save(path, policy="now")
136
+ # TODO: `wandb` is not async at all?
137
+ if wait:
138
+ time.sleep(3)
139
+
140
+ def close(self):
141
+ wandb.run.finish()
142
+
143
+
144
+ def _flatten_convert_filter_nested_dict(
145
+ metrices: NestedDict, filter_nan_inf: bool = True
146
+ ):
147
+ metrices = _flatten_dict(metrices)
148
+ metrices = jax.tree_map(_to_float_if_not_string, metrices)
149
+
150
+ if not filter_nan_inf:
151
+ return metrices
152
+
153
+ filtered_metrices = {}
154
+ for key, value in metrices.items():
155
+ if not isinstance(value, str) and (np.isnan(value) or np.isinf(value)):
156
+ warning = f"Warning: Value of metric {key} is {value}. We skip it."
157
+ warnings.warn(warning)
158
+ continue
159
+ filtered_metrices[key] = value
160
+ return filtered_metrices
161
+
162
+
163
+ def _flatten_dict(d, parent_key="", sep="_"):
164
+ items = []
165
+ for k, v in d.items():
166
+ k = str(k) if isinstance(k, int) else k
167
+ new_key = parent_key + sep + k if parent_key else k
168
+ if isinstance(v, dict):
169
+ items.extend(_flatten_dict(v, new_key, sep=sep).items())
170
+ else:
171
+ items.append((new_key, v))
172
+ return dict(items)
173
+
174
+
175
+ def _to_float_if_not_string(value):
176
+ if isinstance(value, str):
177
+ return value
178
+ else:
179
+ return float(value)
180
+
181
+
182
+ def on_cluster() -> bool:
183
+ """Return `true` if executed on cluster."""
184
+ env_var = os.environ.get("ON_CLUSTER", None)
185
+ return False if env_var is None else True
186
+
187
+
188
+ def unique_id() -> str:
189
+ return ring._UNIQUE_ID
190
+
191
+
192
+ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
193
+ from jax.experimental import jax2tf
194
+
195
+ tf = import_lib("tensorflow", "the function `save_model_tf`")
196
+
197
+ def _create_module(jax_func, input):
198
+ signature = jax.tree_map(
199
+ lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
200
+ )
201
+
202
+ class RingTFModule(tf.Module):
203
+ def __init__(self, jax_func):
204
+ super().__init__()
205
+ self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
206
+
207
+ @partial(
208
+ tf.function,
209
+ autograph=False,
210
+ jit_compile=True,
211
+ input_signature=signature,
212
+ )
213
+ def __call__(self, *args):
214
+ return self.tf_func(*args)
215
+
216
+ return RingTFModule(jax_func)
217
+
218
+ model = _create_module(jax_func, input)
219
+ tf.saved_model.save(
220
+ model,
221
+ path,
222
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
223
+ )
224
+ if validate:
225
+ output_jax = jax_func(*input)
226
+ output_tf = tf.saved_model.load(path)(*input)
227
+ jax.tree_map(
228
+ lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
229
+ output_jax,
230
+ output_tf,
231
+ )
232
+
233
+
234
+ def train_val_split(
235
+ tps: list[str],
236
+ bs: int,
237
+ n_batches_for_val: int = 1,
238
+ transform_gen=None,
239
+ tree_transform=None,
240
+ ):
241
+ "Uses `random` module for shuffeling."
242
+ if transform_gen is None:
243
+ transform_gen = lambda gen: gen
244
+
245
+ len_val = n_batches_for_val * bs
246
+
247
+ _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
248
+ include_samples = list(range(N))
249
+ random.shuffle(include_samples)
250
+
251
+ train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
252
+ X_val, y_val = transform_gen(
253
+ ring.RCMG.eager_gen_from_paths(
254
+ tps, len_val, val_data, tree_transform=tree_transform
255
+ )[0]
256
+ )(jax.random.PRNGKey(420))
257
+
258
+ generator = transform_gen(
259
+ ring.RCMG.eager_gen_from_paths(
260
+ tps,
261
+ bs,
262
+ train_data,
263
+ load_all_into_memory=True,
264
+ tree_transform=tree_transform,
265
+ )[0]
266
+ )
267
+
268
+ return generator, (X_val, y_val)
269
+
270
+
271
+ def _unknown_link_names(N: int):
272
+ return [f"link{i}" for i in range(N)]