imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
    
        ring/utils/normalizer.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import math
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Callable, TypeVar
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 5 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ring.algorithms.generator import types
         
     | 
| 
      
 7 
     | 
    
         
            +
            import tree_utils
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            KEY = jax.random.PRNGKey(777)
         
     | 
| 
      
 10 
     | 
    
         
            +
            KEY_PERMUTATION = jax.random.PRNGKey(888)
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            X = TypeVar("X")
         
     | 
| 
      
 14 
     | 
    
         
            +
            Normalizer = Callable[[X], X]
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            def make_normalizer_from_generator(
         
     | 
| 
      
 18 
     | 
    
         
            +
                generator: types.BatchedGenerator,
         
     | 
| 
      
 19 
     | 
    
         
            +
                approx_with_large_batchsize: int = 512,
         
     | 
| 
      
 20 
     | 
    
         
            +
                verbose: bool = False,
         
     | 
| 
      
 21 
     | 
    
         
            +
            ) -> Normalizer:
         
     | 
| 
      
 22 
     | 
    
         
            +
                "Returns a pure function that normalizes `X`."
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                # probe generator for its batchsize
         
     | 
| 
      
 25 
     | 
    
         
            +
                X, _ = generator(KEY)
         
     | 
| 
      
 26 
     | 
    
         
            +
                bs = tree_utils.tree_shape(X)
         
     | 
| 
      
 27 
     | 
    
         
            +
                assert tree_utils.tree_ndim(X) == 3, "`generator` must be batched."
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                # how often do we have to query the generator
         
     | 
| 
      
 30 
     | 
    
         
            +
                number_of_gen_calls = math.ceil(approx_with_large_batchsize / bs)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                Xs, key = [], KEY
         
     | 
| 
      
 33 
     | 
    
         
            +
                for _ in range(number_of_gen_calls):
         
     | 
| 
      
 34 
     | 
    
         
            +
                    key, consume = jax.random.split(key)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    Xs.append(generator(consume)[0])
         
     | 
| 
      
 36 
     | 
    
         
            +
                Xs = tree_utils.tree_batch(Xs, True, "jax")
         
     | 
| 
      
 37 
     | 
    
         
            +
                # permute 0-th axis, since batchsize of generator might be larger than
         
     | 
| 
      
 38 
     | 
    
         
            +
                # `approx_with_large_batchsize`, then we would not get a representative
         
     | 
| 
      
 39 
     | 
    
         
            +
                # subsample otherwise
         
     | 
| 
      
 40 
     | 
    
         
            +
                Xs = jax.tree_map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
         
     | 
| 
      
 41 
     | 
    
         
            +
                Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                # obtain statistics
         
     | 
| 
      
 44 
     | 
    
         
            +
                mean = jax.tree_map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
         
     | 
| 
      
 45 
     | 
    
         
            +
                std = jax.tree_map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                if verbose:
         
     | 
| 
      
 48 
     | 
    
         
            +
                    print("Mean: ", mean)
         
     | 
| 
      
 49 
     | 
    
         
            +
                    print("Std: ", std)
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                eps = 1e-8
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                def normalizer(X):
         
     | 
| 
      
 54 
     | 
    
         
            +
                    return jax.tree_map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                return normalizer
         
     | 
    
        ring/utils/path.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import os
         
     | 
| 
      
 2 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 3 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
      
 4 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            def parse_path(
         
     | 
| 
      
 8 
     | 
    
         
            +
                path: str,
         
     | 
| 
      
 9 
     | 
    
         
            +
                *join_paths: str,
         
     | 
| 
      
 10 
     | 
    
         
            +
                extension: Optional[str] = None,
         
     | 
| 
      
 11 
     | 
    
         
            +
                file_exists_ok: bool = True,
         
     | 
| 
      
 12 
     | 
    
         
            +
                mkdir: bool = True,
         
     | 
| 
      
 13 
     | 
    
         
            +
                require_is_file: bool = False,
         
     | 
| 
      
 14 
     | 
    
         
            +
            ) -> str:
         
     | 
| 
      
 15 
     | 
    
         
            +
                path = Path(os.path.expanduser(path))
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                for p in join_paths:
         
     | 
| 
      
 18 
     | 
    
         
            +
                    path = path.joinpath(p)
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                if extension is not None:
         
     | 
| 
      
 21 
     | 
    
         
            +
                    if extension != "":
         
     | 
| 
      
 22 
     | 
    
         
            +
                        extension = ("." + extension) if (extension[0] != ".") else extension
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                    # check for paths that contain a dot "." in their filename (through a number)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    # or that already have an extension
         
     | 
