imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
    
        ring/ml/callbacks.py
    ADDED
    
    | 
         @@ -0,0 +1,434 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from collections import deque
         
     | 
| 
      
 2 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 3 
     | 
    
         
            +
            import os
         
     | 
| 
      
 4 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 5 
     | 
    
         
            +
            import time
         
     | 
| 
      
 6 
     | 
    
         
            +
            from typing import Callable, NamedTuple, Optional
         
     | 
| 
      
 7 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 10 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 11 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 12 
     | 
    
         
            +
            import tree_utils
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 15 
     | 
    
         
            +
            from ring.ml import base
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ring.ml import ml_utils
         
     | 
| 
      
 17 
     | 
    
         
            +
            from ring.ml import training_loop
         
     | 
| 
      
 18 
     | 
    
         
            +
            from ring.utils import distribute_batchsize
         
     | 
| 
      
 19 
     | 
    
         
            +
            from ring.utils import expand_batchsize
         
     | 
| 
      
 20 
     | 
    
         
            +
            from ring.utils import merge_batchsize
         
     | 
| 
      
 21 
     | 
    
         
            +
            from ring.utils import parse_path
         
     | 
| 
      
 22 
     | 
    
         
            +
            from ring.utils import pickle_save
         
     | 
| 
      
 23 
     | 
    
         
            +
            import wandb
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
            def _build_eval_fn2(
         
     | 
| 
      
 27 
     | 
    
         
            +
                eval_metrices: dict[str, Callable],
         
     | 
| 
      
 28 
     | 
    
         
            +
                filter: base.AbstractFilter,
         
     | 
| 
      
 29 
     | 
    
         
            +
                X: jax.Array,
         
     | 
| 
      
 30 
     | 
    
         
            +
                y: jax.Array,
         
     | 
| 
      
 31 
     | 
    
         
            +
                lam: tuple[int] | None,
         
     | 
| 
      
 32 
     | 
    
         
            +
                link_names: list[str] | None,
         
     | 
| 
      
 33 
     | 
    
         
            +
            ):
         
     | 
| 
      
 34 
     | 
    
         
            +
                filter = filter.nojit()
         
     | 
| 
      
 35 
     | 
    
         
            +
                assert X.ndim == 5
         
     | 
| 
      
 36 
     | 
    
         
            +
                assert y.ndim == 5
         
     | 
| 
      
 37 
     | 
    
         
            +
                y_4d = merge_batchsize(y, X.shape[0], X.shape[1])
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                if link_names is None:
         
     | 
| 
      
 40 
     | 
    
         
            +
                    link_names = ml_utils._unknown_link_names(y.shape[-2])
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                @partial(jax.pmap, in_axes=(None, 0, 0))
         
     | 
| 
      
 43 
     | 
    
         
            +
                def pmap_vmap_apply(params, X, y):
         
     | 
| 
      
 44 
     | 
    
         
            +
                    return filter.apply(X=X, params=params, lam=lam, y=y)[0]
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
                def eval_fn(params):
         
     | 
| 
      
 47 
     | 
    
         
            +
                    yhat = pmap_vmap_apply(params, X, y)
         
     | 
| 
      
 48 
     | 
    
         
            +
                    yhat = merge_batchsize(yhat, X.shape[0], X.shape[1])
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                    values = {}
         
     | 
| 
      
 51 
     | 
    
         
            +
                    for metric_name, metric_fn in eval_metrices.items():
         
     | 
| 
      
 52 
     | 
    
         
            +
                        assert (
         
     | 
| 
      
 53 
     | 
    
         
            +
                            metric_name not in values
         
     | 
| 
      
 54 
     | 
    
         
            +
                        ), f"The metric identitifier {metric_name} is not unique"
         
     | 
| 
      
 55 
     | 
    
         
            +
                        value = jax.vmap(metric_fn, in_axes=(2, 2))(y_4d, yhat)
         
     | 
| 
      
 56 
     | 
    
         
            +
                        assert value.ndim == 1, f"{value.shape}"
         
     | 
| 
      
 57 
     | 
    
         
            +
                        value = {name: value[i] for i, name in enumerate(link_names)}
         
     | 
| 
      
 58 
     | 
    
         
            +
                        values[metric_name] = value
         
     | 
| 
      
 59 
     | 
    
         
            +
                    return values
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                return eval_fn
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
            class EvalXyTrainingLoopCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 65 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 66 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    filter: base.AbstractFilter,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    eval_metrices: dict[str, Callable],
         
     | 
| 
      
 69 
     | 
    
         
            +
                    X: jax.Array,
         
     | 
| 
      
 70 
     | 
    
         
            +
                    y: jax.Array,
         
     | 
| 
      
 71 
     | 
    
         
            +
                    lam: tuple[int] | None,
         
     | 
| 
      
 72 
     | 
    
         
            +
                    metric_identifier: str,
         
     | 
| 
      
 73 
     | 
    
         
            +
                    eval_every: int = 5,
         
     | 
| 
      
 74 
     | 
    
         
            +
                    link_names: Optional[list[str]] = None,
         
     | 
| 
      
 75 
     | 
    
         
            +
                ):
         
     | 
| 
      
 76 
     | 
    
         
            +
                    """X, y can be batched or unbatched.
         
     | 
| 
      
 77 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 78 
     | 
    
         
            +
                        eval_metrices: "(B, T, 1) -> () and links N are vmapped."
         
     | 
| 
      
 79 
     | 
    
         
            +
                    """
         
     | 
| 
      
 80 
     | 
    
         
            +
                    if X.ndim == 3:
         
     | 
| 
      
 81 
     | 
    
         
            +
                        X, y = X[None], y[None]
         
     | 
