imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,56 @@
1
+ import math
2
+ from typing import Callable, TypeVar
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from ring.algorithms.generator import types
7
+ import tree_utils
8
+
9
+ KEY = jax.random.PRNGKey(777)
10
+ KEY_PERMUTATION = jax.random.PRNGKey(888)
11
+
12
+
13
+ X = TypeVar("X")
14
+ Normalizer = Callable[[X], X]
15
+
16
+
17
+ def make_normalizer_from_generator(
18
+ generator: types.BatchedGenerator,
19
+ approx_with_large_batchsize: int = 512,
20
+ verbose: bool = False,
21
+ ) -> Normalizer:
22
+ "Returns a pure function that normalizes `X`."
23
+
24
+ # probe generator for its batchsize
25
+ X, _ = generator(KEY)
26
+ bs = tree_utils.tree_shape(X)
27
+ assert tree_utils.tree_ndim(X) == 3, "`generator` must be batched."
28
+
29
+ # how often do we have to query the generator
30
+ number_of_gen_calls = math.ceil(approx_with_large_batchsize / bs)
31
+
32
+ Xs, key = [], KEY
33
+ for _ in range(number_of_gen_calls):
34
+ key, consume = jax.random.split(key)
35
+ Xs.append(generator(consume)[0])
36
+ Xs = tree_utils.tree_batch(Xs, True, "jax")
37
+ # permute 0-th axis, since batchsize of generator might be larger than
38
+ # `approx_with_large_batchsize`, then we would not get a representative
39
+ # subsample otherwise
40
+ Xs = jax.tree_map(lambda arr: jax.random.permutation(KEY_PERMUTATION, arr), Xs)
41
+ Xs = tree_utils.tree_slice(Xs, start=0, slice_size=approx_with_large_batchsize)
42
+
43
+ # obtain statistics
44
+ mean = jax.tree_map(lambda arr: jnp.mean(arr, axis=(0, 1)), Xs)
45
+ std = jax.tree_map(lambda arr: jnp.std(arr, axis=(0, 1)), Xs)
46
+
47
+ if verbose:
48
+ print("Mean: ", mean)
49
+ print("Std: ", std)
50
+
51
+ eps = 1e-8
52
+
53
+ def normalizer(X):
54
+ return jax.tree_map(lambda a, b, c: (a - b) / (c + eps), X, mean, std)
55
+
56
+ return normalizer
ring/utils/path.py ADDED
@@ -0,0 +1,44 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import warnings
5
+
6
+
7
+ def parse_path(
8
+ path: str,
9
+ *join_paths: str,
10
+ extension: Optional[str] = None,
11
+ file_exists_ok: bool = True,
12
+ mkdir: bool = True,
13
+ require_is_file: bool = False,
14
+ ) -> str:
15
+ path = Path(os.path.expanduser(path))
16
+
17
+ for p in join_paths:
18
+ path = path.joinpath(p)
19
+
20
+ if extension is not None:
21
+ if extension != "":
22
+ extension = ("." + extension) if (extension[0] != ".") else extension
23
+
24
+ # check for paths that contain a dot "." in their filename (through a number)
25
+ # or that already have an extension
26
+ old_suffix = path.suffix
27
+ if old_suffix != "" and old_suffix != extension:
28
+ warnings.warn(
29
+ f"The path ({path}) already has an extension (`{old_suffix}`), but "
30
+ f"it gets replaced by the extension=`{extension}`."
31
+ )
32
+
33
+ path = path.with_suffix(extension)
34
+
35
+ if not file_exists_ok and os.path.exists(path):
36
+ raise Exception(f"File {path} already exists but shouldn't")
37
+
38
+ if mkdir:
39
+ path.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ if require_is_file:
42
+ assert path.is_file(), f"Not a file: {path}"
43
+
44
+ return str(path)
ring/utils/utils.py ADDED
@@ -0,0 +1,161 @@
1
+ from importlib import import_module as _import_module
2
+ import io
3
+ import pickle
4
+ from typing import Optional
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+
10
+ from ring.base import _Base
11
+ from ring.base import Geometry
12
+
13
+ from .path import parse_path
14
+
15
+
16
+ def tree_equal(a, b):
17
+ "Copied from Marcel / Thomas"
18
+ if type(a) is not type(b):
19
+ return False
20
+ if isinstance(a, _Base):
21
+ return tree_equal(a.__dict__, b.__dict__)
22
+ if isinstance(a, dict):
23
+ if a.keys() != b.keys():
24
+ return False
25
+ return all(tree_equal(a[k], b[k]) for k in a.keys())
26
+ if isinstance(a, (tuple, list)):
27
+ if len(a) != len(b):
28
+ return False
29
+ return all(tree_equal(a[i], b[i]) for i in range(len(a)))
30
+ if isinstance(a, (jax.Array, np.ndarray)):
31
+ return jnp.allclose(a, b)
32
+ return a == b
33
+
34
+
35
+ def _sys_compare_unsafe(sys1, sys2, verbose: bool, prefix: str) -> bool:
36
+ d1 = sys1.__dict__
37
+ d2 = sys2.__dict__
38
+ for key in d1:
39
+ if isinstance(d1[key], _Base):
40
+ if not _sys_compare_unsafe(d1[key], d2[key], verbose, prefix + "." + key):
41
+ return False
42
+ elif isinstance(d1[key], list) and isinstance(d1[key][0], Geometry):
43
+ for ele1, ele2 in zip(d1[key], d2[key]):
44
+ if not _sys_compare_unsafe(ele1, ele2, verbose, prefix + "." + key):
45
+ return False
46
+ else:
47
+ if not tree_equal(d1[key], d2[key]):
48
+ if verbose:
49
+ print(f"Systems different in attribute `sys{prefix}.{key}`")
50
+ print(f"{repr(d1[key])} NOT EQUAL {repr(d2[key])}")
51
+ return False
52
+ return True
53
+
54
+
55
+ def sys_compare(sys1, sys2, verbose: bool = True):
56
+ equalA = _sys_compare_unsafe(sys1, sys2, verbose, "")
57
+ equalB = tree_equal(sys1, sys2)
58
+ assert equalA == equalB
59
+ return equalA
60
+
61
+
62
+ def to_list(obj: object) -> list:
63
+ "obj -> [obj], if it isn't already a list."
64
+ if not isinstance(obj, list):
65
+ return [obj]
66
+ return obj
67
+
68
+
69
+ def dict_union(
70
+ d1: dict[str, jax.Array] | dict[str, dict[str, jax.Array]],
71
+ d2: dict[str, jax.Array] | dict[str, dict[str, jax.Array]],
72
+ overwrite: bool = False,
73
+ ) -> dict:
74
+ "Builds the union between two nested dictonaries."
75
+ # safety copying; otherwise this function would mutate out of scope
76
+ d1 = pytree_deepcopy(d1)
77
+ d2 = pytree_deepcopy(d2)
78
+
79
+ for key2 in d2:
80
+ if key2 not in d1:
81
+ d1[key2] = d2[key2]
82
+ else:
83
+ if not isinstance(d2[key2], dict) or not isinstance(d1[key2], dict):
84
+ raise Exception(f"d1.keys()={d1.keys()}; d2.keys()={d2.keys()}")
85
+
86
+ for key_nested in d2[key2]:
87
+ if not overwrite:
88
+ assert (
89
+ key_nested not in d1[key2]
90
+ ), f"d1.keys()={d1[key2].keys()}; d2.keys()={d2[key2].keys()}"
91
+
92
+ d1[key2].update(d2[key2])
93
+ return d1
94
+
95
+
96
+ def dict_to_nested(
97
+ d: dict[str, jax.Array], add_key: str
98
+ ) -> dict[str, dict[str, jax.Array]]:
99
+ "Nests a dictonary by inserting a single key dictonary."
100
+ return {key: {add_key: d[key]} for key in d.keys()}
101
+
102
+
103
+ def save_figure_to_rgba(fig) -> np.ndarray:
104
+ with io.BytesIO() as buff:
105
+ fig.savefig(buff, format="raw")
106
+ buff.seek(0)
107
+ data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
108
+ w, h = fig.canvas.get_width_height()
109
+ im = data.reshape((int(h), int(w), -1))
110
+ return im
111
+
112
+
113
+ def pytree_deepcopy(tree):
114
+ "Recursivley copies a pytree."
115
+ if isinstance(tree, (int, float, jax.Array)):
116
+ return tree
117
+ elif isinstance(tree, np.ndarray):
118
+ return tree.copy()
119
+ elif isinstance(tree, list):
120
+ return [pytree_deepcopy(ele) for ele in tree]
121
+ elif isinstance(tree, tuple):
122
+ return tuple(pytree_deepcopy(ele) for ele in tree)
123
+ elif isinstance(tree, dict):
124
+ return {key: pytree_deepcopy(value) for key, value in tree.items()}
125
+ else:
126
+ raise NotImplementedError(f"Not implemented for type={type(tree)}")
127
+
128
+
129
+ def import_lib(
130
+ lib: str,
131
+ required_for: Optional[str] = None,
132
+ lib_pypi: Optional[str] = None,
133
+ ):
134
+ try:
135
+ return _import_module(lib)
136
+ except ImportError:
137
+ _required = ""
138
+ if required_for is not None:
139
+ _required = f" but it is required for {required_for}"
140
+ if lib_pypi is None:
141
+ lib_pypi = lib
142
+ error_msg = (
143
+ f"Could not import `{lib}`{_required}. "
144
+ f"Please install with `pip install {lib_pypi}`"
145
+ )
146
+ raise ImportError(error_msg)
147
+
148
+
149
+ def pickle_save(obj, path, overwrite: bool = False):
150
+ path = parse_path(path, extension="pickle", file_exists_ok=overwrite)
151
+ with open(path, "wb") as file:
152
+ pickle.dump(obj, file, protocol=5)
153
+
154
+
155
+ def pickle_load(
156
+ path,
157
+ ):
158
+ path = parse_path(path, extension="pickle", require_is_file=True)
159
+ with open(path, "rb") as file:
160
+ obj = pickle.load(file)
161
+ return obj