| 
      
 26 
     | 
    
         
            +
                    old_suffix = path.suffix
         
     | 
| 
      
 27 
     | 
    
         
            +
                    if old_suffix != "" and old_suffix != extension:
         
     | 
| 
      
 28 
     | 
    
         
            +
                        warnings.warn(
         
     | 
| 
      
 29 
     | 
    
         
            +
                            f"The path ({path}) already has an extension (`{old_suffix}`), but "
         
     | 
| 
      
 30 
     | 
    
         
            +
                            f"it gets replaced by the extension=`{extension}`."
         
     | 
| 
      
 31 
     | 
    
         
            +
                        )
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                    path = path.with_suffix(extension)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                if not file_exists_ok and os.path.exists(path):
         
     | 
| 
      
 36 
     | 
    
         
            +
                    raise Exception(f"File {path} already exists but shouldn't")
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                if mkdir:
         
     | 
| 
      
 39 
     | 
    
         
            +
                    path.parent.mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                if require_is_file:
         
     | 
| 
      
 42 
     | 
    
         
            +
                    assert path.is_file(), f"Not a file: {path}"
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                return str(path)
         
     | 
    
        ring/utils/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from importlib import import_module as _import_module
         
     | 
| 
      
 2 
     | 
    
         
            +
            import io
         
     | 
| 
      
 3 
     | 
    
         
            +
            import pickle
         
     | 
| 
      
 4 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 7 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 8 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            from ring.base import _Base
         
     | 
| 
      
 11 
     | 
    
         
            +
            from ring.base import Geometry
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            from .path import parse_path
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            def tree_equal(a, b):
         
     | 
| 
      
 17 
     | 
    
         
            +
                "Copied from Marcel / Thomas"
         
     | 
| 
      
 18 
     | 
    
         
            +
                if type(a) is not type(b):
         
     | 
| 
      
 19 
     | 
    
         
            +
                    return False
         
     | 
| 
      
 20 
     | 
    
         
            +
                if isinstance(a, _Base):
         
     | 
| 
      
 21 
     | 
    
         
            +
                    return tree_equal(a.__dict__, b.__dict__)
         
     | 
| 
      
 22 
     | 
    
         
            +
                if isinstance(a, dict):
         
     | 
| 
      
 23 
     | 
    
         
            +
                    if a.keys() != b.keys():
         
     | 
| 
      
 24 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 25 
     | 
    
         
            +
                    return all(tree_equal(a[k], b[k]) for k in a.keys())
         
     | 
| 
      
 26 
     | 
    
         
            +
                if isinstance(a, (tuple, list)):
         
     | 
| 
      
 27 
     | 
    
         
            +
                    if len(a) != len(b):
         
     | 
| 
      
 28 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 29 
     | 
    
         
            +
                    return all(tree_equal(a[i], b[i]) for i in range(len(a)))
         
     | 
| 
      
 30 
     | 
    
         
            +
                if isinstance(a, (jax.Array, np.ndarray)):
         
     | 
| 
      
 31 
     | 
    
         
            +
                    return jnp.allclose(a, b)
         
     | 
| 
      
 32 
     | 
    
         
            +
                return a == b
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            def _sys_compare_unsafe(sys1, sys2, verbose: bool, prefix: str) -> bool:
         
     | 
| 
      
 36 
     | 
    
         
            +
                d1 = sys1.__dict__
         
     | 
| 
      
 37 
     | 
    
         
            +
                d2 = sys2.__dict__
         
     | 
| 
      
 38 
     | 
    
         
            +
                for key in d1:
         
     | 
| 
      
 39 
     | 
    
         
            +
                    if isinstance(d1[key], _Base):
         
     | 
| 
      
 40 
     | 
    
         
            +
                        if not _sys_compare_unsafe(d1[key], d2[key], verbose, prefix + "." + key):
         
     | 
| 
      
 41 
     | 
    
         
            +
                            return False
         
     | 
| 
      
 42 
     | 
    
         
            +
                    elif isinstance(d1[key], list) and isinstance(d1[key][0], Geometry):
         
     | 
| 
      
 43 
     | 
    
         
            +
                        for ele1, ele2 in zip(d1[key], d2[key]):
         
     | 
| 
      
 44 
     | 
    
         
            +
                            if not _sys_compare_unsafe(ele1, ele2, verbose, prefix + "." + key):
         
     | 
| 
      
 45 
     | 
    
         
            +
                                return False
         
     | 