| 
      
 82 
     | 
    
         
            +
                    B = X.shape[0]
         
     | 
| 
      
 83 
     | 
    
         
            +
                    X, y = expand_batchsize((X, y), *distribute_batchsize(B))
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self.eval_fn = _build_eval_fn2(
         
     | 
| 
      
 85 
     | 
    
         
            +
                        eval_metrices,
         
     | 
| 
      
 86 
     | 
    
         
            +
                        filter,
         
     | 
| 
      
 87 
     | 
    
         
            +
                        X,
         
     | 
| 
      
 88 
     | 
    
         
            +
                        y,
         
     | 
| 
      
 89 
     | 
    
         
            +
                        lam,
         
     | 
| 
      
 90 
     | 
    
         
            +
                        link_names,
         
     | 
| 
      
 91 
     | 
    
         
            +
                    )
         
     | 
| 
      
 92 
     | 
    
         
            +
                    self.eval_every = eval_every
         
     | 
| 
      
 93 
     | 
    
         
            +
                    self.metric_identifier = metric_identifier
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 96 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 97 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 98 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 99 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 100 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 101 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 103 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 104 
     | 
    
         
            +
                ):
         
     | 
| 
      
 105 
     | 
    
         
            +
                    if self.eval_every == -1:
         
     | 
| 
      
 106 
     | 
    
         
            +
                        return
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                    if (i_episode % self.eval_every) == 0:
         
     | 
| 
      
 109 
     | 
    
         
            +
                        point_estimates = self.eval_fn(params)
         
     | 
| 
      
 110 
     | 
    
         
            +
                        self.last_metrices = {self.metric_identifier: point_estimates}
         
     | 
| 
      
 111 
     | 
    
         
            +
                    metrices.update(self.last_metrices)
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
            class AverageMetricesTLCB(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 115 
     | 
    
         
            +
                def __init__(self, metrices_names: list[list[str]], name: str):
         
     | 
| 
      
 116 
     | 
    
         
            +
                    self.zoom_ins = metrices_names
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self.name = name
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 120 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 121 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 122 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 123 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 124 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 125 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 126 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 127 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 128 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 129 
     | 
    
         
            +
                    value = 0
         
     | 
| 
      
 130 
     | 
    
         
            +
                    N = 0
         
     | 
| 
      
 131 
     | 
    
         
            +
                    for zoom_in in self.zoom_ins:
         
     | 
| 
      
 132 
     | 
    
         
            +
                        value_zoom_in = _zoom_into_metrices(metrices, zoom_in)
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                        if np.isnan(value_zoom_in) or np.isinf(value_zoom_in):
         
     | 
| 
      
 135 
     | 
    
         
            +
                            warning = (
         
     | 
| 
      
 136 
     | 
    
         
            +
                                f"Value of zoom_in={zoom_in} is {value_zoom_in}. "
         
     | 
| 
      
 137 
     | 
    
         
            +
                                + f"It is not added to the metric {self.name}"
         
     | 
| 
      
 138 
     | 
    
         
            +
                            )
         
     | 
| 
      
 139 
     | 
    
         
            +
                            warnings.warn(warning)
         
     | 
| 
      
 140 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                        value += value_zoom_in
         
     | 
| 
      
 143 
     | 
    
         
            +
                        N += 1
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                    if N > 0:
         
     | 
| 
      
 146 
     | 
    
         
            +
                        metrices.update({self.name: value / N})
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
            class QueueElement(NamedTuple):
         
     | 
| 
      
 150 
     | 
    
         
            +
                value: float
         
     | 
| 
      
 151 
     | 
    
         
            +
                params: dict
         
     | 
| 
      
 152 
     | 
    
         
            +
                episode: int
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
            class Queue:
         
     | 
| 
      
 156 
     | 
    
         
            +
                def __init__(self, maxlen: int = 1):
         
     | 
| 
      
 157 
     | 
    
         
            +
                    self._storage: list[QueueElement] = []
         
     | 
| 
      
 158 
     | 
    
         
            +
                    self.maxlen = maxlen
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
                def __len__(self) -> int:
         
     | 
| 
      
 161 
     | 
    
         
            +
                    return len(self._storage)
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                def insert(self, ele: QueueElement) -> None:
         
     | 
| 
      
 164 
     | 
    
         
            +
                    sort = True
         
     | 
| 
      
 165 
     | 
    
         
            +
                    if len(self) < self.maxlen:
         
     | 
| 
      
 166 
     | 
    
         
            +
                        self._storage.append(ele)
         
     | 
| 
      
 167 
     | 
    
         
            +
                    elif ele.value < self._storage[-1].value:
         
     | 
| 
      
 168 
     | 
    
         
            +
                        self._storage[-1] = ele
         
     | 
| 
      
 169 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 170 
     | 
    
         
            +
                        sort = False
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                    if sort:
         
     | 
| 
      
 173 
     | 
    
         
            +
                        self._storage.sort(key=lambda ele: ele.value)
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
                def __iter__(self):
         
     | 
| 
      
 176 
     | 
    
         
            +
                    return iter(self._storage)
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
            def _zoom_into_metrices(metrices: dict, zoom_in: list[str]) -> float:
         
     | 
| 
      
 180 
     | 
    
         
            +
                zoomed_out = metrices
         
     | 
| 
      
 181 
     | 
    
         
            +
                for key in zoom_in:
         
     | 
| 
      
 182 
     | 
    
         
            +
                    zoomed_out = zoomed_out[key]
         
     | 
| 
      
 183 
     | 
    
         
            +
                return float(zoomed_out)
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
            class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 187 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 188 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 189 
     | 
    
         
            +
                    path_to_file: str,
         
     | 
| 
      
 190 
     | 
    
         
            +
                    upload: bool = True,
         
     | 
| 
      
 191 
     | 
    
         
            +
                    last_n_params: int = 1,
         
     | 
| 
      
 192 
     | 
    
         
            +
                    track_metrices: Optional[list[list[str]]] = None,
         
     | 
| 
      
 193 
     | 
    
         
            +
                    track_metrices_eval_every: int = 5,
         
     | 
| 
      
 194 
     | 
    
         
            +
                    cleanup: bool = False,
         
     | 
| 
      
 195 
     | 
    
         
            +
                ):
         
     | 
| 
      
 196 
     | 
    
         
            +
                    self.path_to_file = path_to_file
         
     | 
| 
      
 197 
     | 
    
         
            +
                    self.upload = upload
         
     | 
| 
      
 198 
     | 
    
         
            +
                    self._queue = Queue(maxlen=last_n_params)
         
     | 
| 
      
 199 
     | 
    
         
            +
                    self._loggers = []
         
     | 
| 
      
 200 
     | 
    
         
            +
                    self._track_metrices = track_metrices
         
     | 
| 
      
 201 
     | 
    
         
            +
                    self._value = 0.0
         
     | 
| 
      
 202 
     | 
    
         
            +
                    self._cleanup = cleanup
         
     | 
| 
      
 203 
     | 
    
         
            +
                    self._track_metrices_eval_every = track_metrices_eval_every
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 206 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 207 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 208 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 209 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 210 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 211 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 212 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger | ml_utils.MixinLogger],
         
     | 
| 
      
 213 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 214 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 215 
     | 
    
         
            +
                    if self._track_metrices is None:
         
     | 
| 
      
 216 
     | 
    
         
            +
                        self._value -= 1.0
         
     | 
| 
      
 217 
     | 
    
         
            +
                        value = self._value
         
     | 
| 
      
 218 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 219 
     | 
    
         
            +
                        if (i_episode % self._track_metrices_eval_every) == 0:
         
     | 
| 
      
 220 
     | 
    
         
            +
                            value = 0.0
         
     | 
| 
      
 221 
     | 
    
         
            +
                            N = 0
         
     | 
| 
      
 222 
     | 
    
         
            +
                            for combination in self._track_metrices:
         
     | 
| 
      
 223 
     | 
    
         
            +
                                value += _zoom_into_metrices(metrices, combination)
         
     | 
| 
      
 224 
     | 
    
         
            +
                                N += 1
         
     | 
| 
      
 225 
     | 
    
         
            +
                            value /= N
         
     | 
| 
      
 226 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 227 
     | 
    
         
            +
                            # some very large loss such that it doesn't get added because
         
     | 
| 
      
 228 
     | 
    
         
            +
                            # we have already added this parameter set
         
     | 
| 
      
 229 
     | 
    
         
            +
                            value = 1e16
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
                    ele = QueueElement(value, params, i_episode)
         
     | 
| 
      
 232 
     | 
    
         
            +
                    self._queue.insert(ele)
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
                    self._loggers = loggers
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                def close(self):
         
     | 
| 
      
 237 
     | 
    
         
            +
                    filenames = []
         
     | 
| 
      
 238 
     | 
    
         
            +
                    for ele in self._queue:
         
     | 
| 
      
 239 
     | 
    
         
            +
                        if len(self._queue) == 1:
         
     | 
| 
      
 240 
     | 
    
         
            +
                            filename = parse_path(self.path_to_file, extension="pickle")
         
     | 
| 
      
 241 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 242 
     | 
    
         
            +
                            value = "{:.2f}".format(ele.value).replace(".", ",")
         
     | 
| 
      
 243 
     | 
    
         
            +
                            filename = parse_path(
         
     | 
| 
      
 244 
     | 
    
         
            +
                                self.path_to_file + f"_episode={ele.episode}_value={value}",
         
     | 
| 
      
 245 
     | 
    
         
            +
                                extension="pickle",
         
     | 
| 
      
 246 
     | 
    
         
            +
                            )
         
     | 
| 
      
 247 
     | 
    
         
            +
             
     | 
| 
      
 248 
     | 
    
         
            +
                        pickle_save(ele.params, filename, overwrite=True)
         
     | 
| 
      
 249 
     | 
    
         
            +
                        if self.upload:
         
     | 
| 
      
 250 
     | 
    
         
            +
                            success = False
         
     | 
| 
      
 251 
     | 
    
         
            +
                            for logger in self._loggers:
         
     | 
| 
      
 252 
     | 
    
         
            +
                                try:
         
     | 
| 
      
 253 
     | 
    
         
            +
                                    logger.log_params(filename)
         
     | 
| 
      
 254 
     | 
    
         
            +
                                    success = True
         
     | 
| 
      
 255 
     | 
    
         
            +
                                except NotImplementedError:
         
     | 
| 
      
 256 
     | 
    
         
            +
                                    pass
         
     | 
| 
      
 257 
     | 
    
         
            +
                                if not success:
         
     | 