| 
      
 46 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 47 
     | 
    
         
            +
                        if not tree_equal(d1[key], d2[key]):
         
     | 
| 
      
 48 
     | 
    
         
            +
                            if verbose:
         
     | 
| 
      
 49 
     | 
    
         
            +
                                print(f"Systems different in attribute `sys{prefix}.{key}`")
         
     | 
| 
      
 50 
     | 
    
         
            +
                                print(f"{repr(d1[key])} NOT EQUAL {repr(d2[key])}")
         
     | 
| 
      
 51 
     | 
    
         
            +
                            return False
         
     | 
| 
      
 52 
     | 
    
         
            +
                return True
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
            def sys_compare(sys1, sys2, verbose: bool = True):
         
     | 
| 
      
 56 
     | 
    
         
            +
                equalA = _sys_compare_unsafe(sys1, sys2, verbose, "")
         
     | 
| 
      
 57 
     | 
    
         
            +
                equalB = tree_equal(sys1, sys2)
         
     | 
| 
      
 58 
     | 
    
         
            +
                assert equalA == equalB
         
     | 
| 
      
 59 
     | 
    
         
            +
                return equalA
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
            def to_list(obj: object) -> list:
         
     | 
| 
      
 63 
     | 
    
         
            +
                "obj -> [obj], if it isn't already a list."
         
     | 
| 
      
 64 
     | 
    
         
            +
                if not isinstance(obj, list):
         
     | 
| 
      
 65 
     | 
    
         
            +
                    return [obj]
         
     | 
| 
      
 66 
     | 
    
         
            +
                return obj
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
            def dict_union(
         
     | 
| 
      
 70 
     | 
    
         
            +
                d1: dict[str, jax.Array] | dict[str, dict[str, jax.Array]],
         
     | 
| 
      
 71 
     | 
    
         
            +
                d2: dict[str, jax.Array] | dict[str, dict[str, jax.Array]],
         
     | 
| 
      
 72 
     | 
    
         
            +
                overwrite: bool = False,
         
     | 
| 
      
 73 
     | 
    
         
            +
            ) -> dict:
         
     | 
| 
      
 74 
     | 
    
         
            +
                "Builds the union between two nested dictonaries."
         
     | 
| 
      
 75 
     | 
    
         
            +
                # safety copying; otherwise this function would mutate out of scope
         
     | 
| 
      
 76 
     | 
    
         
            +
                d1 = pytree_deepcopy(d1)
         
     | 
| 
      
 77 
     | 
    
         
            +
                d2 = pytree_deepcopy(d2)
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                for key2 in d2:
         
     | 
| 
      
 80 
     | 
    
         
            +
                    if key2 not in d1:
         
     | 
| 
      
 81 
     | 
    
         
            +
                        d1[key2] = d2[key2]
         
     | 
| 
      
 82 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 83 
     | 
    
         
            +
                        if not isinstance(d2[key2], dict) or not isinstance(d1[key2], dict):
         
     | 
| 
      
 84 
     | 
    
         
            +
                            raise Exception(f"d1.keys()={d1.keys()}; d2.keys()={d2.keys()}")
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                        for key_nested in d2[key2]:
         
     | 
| 
      
 87 
     | 
    
         
            +
                            if not overwrite:
         
     | 
| 
      
 88 
     | 
    
         
            +
                                assert (
         
     | 
| 
      
 89 
     | 
    
         
            +
                                    key_nested not in d1[key2]
         
     | 
| 
      
 90 
     | 
    
         
            +
                                ), f"d1.keys()={d1[key2].keys()}; d2.keys()={d2[key2].keys()}"
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                        d1[key2].update(d2[key2])
         
     | 
| 
      
 93 
     | 
    
         
            +
                return d1
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
            def dict_to_nested(
         
     | 
| 
      
 97 
     | 
    
         
            +
                d: dict[str, jax.Array], add_key: str
         
     | 
| 
      
 98 
     | 
    
         
            +
            ) -> dict[str, dict[str, jax.Array]]:
         
     | 
| 
      
 99 
     | 
    
         
            +
                "Nests a dictonary by inserting a single key dictonary."
         
     | 
| 
      
 100 
     | 
    
         
            +
                return {key: {add_key: d[key]} for key in d.keys()}
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
            def save_figure_to_rgba(fig) -> np.ndarray:
         
     | 
| 
      
 104 
     | 
    
         
            +
                with io.BytesIO() as buff:
         
     | 
| 
      
 105 
     | 
    
         
            +
                    fig.savefig(buff, format="raw")
         
     | 