| 
      
 258 
     | 
    
         
            +
                                    warnings.warn(
         
     | 
| 
      
 259 
     | 
    
         
            +
                                        "Upload of parameters was requested but no `ml_utils.Logger"
         
     | 
| 
      
 260 
     | 
    
         
            +
                                        "` that implements `logger.log_params` was found."
         
     | 
| 
      
 261 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
                        filenames.append(filename)
         
     | 
| 
      
 264 
     | 
    
         
            +
             
     | 
| 
      
 265 
     | 
    
         
            +
                    if self._cleanup:
         
     | 
| 
      
 266 
     | 
    
         
            +
                        # wait for upload
         
     | 
| 
      
 267 
     | 
    
         
            +
                        time.sleep(3)
         
     | 
| 
      
 268 
     | 
    
         
            +
             
     | 
| 
      
 269 
     | 
    
         
            +
                        for filename in filenames:
         
     | 
| 
      
 270 
     | 
    
         
            +
                            os.system(f"rm {filename}")
         
     | 
| 
      
 271 
     | 
    
         
            +
             
     | 
| 
      
 272 
     | 
    
         
            +
                        # delete folder
         
     | 
| 
      
 273 
     | 
    
         
            +
                        os.system(f"rmdir {str(Path(filename).parent)}")
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
            class LogGradsTrainingLoopCallBack(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 277 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 278 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 279 
     | 
    
         
            +
                    kill_if_larger: Optional[float] = None,
         
     | 
| 
      
 280 
     | 
    
         
            +
                    consecutive_larger: int = 1,
         
     | 
| 
      
 281 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 282 
     | 
    
         
            +
                    self.kill_if_larger = kill_if_larger
         
     | 
| 
      
 283 
     | 
    
         
            +
                    self.consecutive_larger = consecutive_larger
         
     | 
| 
      
 284 
     | 
    
         
            +
                    self.last_larger = deque(maxlen=consecutive_larger)
         
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 287 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 288 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 289 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 290 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 291 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 292 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 293 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 294 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 295 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 296 
     | 
    
         
            +
                    gradient_log = {}
         
     | 
| 
      
 297 
     | 
    
         
            +
                    for i, grads_tbp in enumerate(grads):
         
     | 
| 
      
 298 
     | 
    
         
            +
                        grads_flat = tree_utils.batch_concat(grads_tbp, num_batch_dims=0)
         
     | 
| 
      
 299 
     | 
    
         
            +
                        grads_max = jnp.max(jnp.abs(grads_flat))
         
     | 
| 
      
 300 
     | 
    
         
            +
                        grads_norm = jnp.linalg.norm(grads_flat)
         
     | 
| 
      
 301 
     | 
    
         
            +
                        if self.kill_if_larger is not None:
         
     | 
| 
      
 302 
     | 
    
         
            +
                            if grads_norm > self.kill_if_larger:
         
     | 
| 
      
 303 
     | 
    
         
            +
                                self.last_larger.append(True)
         
     | 
| 
      
 304 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 305 
     | 
    
         
            +
                                self.last_larger.append(False)
         
     | 
| 
      
 306 
     | 
    
         
            +
                            if all(self.last_larger):
         
     | 
| 
      
 307 
     | 
    
         
            +
                                training_loop.send_kill_run_signal()
         
     | 
| 
      
 308 
     | 
    
         
            +
                        gradient_log[f"grads_tbp_{i}_max"] = grads_max
         
     | 
| 
      
 309 
     | 
    
         
            +
                        gradient_log[f"grads_tbp_{i}_l2norm"] = grads_norm
         
     | 
| 
      
 310 
     | 
    
         
            +
             
     | 
| 
      
 311 
     | 
    
         
            +
                    metrices.update(gradient_log)
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
            class NanKillRunCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 315 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 316 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 317 
     | 
    
         
            +
                    print: bool = True,
         
     | 
| 
      
 318 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 319 
     | 
    
         
            +
                    self.print = print
         
     | 
| 
      
 320 
     | 
    
         
            +
             
     | 
| 
      
 321 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 322 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 323 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 324 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 325 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 326 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 327 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 328 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 329 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 330 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 331 
     | 
    
         
            +
                    params_fast_flat = tree_utils.batch_concat(params, num_batch_dims=0)
         
     | 
| 
      
 332 
     | 
    
         
            +
                    params_is_nan = jnp.any(jnp.isnan(params_fast_flat))
         
     | 
| 
      
 333 
     | 
    
         
            +
             
     | 
| 
      
 334 
     | 
    
         
            +
                    if params_is_nan:
         
     | 
| 
      
 335 
     | 
    
         
            +
                        training_loop.send_kill_run_signal()
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
      
 337 
     | 
    
         
            +
                    if params_is_nan and self.print:
         
     | 
| 
      
 338 
     | 
    
         
            +
                        print(
         
     | 
| 
      
 339 
     | 
    
         
            +
                            f"Parameters have converged to NaN at step {i_episode}. Exiting run.."
         
     | 
| 
      
 340 
     | 
    
         
            +
                        )
         
     | 
| 
      
 341 
     | 
    
         
            +
             
     | 
| 
      
 342 
     | 
    
         
            +
             
     | 
| 
      
 343 
     | 
    
         
            +
            class LogEpisodeTrainingLoopCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 344 
     | 
    
         
            +
                def __init__(self, kill_after_episode: Optional[int] = None) -> None:
         
     | 
| 
      
 345 
     | 
    
         
            +
                    self.kill_after_episode = kill_after_episode
         
     | 
| 
      
 346 
     | 
    
         
            +
             
     | 
| 
      
 347 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 348 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 349 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 350 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 351 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 352 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 353 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 354 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 355 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 356 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 357 
     | 
    
         
            +
                    if self.kill_after_episode is not None and (
         
     | 
| 
      
 358 
     | 
    
         
            +
                        i_episode >= self.kill_after_episode
         
     | 
| 
      
 359 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 360 
     | 
    
         
            +
                        training_loop.send_kill_run_signal()
         
     | 
| 
      
 361 
     | 
    
         
            +
                    metrices.update({"i_episode": i_episode})
         
     | 
| 
      
 362 
     | 
    
         
            +
             
     | 
| 
      
 363 
     | 
    
         
            +
             
     | 
| 
      
 364 
     | 
    
         
            +
            class TimingKillRunCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 365 
     | 
    
         
            +
                def __init__(self, max_run_time_seconds: float) -> None:
         
     | 
| 
      
 366 
     | 
    
         
            +
                    self.max_run_time_seconds = max_run_time_seconds
         
     | 
| 
      
 367 
     | 
    
         
            +
             
     | 
| 
      
 368 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 369 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 370 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 371 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 372 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 373 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 374 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 375 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 376 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 377 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 378 
     | 
    
         
            +
                    runtime = time.time() - ring._TRAIN_TIMING_START
         
     | 
| 
      
 379 
     | 
    
         
            +
                    if runtime > self.max_run_time_seconds:
         
     | 
| 
      
 380 
     | 
    
         
            +
                        runtime_h = runtime / 3600
         
     | 
| 
      
 381 
     | 
    
         
            +
                        print(f"Run is killed due to timing. Current runtime is {runtime_h}h.")
         
     | 
| 
      
 382 
     | 
    
         
            +
                        training_loop.send_kill_run_signal()
         
     | 
| 
      
 383 
     | 
    
         
            +
             
     | 
| 
      
 384 
     | 
    
         
            +
             
     | 
| 
      
 385 
     | 
    
         
            +
            class CheckpointCallback(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 386 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 387 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 388 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 389 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 390 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 391 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 392 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 393 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 394 
     | 
    
         
            +
                    opt_state: tree_utils.PyTree,
         
     | 
| 
      
 395 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 396 
     | 
    
         
            +
                    self.params = params
         
     | 
| 
      
 397 
     | 
    
         
            +
                    self.opt_state = opt_state
         
     | 
| 
      
 398 
     | 
    
         
            +
             
     | 
| 
      
 399 
     | 
    
         
            +
                def close(self):
         
     | 
| 
      
 400 
     | 
    
         
            +
                    # only checkpoint if run has been killed
         
     | 
| 
      
 401 
     | 
    
         
            +
                    if training_loop.recv_kill_run_signal():
         
     | 
| 
      
 402 
     | 
    
         
            +
                        path = parse_path(
         
     | 
| 
      
 403 
     | 
    
         
            +
                            "~/.xxy_checkpoints", ml_utils.unique_id(), extension="pickle"
         
     | 
| 
      
 404 
     | 
    
         
            +
                        )
         
     | 
| 
      
 405 
     | 
    
         
            +
                        data = {"params": self.params, "opt_state": self.opt_state}
         
     | 
| 
      
 406 
     | 
    
         
            +
                        pickle_save(
         
     | 
| 
      
 407 
     | 
    
         
            +
                            obj=jax.device_get(data),
         
     | 
| 
      
 408 
     | 
    
         
            +
                            path=path,
         
     | 
| 
      
 409 
     | 
    
         
            +
                            overwrite=True,
         
     | 
| 
      
 410 
     | 
    
         
            +
                        )
         
     | 
| 
      
 411 
     | 
    
         
            +
             
     | 
| 
      
 412 
     | 
    
         
            +
             
     | 
| 
      
 413 
     | 
    
         
            +
            class WandbKillRun(training_loop.TrainingLoopCallback):
         
     | 
| 
      
 414 
     | 
    
         
            +
                def __init__(self, stop_tag: str = "stop"):
         
     | 
| 
      
 415 
     | 
    
         
            +
                    self.stop_tag = stop_tag
         
     | 
| 
      
 416 
     | 
    
         
            +
             
     | 
| 
      
 417 
     | 
    
         
            +
                def after_training_step(
         
     | 
| 
      
 418 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 419 
     | 
    
         
            +
                    i_episode: int,
         
     | 
| 
      
 420 
     | 
    
         
            +
                    metrices: dict,
         
     | 
| 
      
 421 
     | 
    
         
            +
                    params: dict,
         
     | 
| 
      
 422 
     | 
    
         
            +
                    grads: list[dict],
         
     | 
| 
      
 423 
     | 
    
         
            +
                    sample_eval: dict,
         
     | 
| 
      
 424 
     | 
    
         
            +
                    loggers: list[ml_utils.Logger],
         
     | 
| 
      
 425 
     | 
    
         
            +
                    opt_state,
         
     | 
| 
      
 426 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 427 
     | 
    
         
            +
                    if wandb.run is not None:
         
     | 
| 
      
 428 
     | 
    
         
            +
                        tags = (
         
     | 
| 
      
 429 
     | 
    
         
            +
                            wandb.Api(timeout=99)
         
     | 
| 
      
 430 
     | 
    
         
            +
                            .run(path=f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}")
         
     | 
| 
      
 431 
     | 
    
         
            +
                            .tags
         
     | 
| 
      
 432 
     | 
    
         
            +
                        )
         
     | 
| 
      
 433 
     | 
    
         
            +
                        if self.stop_tag in tags:
         
     | 
| 
      
 434 
     | 
    
         
            +
                            training_loop.send_kill_run_signal()
         
     | 
    
        ring/ml/ml_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,272 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from collections import defaultdict
         
     | 
| 
      
 2 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 3 
     | 
    
         
            +
            import os
         
     | 
| 
      
 4 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 5 
     | 
    
         
            +
            import pickle
         
     | 
| 
      
 6 
     | 
    
         
            +
            import random
         
     | 
| 
      
 7 
     | 
    
         
            +
            import time
         
     | 
| 
      
 8 
     | 
    
         
            +
            from typing import Optional, Protocol
         
     | 
| 
      
 9 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 12 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 13 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 14 
     | 
    
         
            +
            from ring.utils import import_lib
         
     | 
| 
      
 15 
     | 
    
         
            +
            from tree_utils import PyTree
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            import wandb
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            # An arbitrarily nested dictionary with Array leaves; Or strings
         
     | 
| 
      
 20 
     | 
    
         
            +
            NestedDict = PyTree
         
     | 
| 
      
 21 
     | 
    
         
            +
            STEP_METRIC_NAME = "i_episode"
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            class Logger(Protocol):
         
     | 
| 
      
 25 
     | 
    
         
            +
                def close(self) -> None: ...  # noqa: E704
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                def log(self, metrics: NestedDict) -> None: ...  # noqa: E704
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 30 
     | 
    
         
            +
                def n_params(params) -> int:
         
     | 
| 
      
 31 
     | 
    
         
            +
                    "Number of parameters in Pytree `params`."
         
     | 
| 
      
 32 
     | 
    
         
            +
                    return sum([arr.flatten().size for arr in jax.tree_util.tree_leaves(params)])
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            class MixinLogger(Logger):
         
     | 
| 
      
 36 
     | 
    
         
            +
                def close(self):
         
     | 
| 
      
 37 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                def log_image(self, path: str, caption: Optional[str] = None):
         
     | 
| 
      
 40 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                def log_video(
         
     | 
| 
      
 43 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 44 
     | 
    
         
            +
                    path: str,
         
     | 
| 
      
 45 
     | 
    
         
            +
                    fps: int = 25,
         
     | 
| 
      
 46 
     | 
    
         
            +
                    caption: Optional[str] = None,
         
     | 
| 
      
 47 
     | 
    
         
            +
                    step: Optional[int] = None,
         
     | 
| 
      
 48 
     | 
    
         
            +
                ):
         
     | 
| 
      
 49 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                def log_params(self, path: str):
         
     | 
| 
      
 52 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                def log(self, metrics: NestedDict):
         
     | 
| 
      
 55 
     | 
    
         
            +
                    step = metrics[STEP_METRIC_NAME] if STEP_METRIC_NAME in metrics else None
         
     | 
| 
      
 56 
     | 
    
         
            +
                    for key, value in _flatten_convert_filter_nested_dict(metrics).items():
         
     | 
| 
      
 57 
     | 
    
         
            +
                        self.log_key_value(key, value, step=step)
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
         
     | 
| 
      
 60 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                def log_command_output(self, command: str):
         
     | 
| 
      
 63 
     | 
    
         
            +
                    path = command.replace(" ", "_") + ".txt"
         
     | 
| 
      
 64 
     | 
    
         
            +
                    os.system(f"{command} >> {path}")
         
     | 
| 
      
 65 
     | 
    
         
            +
                    self.log_txt(path, wait=True)
         
     | 
| 
      
 66 
     | 
    
         
            +
                    os.system(f"rm {path}")
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                def log_txt(self, path: str, wait: bool = True):
         
     | 
| 
      
 69 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                def _log_environment(self):
         
     | 
| 
      
 72 
     | 
    
         
            +
                    self.log_command_output("pip list")
         
     | 
| 
      
 73 
     | 
    
         
            +
                    self.log_command_output("conda list")
         
     | 
| 
      
 74 
     | 
    
         
            +
                    self.log_command_output("nvidia-smi")
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
            class DictLogger(MixinLogger):
         
     | 
| 
      
 78 
     | 
    
         
            +
                def __init__(self, output_path: Optional[str] = None):
         
     | 
| 
      
 79 
     | 
    
         
            +
                    self._logs = defaultdict(lambda: [])
         
     | 
| 
      
 80 
     | 
    
         
            +
                    self._output_path = output_path
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                def log_key_value(self, key: str, value: str | float, step: int | None = None):
         
     | 
| 
      
 83 
     | 
    
         
            +
                    self._logs[key].append(value)
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                def close(self):
         
     | 
| 
      
 86 
     | 
    
         
            +
                    if self._output_path is None:
         
     | 
| 
      
 87 
     | 
    
         
            +
                        return
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.save(self._output_path)
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                def save(self, path: str):
         
     | 
| 
      
 91 
     | 
    
         
            +
                    path = Path(path).with_suffix(".pickle").expanduser()
         
     | 
| 
      
 92 
     | 
    
         
            +
                    path.mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 93 
     | 
    
         
            +
                    with open(path, "wb") as file:
         
     | 
| 
      
 94 
     | 
    
         
            +
                        pickle.dump(self.get_logs(), file, protocol=5)
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                def get_logs(self):
         
     | 
| 
      
 97 
     | 
    
         
            +
                    return self._logs
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
            class WandbLogger(MixinLogger):
         
     | 
| 
      
 101 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
      
 102 
     | 
    
         
            +
                    self._log_environment()
         
     | 
| 
      
 103 
     | 
    
         
            +
                    wandb.run.define_metric(STEP_METRIC_NAME)
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                def log_key_value(self, key: str, value: str | float, step: Optional[int] = None):
         
     | 
| 
      
 106 
     | 
    
         
            +
                    data = {key: value}
         
     | 
| 
      
 107 
     | 
    
         
            +
                    if step is not None:
         
     | 
| 
      
 108 
     | 
    
         
            +
                        data.update({STEP_METRIC_NAME: step})
         
     | 
| 
      
 109 
     | 
    
         
            +
                    wandb.log(data)
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                def log_params(self, path: str):
         
     | 
| 
      
 112 
     | 
    
         
            +
                    wandb.save(path, policy="now")
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
                def log_video(
         
     | 
| 
      
 115 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 116 
     | 
    
         
            +
                    path: str,
         
     | 
| 
      
 117 
     | 
    
         
            +
                    fps: int = 25,
         
     | 
| 
      
 118 
     | 
    
         
            +
                    caption: Optional[str] = None,
         
     | 
| 
      
 119 
     | 
    
         
            +
                    step: Optional[int] = None,
         
     | 
| 
      
 120 
     | 
    
         
            +
                ):
         
     | 
| 
      
 121 
     | 
    
         
            +
                    # TODO >>>
         
     | 
| 
      
 122 
     | 
    
         
            +
                    wandb.save(path, policy="now")
         
     | 
| 
      
 123 
     | 
    
         
            +
                    return
         
     | 
| 
      
 124 
     | 
    
         
            +
                    # <<<
         
     | 
| 
      
 125 
     | 
    
         
            +
                    data = {"video": wandb.Video(path, caption=caption, fps=fps)}
         
     | 
| 
      
 126 
     | 
    
         
            +
                    if step is not None:
         
     | 
| 
      
 127 
     | 
    
         
            +
                        data.update({STEP_METRIC_NAME: step})
         
     | 
| 
      
 128 
     | 
    
         
            +
                    wandb.log(data)
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
                def log_image(self, path: str, caption: Optional[str] = None):
         
     | 
| 
      
 131 
     | 
    
         
            +
                    # wandb.log({"image": wandb.Image(path, caption=caption)})
         
     | 
| 
      
 132 
     | 
    
         
            +
                    wandb.save(path, policy="now")
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                def log_txt(self, path: str, wait: bool = True):
         
     | 
| 
      
 135 
     | 
    
         
            +
                    wandb.save(path, policy="now")
         
     | 
| 
      
 136 
     | 
    
         
            +
                    # TODO: `wandb` is not async at all?
         
     | 
| 
      
 137 
     | 
    
         
            +
                    if wait:
         
     | 
| 
      
 138 
     | 
    
         
            +
                        time.sleep(3)
         
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
                def close(self):
         
     | 
| 
      
 141 
     | 
    
         
            +
                    wandb.run.finish()
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
            def _flatten_convert_filter_nested_dict(
         
     | 
| 
      
 145 
     | 
    
         
            +
                metrices: NestedDict, filter_nan_inf: bool = True
         
     | 
| 
      
 146 
     | 
    
         
            +
            ):
         
     | 
| 
      
 147 
     | 
    
         
            +
                metrices = _flatten_dict(metrices)
         
     | 
| 
      
 148 
     | 
    
         
            +
                metrices = jax.tree_map(_to_float_if_not_string, metrices)
         
     | 
| 
      
 149 
     | 
    
         
            +
             
     | 
| 
      
 150 
     | 
    
         
            +
                if not filter_nan_inf:
         
     | 
| 
      
 151 
     | 
    
         
            +
                    return metrices
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                filtered_metrices = {}
         
     | 
| 
      
 154 
     | 
    
         
            +
                for key, value in metrices.items():
         
     | 
| 
      
 155 
     | 
    
         
            +
                    if not isinstance(value, str) and (np.isnan(value) or np.isinf(value)):
         
     | 
| 
      
 156 
     | 
    
         
            +
                        warning = f"Warning: Value of metric {key} is {value}. We skip it."
         
     | 
| 
      
 157 
     | 
    
         
            +
                        warnings.warn(warning)
         
     | 
| 
      
 158 
     | 
    
         
            +
                        continue
         
     | 
| 
      
 159 
     | 
    
         
            +
                    filtered_metrices[key] = value
         
     | 
| 
      
 160 
     | 
    
         
            +
                return filtered_metrices
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
            def _flatten_dict(d, parent_key="", sep="_"):
         
     | 
| 
      
 164 
     | 
    
         
            +
                items = []
         
     | 
| 
      
 165 
     | 
    
         
            +
                for k, v in d.items():
         
     | 
| 
      
 166 
     | 
    
         
            +
                    k = str(k) if isinstance(k, int) else k
         
     | 
| 
      
 167 
     | 
    
         
            +
                    new_key = parent_key + sep + k if parent_key else k
         
     | 
| 
      
 168 
     | 
    
         
            +
                    if isinstance(v, dict):
         
     | 
| 
      
 169 
     | 
    
         
            +
                        items.extend(_flatten_dict(v, new_key, sep=sep).items())
         
     | 
| 
      
 170 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 171 
     | 
    
         
            +
                        items.append((new_key, v))
         
     | 
| 
      
 172 
     | 
    
         
            +
                return dict(items)
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
            def _to_float_if_not_string(value):
         
     | 
| 
      
 176 
     | 
    
         
            +
                if isinstance(value, str):
         
     | 
| 
      
 177 
     | 
    
         
            +
                    return value
         
     | 
| 
      
 178 
     | 
    
         
            +
                else:
         
     | 
| 
      
 179 
     | 
    
         
            +
                    return float(value)
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
      
 182 
     | 
    
         
            +
            def on_cluster() -> bool:
         
     | 
| 
      
 183 
     | 
    
         
            +
                """Return `true` if executed on cluster."""
         
     | 
| 
      
 184 
     | 
    
         
            +
                env_var = os.environ.get("ON_CLUSTER", None)
         
     | 
| 
      
 185 
     | 
    
         
            +
                return False if env_var is None else True
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
            def unique_id() -> str:
         
     | 
| 
      
 189 
     | 
    
         
            +
                return ring._UNIQUE_ID
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
             
     | 
| 
      
 192 
     | 
    
         
            +
            def save_model_tf(jax_func, path: str, *input, validate: bool = True):
         
     | 
| 
      
 193 
     | 
    
         
            +
                from jax.experimental import jax2tf
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                tf = import_lib("tensorflow", "the function `save_model_tf`")
         
     | 
| 
      
 196 
     | 
    
         
            +
             
     | 
| 
      
 197 
     | 
    
         
            +
                def _create_module(jax_func, input):
         
     | 
| 
      
 198 
     | 
    
         
            +
                    signature = jax.tree_map(
         
     | 
| 
      
 199 
     | 
    
         
            +
                        lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
         
     | 
| 
      
 200 
     | 
    
         
            +
                    )
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
                    class RingTFModule(tf.Module):
         
     | 
| 
      
 203 
     | 
    
         
            +
                        def __init__(self, jax_func):
         
     | 
| 
      
 204 
     | 
    
         
            +
                            super().__init__()
         
     | 
| 
      
 205 
     | 
    
         
            +
                            self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
         
     | 
| 
      
 206 
     | 
    
         
            +
             
     | 
| 
      
 207 
     | 
    
         
            +
                        @partial(
         
     | 
| 
      
 208 
     | 
    
         
            +
                            tf.function,
         
     | 
| 
      
 209 
     | 
    
         
            +
                            autograph=False,
         
     | 
| 
      
 210 
     | 
    
         
            +
                            jit_compile=True,
         
     | 
| 
      
 211 
     | 
    
         
            +
                            input_signature=signature,
         
     | 
| 
      
 212 
     | 
    
         
            +
                        )
         
     | 
| 
      
 213 
     | 
    
         
            +
                        def __call__(self, *args):
         
     | 
| 
      
 214 
     | 
    
         
            +
                            return self.tf_func(*args)
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
      
 216 
     | 
    
         
            +
                    return RingTFModule(jax_func)
         
     | 
| 
      
 217 
     | 
    
         
            +
             
     | 
| 
      
 218 
     | 
    
         
            +
                model = _create_module(jax_func, input)
         
     | 
| 
      
 219 
     | 
    
         
            +
                tf.saved_model.save(
         
     | 
| 
      
 220 
     | 
    
         
            +
                    model,
         
     | 
| 
      
 221 
     | 
    
         
            +
                    path,
         
     | 
| 
      
 222 
     | 
    
         
            +
                    options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
         
     | 
| 
      
 223 
     | 
    
         
            +
                )
         
     | 
| 
      
 224 
     | 
    
         
            +
                if validate:
         
     | 
| 
      
 225 
     | 
    
         
            +
                    output_jax = jax_func(*input)
         
     | 
| 
      
 226 
     | 
    
         
            +
                    output_tf = tf.saved_model.load(path)(*input)
         
     | 
| 
      
 227 
     | 
    
         
            +
                    jax.tree_map(
         
     | 
| 
      
 228 
     | 
    
         
            +
                        lambda a1, a2: np.allclose(a1, a2, atol=1e-5, rtol=1e-5),
         
     | 
| 
      
 229 
     | 
    
         
            +
                        output_jax,
         
     | 
| 
      
 230 
     | 
    
         
            +
                        output_tf,
         
     | 
| 
      
 231 
     | 
    
         
            +
                    )
         
     | 
| 
      
 232 
     | 
    
         
            +
             
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
            def train_val_split(
         
     | 
| 
      
 235 
     | 
    
         
            +
                tps: list[str],
         
     | 
| 
      
 236 
     | 
    
         
            +
                bs: int,
         
     | 
| 
      
 237 
     | 
    
         
            +
                n_batches_for_val: int = 1,
         
     | 
| 
      
 238 
     | 
    
         
            +
                transform_gen=None,
         
     | 
| 
      
 239 
     | 
    
         
            +
                tree_transform=None,
         
     | 
| 
      
 240 
     | 
    
         
            +
            ):
         
     | 
| 
      
 241 
     | 
    
         
            +
                "Uses `random` module for shuffeling."
         
     | 
| 
      
 242 
     | 
    
         
            +
                if transform_gen is None:
         
     | 
| 
      
 243 
     | 
    
         
            +
                    transform_gen = lambda gen: gen
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
                len_val = n_batches_for_val * bs
         
     | 
| 
      
 246 
     | 
    
         
            +
             
     | 
| 
      
 247 
     | 
    
         
            +
                _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
         
     | 
| 
      
 248 
     | 
    
         
            +
                include_samples = list(range(N))
         
     | 
| 
      
 249 
     | 
    
         
            +
                random.shuffle(include_samples)
         
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
                train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
         
     | 
| 
      
 252 
     | 
    
         
            +
                X_val, y_val = transform_gen(
         
     | 
| 
      
 253 
     | 
    
         
            +
                    ring.RCMG.eager_gen_from_paths(
         
     | 
| 
      
 254 
     | 
    
         
            +
                        tps, len_val, val_data, tree_transform=tree_transform
         
     | 
| 
      
 255 
     | 
    
         
            +
                    )[0]
         
     | 
| 
      
 256 
     | 
    
         
            +
                )(jax.random.PRNGKey(420))
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                generator = transform_gen(
         
     | 
| 
      
 259 
     | 
    
         
            +
                    ring.RCMG.eager_gen_from_paths(
         
     | 
| 
      
 260 
     | 
    
         
            +
                        tps,
         
     | 
| 
      
 261 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 262 
     | 
    
         
            +
                        train_data,
         
     | 
| 
      
 263 
     | 
    
         
            +
                        load_all_into_memory=True,
         
     | 
| 
      
 264 
     | 
    
         
            +
                        tree_transform=tree_transform,
         
     | 
| 
      
 265 
     | 
    
         
            +
                    )[0]
         
     | 
| 
      
 266 
     | 
    
         
            +
                )
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
                return generator, (X_val, y_val)
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
             
     | 
| 
      
 271 
     | 
    
         
            +
            def _unknown_link_names(N: int):
         
     | 
| 
      
 272 
     | 
    
         
            +
                return [f"link{i}" for i in range(N)]
         
     |