| 
      
 106 
     | 
    
         
            +
                    buff.seek(0)
         
     | 
| 
      
 107 
     | 
    
         
            +
                    data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
         
     | 
| 
      
 108 
     | 
    
         
            +
                w, h = fig.canvas.get_width_height()
         
     | 
| 
      
 109 
     | 
    
         
            +
                im = data.reshape((int(h), int(w), -1))
         
     | 
| 
      
 110 
     | 
    
         
            +
                return im
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
            def pytree_deepcopy(tree):
         
     | 
| 
      
 114 
     | 
    
         
            +
                "Recursivley copies a pytree."
         
     | 
| 
      
 115 
     | 
    
         
            +
                if isinstance(tree, (int, float, jax.Array)):
         
     | 
| 
      
 116 
     | 
    
         
            +
                    return tree
         
     | 
| 
      
 117 
     | 
    
         
            +
                elif isinstance(tree, np.ndarray):
         
     | 
| 
      
 118 
     | 
    
         
            +
                    return tree.copy()
         
     | 
| 
      
 119 
     | 
    
         
            +
                elif isinstance(tree, list):
         
     | 
| 
      
 120 
     | 
    
         
            +
                    return [pytree_deepcopy(ele) for ele in tree]
         
     | 
| 
      
 121 
     | 
    
         
            +
                elif isinstance(tree, tuple):
         
     | 
| 
      
 122 
     | 
    
         
            +
                    return tuple(pytree_deepcopy(ele) for ele in tree)
         
     | 
| 
      
 123 
     | 
    
         
            +
                elif isinstance(tree, dict):
         
     | 
| 
      
 124 
     | 
    
         
            +
                    return {key: pytree_deepcopy(value) for key, value in tree.items()}
         
     | 
| 
      
 125 
     | 
    
         
            +
                else:
         
     | 
| 
      
 126 
     | 
    
         
            +
                    raise NotImplementedError(f"Not implemented for type={type(tree)}")
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
            def import_lib(
         
     | 
| 
      
 130 
     | 
    
         
            +
                lib: str,
         
     | 
| 
      
 131 
     | 
    
         
            +
                required_for: Optional[str] = None,
         
     | 
| 
      
 132 
     | 
    
         
            +
                lib_pypi: Optional[str] = None,
         
     | 
| 
      
 133 
     | 
    
         
            +
            ):
         
     | 
| 
      
 134 
     | 
    
         
            +
                try:
         
     | 
| 
      
 135 
     | 
    
         
            +
                    return _import_module(lib)
         
     | 
| 
      
 136 
     | 
    
         
            +
                except ImportError:
         
     | 
| 
      
 137 
     | 
    
         
            +
                    _required = ""
         
     | 
| 
      
 138 
     | 
    
         
            +
                    if required_for is not None:
         
     | 
| 
      
 139 
     | 
    
         
            +
                        _required = f" but it is required for {required_for}"
         
     | 
| 
      
 140 
     | 
    
         
            +
                    if lib_pypi is None:
         
     | 
| 
      
 141 
     | 
    
         
            +
                        lib_pypi = lib
         
     | 
| 
      
 142 
     | 
    
         
            +
                    error_msg = (
         
     | 
| 
      
 143 
     | 
    
         
            +
                        f"Could not import `{lib}`{_required}. "
         
     | 
| 
      
 144 
     | 
    
         
            +
                        f"Please install with `pip install {lib_pypi}`"
         
     | 
| 
      
 145 
     | 
    
         
            +
                    )
         
     | 
| 
      
 146 
     | 
    
         
            +
                    raise ImportError(error_msg)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
            def pickle_save(obj, path, overwrite: bool = False):
         
     | 
| 
      
 150 
     | 
    
         
            +
                path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
         
     | 
| 
      
 151 
     | 
    
         
            +
                with open(path, "wb") as file:
         
     | 
| 
      
 152 
     | 
    
         
            +
                    pickle.dump(obj, file, protocol=5)
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
            def pickle_load(
         
     | 
| 
      
 156 
     | 
    
         
            +
                path,
         
     | 
| 
      
 157 
     | 
    
         
            +
            ):
         
     | 
| 
      
 158 
     | 
    
         
            +
                path = parse_path(path, extension="pickle", require_is_file=True)
         
     | 
| 
      
 159 
     | 
    
         
            +
                with open(path, "rb") as file:
         
     | 
| 
      
 160 
     | 
    
         
            +
                    obj = pickle.load(file)
         
     | 
| 
      
 161 
     | 
    
         
            +
                return obj
         
     |