unienv 0.0.1b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b1/LICENSE +23 -0
- unienv-0.0.1b1/PKG-INFO +20 -0
- unienv-0.0.1b1/README.md +46 -0
- unienv-0.0.1b1/pyproject.toml +31 -0
- unienv-0.0.1b1/setup.cfg +4 -0
- unienv-0.0.1b1/unienv.egg-info/PKG-INFO +20 -0
- unienv-0.0.1b1/unienv.egg-info/SOURCES.txt +88 -0
- unienv-0.0.1b1/unienv.egg-info/dependency_links.txt +1 -0
- unienv-0.0.1b1/unienv.egg-info/requires.txt +19 -0
- unienv-0.0.1b1/unienv.egg-info/top_level.txt +6 -0
- unienv-0.0.1b1/unienv_data/__init__.py +4 -0
- unienv-0.0.1b1/unienv_data/base/__init__.py +3 -0
- unienv-0.0.1b1/unienv_data/base/common.py +228 -0
- unienv-0.0.1b1/unienv_data/base/storage.py +132 -0
- unienv-0.0.1b1/unienv_data/base/transformations.py +144 -0
- unienv-0.0.1b1/unienv_data/batches/__init__.py +4 -0
- unienv-0.0.1b1/unienv_data/batches/backend_compat.py +181 -0
- unienv-0.0.1b1/unienv_data/batches/combined_batch.py +361 -0
- unienv-0.0.1b1/unienv_data/batches/framestack_batch.py +51 -0
- unienv-0.0.1b1/unienv_data/batches/slicestack_batch.py +428 -0
- unienv-0.0.1b1/unienv_data/integrations/pytorch.py +63 -0
- unienv-0.0.1b1/unienv_data/replay_buffer/__init__.py +2 -0
- unienv-0.0.1b1/unienv_data/replay_buffer/replay_buffer.py +263 -0
- unienv-0.0.1b1/unienv_data/replay_buffer/trajectory_replay_buffer.py +479 -0
- unienv-0.0.1b1/unienv_data/samplers/__init__.py +3 -0
- unienv-0.0.1b1/unienv_data/samplers/multiprocessing_sampler.py +388 -0
- unienv-0.0.1b1/unienv_data/samplers/slice_sampler.py +266 -0
- unienv-0.0.1b1/unienv_data/samplers/step_sampler.py +77 -0
- unienv-0.0.1b1/unienv_data/storages/common.py +156 -0
- unienv-0.0.1b1/unienv_data/storages/hdf5.py +376 -0
- unienv-0.0.1b1/unienv_data/storages/pytorch.py +161 -0
- unienv-0.0.1b1/unienv_interface/__init__.py +4 -0
- unienv-0.0.1b1/unienv_interface/backends/__init__.py +3 -0
- unienv-0.0.1b1/unienv_interface/backends/base.py +1 -0
- unienv-0.0.1b1/unienv_interface/backends/jax.py +18 -0
- unienv-0.0.1b1/unienv_interface/backends/numpy.py +19 -0
- unienv-0.0.1b1/unienv_interface/backends/pytorch.py +19 -0
- unienv-0.0.1b1/unienv_interface/backends/serialization.py +70 -0
- unienv-0.0.1b1/unienv_interface/env_base/__init__.py +4 -0
- unienv-0.0.1b1/unienv_interface/env_base/env.py +143 -0
- unienv-0.0.1b1/unienv_interface/env_base/funcenv.py +308 -0
- unienv-0.0.1b1/unienv_interface/env_base/funcenv_wrapper.py +231 -0
- unienv-0.0.1b1/unienv_interface/env_base/wrapper.py +286 -0
- unienv-0.0.1b1/unienv_interface/func_wrapper/__init__.py +1 -0
- unienv-0.0.1b1/unienv_interface/func_wrapper/transformation.py +127 -0
- unienv-0.0.1b1/unienv_interface/space/__init__.py +3 -0
- unienv-0.0.1b1/unienv_interface/space/space.py +119 -0
- unienv-0.0.1b1/unienv_interface/space/space_utils/__init__.py +3 -0
- unienv-0.0.1b1/unienv_interface/space/space_utils/batch_utils.py +867 -0
- unienv-0.0.1b1/unienv_interface/space/space_utils/flatten_utils.py +306 -0
- unienv-0.0.1b1/unienv_interface/space/space_utils/gym_utils.py +450 -0
- unienv-0.0.1b1/unienv_interface/space/space_utils/serialization_utils.py +221 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/__init__.py +22 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/binary.py +100 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/box.py +324 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/dict.py +180 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/dynamic_box.py +452 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/graph.py +309 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/text.py +130 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/tuple.py +110 -0
- unienv-0.0.1b1/unienv_interface/space/spaces/union.py +115 -0
- unienv-0.0.1b1/unienv_interface/transformations/__init__.py +6 -0
- unienv-0.0.1b1/unienv_interface/transformations/batch_and_unbatch.py +36 -0
- unienv-0.0.1b1/unienv_interface/transformations/chained_transform.py +63 -0
- unienv-0.0.1b1/unienv_interface/transformations/dict_transform.py +127 -0
- unienv-0.0.1b1/unienv_interface/transformations/filter_dict.py +171 -0
- unienv-0.0.1b1/unienv_interface/transformations/rescale.py +98 -0
- unienv-0.0.1b1/unienv_interface/transformations/transformation.py +46 -0
- unienv-0.0.1b1/unienv_interface/utils/seed_util.py +21 -0
- unienv-0.0.1b1/unienv_interface/utils/symbol_util.py +13 -0
- unienv-0.0.1b1/unienv_interface/world/__init__.py +4 -0
- unienv-0.0.1b1/unienv_interface/world/funcnode.py +210 -0
- unienv-0.0.1b1/unienv_interface/world/funcworld.py +62 -0
- unienv-0.0.1b1/unienv_interface/world/node.py +148 -0
- unienv-0.0.1b1/unienv_interface/world/world.py +167 -0
- unienv-0.0.1b1/unienv_interface/wrapper/__init__.py +9 -0
- unienv-0.0.1b1/unienv_interface/wrapper/action_rescale.py +31 -0
- unienv-0.0.1b1/unienv_interface/wrapper/backend_compat.py +187 -0
- unienv-0.0.1b1/unienv_interface/wrapper/batch_and_unbatch.py +84 -0
- unienv-0.0.1b1/unienv_interface/wrapper/control_frequency_limit.py +63 -0
- unienv-0.0.1b1/unienv_interface/wrapper/flatten.py +145 -0
- unienv-0.0.1b1/unienv_interface/wrapper/frame_stack.py +242 -0
- unienv-0.0.1b1/unienv_interface/wrapper/gym_compat.py +299 -0
- unienv-0.0.1b1/unienv_interface/wrapper/time_limit.py +77 -0
- unienv-0.0.1b1/unienv_interface/wrapper/transformation.py +187 -0
- unienv-0.0.1b1/unienv_interface/wrapper/video_record.py +235 -0
- unienv-0.0.1b1/unienv_maniskill/__init__.py +1 -0
- unienv-0.0.1b1/unienv_maniskill/wrapper/maniskill_compat.py +235 -0
- unienv-0.0.1b1/unienv_mjxplayground/__init__.py +1 -0
- unienv-0.0.1b1/unienv_mjxplayground/wrapper/playground_compat.py +256 -0
unienv-0.0.1b1/LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2016 OpenAI
|
|
4
|
+
Copyright (c) 2022 Farama Foundation
|
|
5
|
+
Copyright (c) 2024 Yunhao Cao
|
|
6
|
+
|
|
7
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
in the Software without restriction, including without limitation the rights
|
|
10
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
furnished to do so, subject to the following conditions:
|
|
13
|
+
|
|
14
|
+
The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
copies or substantial portions of the Software.
|
|
16
|
+
|
|
17
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
SOFTWARE.
|
unienv-0.0.1b1/PKG-INFO
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: unienv
|
|
3
|
+
Version: 0.0.1b1
|
|
4
|
+
Requires-Python: >=3.10
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Dist: numpy
|
|
7
|
+
Requires-Dist: xbarray>=0.0.1a8
|
|
8
|
+
Requires-Dist: pillow
|
|
9
|
+
Requires-Dist: h5py
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: pytest; extra == "dev"
|
|
12
|
+
Provides-Extra: gymnasium
|
|
13
|
+
Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
|
|
14
|
+
Provides-Extra: video
|
|
15
|
+
Requires-Dist: moviepy>=2.1; extra == "video"
|
|
16
|
+
Provides-Extra: mjx
|
|
17
|
+
Requires-Dist: playground; extra == "mjx"
|
|
18
|
+
Provides-Extra: maniskill
|
|
19
|
+
Requires-Dist: mani_skill>=3.0.0b12; extra == "maniskill"
|
|
20
|
+
Dynamic: license-file
|
unienv-0.0.1b1/README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# UniEnvPy
|
|
2
|
+
|
|
3
|
+
TLDR: Gymnasium Library replacement with support for multiple tensor backends
|
|
4
|
+
|
|
5
|
+
Provides an universal interface for single / parallel state-based or function-based environments. Also contains a set of utilities (such as replay buffers, wrappers, etc.) to facilitate the training of reinforcement learning agents.
|
|
6
|
+
|
|
7
|
+
## Cross-backend Support
|
|
8
|
+
|
|
9
|
+
UniEnvPy supports multiple tensor backends with zero-copy translation layers through the DLPack protocol, and allows you to use the same abstract compute backend interface to write custom data transformation layers, environment wrappers and other utilities. This is powered by the [xbarray](https://github.com/realquantumcookie/xbarray) package, which builts on top of the Array API Standard, and supports the following backends:
|
|
10
|
+
|
|
11
|
+
- numpy
|
|
12
|
+
- pytorch
|
|
13
|
+
- jax
|
|
14
|
+
|
|
15
|
+
We also support diverse simulation environments and real robots, built on top of the abstract environment / world interface. This allows you to reuse code across different sim and real robots.
|
|
16
|
+
|
|
17
|
+
Current supported simulation environments:
|
|
18
|
+
- Any Environment defined in Gymnasium interface
|
|
19
|
+
- <s>Mujoco</s> (New code will be added in the future, but I'm currently working on refractoring World based environments)
|
|
20
|
+
- MJX based on [Mujoco-Playground](https://github.com/google-deepmind/mujoco_playground)
|
|
21
|
+
- [ManiSkill 3](https://github.com/haosulab/ManiSkill/)
|
|
22
|
+
|
|
23
|
+
Current supported real robots:
|
|
24
|
+
- Franka Research 3 + RobotiQ Gripper in Droid Setup
|
|
25
|
+
- OyMotion OHand
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
Clone down this repo
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
git clone https://github.com/realquantumcookie/UniEnvPy
|
|
33
|
+
cd UniEnvPy
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Install the package with pip
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
pip install -e .
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
You can install optional dependencies such as `gymnasium`, `mjx`, `maniskill`, `video` by running
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install -e .[gymnasium,mjx,maniskill,video]
|
|
46
|
+
```
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "unienv"
|
|
3
|
+
version = "0.0.1b1"
|
|
4
|
+
requires-python = ">= 3.10"
|
|
5
|
+
dependencies = [
|
|
6
|
+
"numpy",
|
|
7
|
+
"xbarray>=0.0.1a8",
|
|
8
|
+
"pillow",
|
|
9
|
+
"h5py",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
[project.optional-dependencies]
|
|
13
|
+
dev = [
|
|
14
|
+
"pytest",
|
|
15
|
+
]
|
|
16
|
+
gymnasium = [
|
|
17
|
+
"gymnasium>=0.29.0",
|
|
18
|
+
]
|
|
19
|
+
video = [
|
|
20
|
+
"moviepy>=2.1"
|
|
21
|
+
]
|
|
22
|
+
mjx = [
|
|
23
|
+
"playground",
|
|
24
|
+
]
|
|
25
|
+
maniskill = [
|
|
26
|
+
"mani_skill>=3.0.0b12"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[tool.setuptools.packages.find]
|
|
30
|
+
include = ["*"]
|
|
31
|
+
exclude = ["training*", "tests*"]
|
unienv-0.0.1b1/setup.cfg
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: unienv
|
|
3
|
+
Version: 0.0.1b1
|
|
4
|
+
Requires-Python: >=3.10
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Dist: numpy
|
|
7
|
+
Requires-Dist: xbarray>=0.0.1a8
|
|
8
|
+
Requires-Dist: pillow
|
|
9
|
+
Requires-Dist: h5py
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: pytest; extra == "dev"
|
|
12
|
+
Provides-Extra: gymnasium
|
|
13
|
+
Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
|
|
14
|
+
Provides-Extra: video
|
|
15
|
+
Requires-Dist: moviepy>=2.1; extra == "video"
|
|
16
|
+
Provides-Extra: mjx
|
|
17
|
+
Requires-Dist: playground; extra == "mjx"
|
|
18
|
+
Provides-Extra: maniskill
|
|
19
|
+
Requires-Dist: mani_skill>=3.0.0b12; extra == "maniskill"
|
|
20
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
unienv.egg-info/PKG-INFO
|
|
5
|
+
unienv.egg-info/SOURCES.txt
|
|
6
|
+
unienv.egg-info/dependency_links.txt
|
|
7
|
+
unienv.egg-info/requires.txt
|
|
8
|
+
unienv.egg-info/top_level.txt
|
|
9
|
+
unienv_data/__init__.py
|
|
10
|
+
unienv_data/base/__init__.py
|
|
11
|
+
unienv_data/base/common.py
|
|
12
|
+
unienv_data/base/storage.py
|
|
13
|
+
unienv_data/base/transformations.py
|
|
14
|
+
unienv_data/batches/__init__.py
|
|
15
|
+
unienv_data/batches/backend_compat.py
|
|
16
|
+
unienv_data/batches/combined_batch.py
|
|
17
|
+
unienv_data/batches/framestack_batch.py
|
|
18
|
+
unienv_data/batches/slicestack_batch.py
|
|
19
|
+
unienv_data/integrations/pytorch.py
|
|
20
|
+
unienv_data/replay_buffer/__init__.py
|
|
21
|
+
unienv_data/replay_buffer/replay_buffer.py
|
|
22
|
+
unienv_data/replay_buffer/trajectory_replay_buffer.py
|
|
23
|
+
unienv_data/samplers/__init__.py
|
|
24
|
+
unienv_data/samplers/multiprocessing_sampler.py
|
|
25
|
+
unienv_data/samplers/slice_sampler.py
|
|
26
|
+
unienv_data/samplers/step_sampler.py
|
|
27
|
+
unienv_data/storages/common.py
|
|
28
|
+
unienv_data/storages/hdf5.py
|
|
29
|
+
unienv_data/storages/pytorch.py
|
|
30
|
+
unienv_interface/__init__.py
|
|
31
|
+
unienv_interface/backends/__init__.py
|
|
32
|
+
unienv_interface/backends/base.py
|
|
33
|
+
unienv_interface/backends/jax.py
|
|
34
|
+
unienv_interface/backends/numpy.py
|
|
35
|
+
unienv_interface/backends/pytorch.py
|
|
36
|
+
unienv_interface/backends/serialization.py
|
|
37
|
+
unienv_interface/env_base/__init__.py
|
|
38
|
+
unienv_interface/env_base/env.py
|
|
39
|
+
unienv_interface/env_base/funcenv.py
|
|
40
|
+
unienv_interface/env_base/funcenv_wrapper.py
|
|
41
|
+
unienv_interface/env_base/wrapper.py
|
|
42
|
+
unienv_interface/func_wrapper/__init__.py
|
|
43
|
+
unienv_interface/func_wrapper/transformation.py
|
|
44
|
+
unienv_interface/space/__init__.py
|
|
45
|
+
unienv_interface/space/space.py
|
|
46
|
+
unienv_interface/space/space_utils/__init__.py
|
|
47
|
+
unienv_interface/space/space_utils/batch_utils.py
|
|
48
|
+
unienv_interface/space/space_utils/flatten_utils.py
|
|
49
|
+
unienv_interface/space/space_utils/gym_utils.py
|
|
50
|
+
unienv_interface/space/space_utils/serialization_utils.py
|
|
51
|
+
unienv_interface/space/spaces/__init__.py
|
|
52
|
+
unienv_interface/space/spaces/binary.py
|
|
53
|
+
unienv_interface/space/spaces/box.py
|
|
54
|
+
unienv_interface/space/spaces/dict.py
|
|
55
|
+
unienv_interface/space/spaces/dynamic_box.py
|
|
56
|
+
unienv_interface/space/spaces/graph.py
|
|
57
|
+
unienv_interface/space/spaces/text.py
|
|
58
|
+
unienv_interface/space/spaces/tuple.py
|
|
59
|
+
unienv_interface/space/spaces/union.py
|
|
60
|
+
unienv_interface/transformations/__init__.py
|
|
61
|
+
unienv_interface/transformations/batch_and_unbatch.py
|
|
62
|
+
unienv_interface/transformations/chained_transform.py
|
|
63
|
+
unienv_interface/transformations/dict_transform.py
|
|
64
|
+
unienv_interface/transformations/filter_dict.py
|
|
65
|
+
unienv_interface/transformations/rescale.py
|
|
66
|
+
unienv_interface/transformations/transformation.py
|
|
67
|
+
unienv_interface/utils/seed_util.py
|
|
68
|
+
unienv_interface/utils/symbol_util.py
|
|
69
|
+
unienv_interface/world/__init__.py
|
|
70
|
+
unienv_interface/world/funcnode.py
|
|
71
|
+
unienv_interface/world/funcworld.py
|
|
72
|
+
unienv_interface/world/node.py
|
|
73
|
+
unienv_interface/world/world.py
|
|
74
|
+
unienv_interface/wrapper/__init__.py
|
|
75
|
+
unienv_interface/wrapper/action_rescale.py
|
|
76
|
+
unienv_interface/wrapper/backend_compat.py
|
|
77
|
+
unienv_interface/wrapper/batch_and_unbatch.py
|
|
78
|
+
unienv_interface/wrapper/control_frequency_limit.py
|
|
79
|
+
unienv_interface/wrapper/flatten.py
|
|
80
|
+
unienv_interface/wrapper/frame_stack.py
|
|
81
|
+
unienv_interface/wrapper/gym_compat.py
|
|
82
|
+
unienv_interface/wrapper/time_limit.py
|
|
83
|
+
unienv_interface/wrapper/transformation.py
|
|
84
|
+
unienv_interface/wrapper/video_record.py
|
|
85
|
+
unienv_maniskill/__init__.py
|
|
86
|
+
unienv_maniskill/wrapper/maniskill_compat.py
|
|
87
|
+
unienv_mjxplayground/__init__.py
|
|
88
|
+
unienv_mjxplayground/wrapper/playground_compat.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from typing import List, Tuple, Union, Dict, Any, Optional, Generic, TypeVar, Iterable, Iterator
|
|
2
|
+
from types import EllipsisType
|
|
3
|
+
import os
|
|
4
|
+
import abc
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
7
|
+
from unienv_interface.space import Space, BoxSpace, DictSpace
|
|
8
|
+
import dataclasses
|
|
9
|
+
|
|
10
|
+
from unienv_interface.space.space_utils import batch_utils as space_batch_utils, flatten_utils as space_flatten_utils
|
|
11
|
+
|
|
12
|
+
IndexableType = Union[int, slice, EllipsisType]
|
|
13
|
+
|
|
14
|
+
BatchT = TypeVar('BatchT')
|
|
15
|
+
class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
16
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
17
|
+
device: Optional[BDeviceType] = None
|
|
18
|
+
|
|
19
|
+
# If the batch is mutable, then the data can be changed (extend_*, set_*, remove_*, etc.)
|
|
20
|
+
is_mutable: bool = True
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
single_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
25
|
+
single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
|
|
26
|
+
):
|
|
27
|
+
self.single_space = single_space
|
|
28
|
+
self.single_metadata_space = single_metadata_space
|
|
29
|
+
self._batched_space : Space[
|
|
30
|
+
BatchT, Any, BDeviceType, BDtypeType, BRNGType
|
|
31
|
+
] = space_batch_utils.batch_space(single_space, 1)
|
|
32
|
+
if single_metadata_space is not None:
|
|
33
|
+
self._batched_metadata_space : DictSpace[
|
|
34
|
+
BDeviceType, BDtypeType, BRNGType
|
|
35
|
+
] = space_batch_utils.batch_space(single_metadata_space, 1)
|
|
36
|
+
else:
|
|
37
|
+
self._batched_metadata_space = None
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def __len__(self) -> int:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def get_flattened_at(self, idx : Union[IndexableType, BArrayType]) -> BArrayType:
|
|
44
|
+
return self.get_flattened_at_with_metadata(idx)[0]
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def get_flattened_at_with_metadata(
|
|
48
|
+
self, idx : Union[IndexableType, BArrayType]
|
|
49
|
+
) -> Tuple[BArrayType, Optional[Dict[str, Any]]]:
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
def set_flattened_at(self, idx : Union[IndexableType, BArrayType], value : BArrayType) -> None:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
def extend_flattened(self, value : BArrayType) -> None:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
def get_at(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
|
|
59
|
+
flattened_data = self.get_flattened_at(idx)
|
|
60
|
+
if isinstance(idx, int):
|
|
61
|
+
return space_flatten_utils.unflatten_data(self.single_space, flattened_data)
|
|
62
|
+
else:
|
|
63
|
+
return space_flatten_utils.unflatten_data(self._batched_space, flattened_data, start_dim=1)
|
|
64
|
+
|
|
65
|
+
def get_at_with_metadata(
|
|
66
|
+
self, idx : Union[IndexableType, BArrayType]
|
|
67
|
+
) -> Tuple[BatchT, Optional[Dict[str, Any]]]:
|
|
68
|
+
flattened_data, metadata = self.get_flattened_at_with_metadata(idx)
|
|
69
|
+
if isinstance(idx, int):
|
|
70
|
+
return space_flatten_utils.unflatten_data(self.single_space, flattened_data), metadata
|
|
71
|
+
else:
|
|
72
|
+
return space_flatten_utils.unflatten_data(self._batched_space, flattened_data, start_dim=1), metadata
|
|
73
|
+
|
|
74
|
+
def __getitem__(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
|
|
75
|
+
return self.get_at(idx)
|
|
76
|
+
|
|
77
|
+
def set_at(self, idx : Union[IndexableType, BArrayType], value : BatchT) -> None:
|
|
78
|
+
if isinstance(idx, int):
|
|
79
|
+
flattened_data = space_flatten_utils.flatten_data(self.single_space, value)
|
|
80
|
+
else:
|
|
81
|
+
flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
|
|
82
|
+
self.set_flattened_at(idx, flattened_data)
|
|
83
|
+
|
|
84
|
+
def __setitem__(self, idx : Union[IndexableType, BArrayType], value : BatchT) -> None:
|
|
85
|
+
self.set_at(idx, value)
|
|
86
|
+
|
|
87
|
+
def remove_at(self, idx : Union[IndexableType, BArrayType]) -> None:
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
def __delitem__(self, idx : Union[IndexableType, BArrayType]) -> None:
|
|
91
|
+
self.remove_at(idx)
|
|
92
|
+
|
|
93
|
+
def extend(self, value : BatchT) -> None:
|
|
94
|
+
flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
|
|
95
|
+
self.extend_flattened(flattened_data)
|
|
96
|
+
|
|
97
|
+
def close(self) -> None:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
def __del__(self) -> None:
|
|
101
|
+
self.close()
|
|
102
|
+
|
|
103
|
+
SamplerBatchT = TypeVar('SamplerBatchT')
|
|
104
|
+
SamplerArrayType = TypeVar('SamplerArrayType')
|
|
105
|
+
SamplerDeviceType = TypeVar('SamplerDeviceType')
|
|
106
|
+
SamplerDtypeType = TypeVar('SamplerDtypeType')
|
|
107
|
+
SamplerRNGType = TypeVar('SamplerRNGType')
|
|
108
|
+
class BatchSampler(abc.ABC, Generic[
|
|
109
|
+
SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType,
|
|
110
|
+
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
|
|
111
|
+
]):
|
|
112
|
+
batch_size : int
|
|
113
|
+
sampled_space : Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
|
|
114
|
+
sampled_space_flat : BoxSpace[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
|
|
115
|
+
sampled_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = None
|
|
116
|
+
|
|
117
|
+
backend : ComputeBackend[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
|
|
118
|
+
device : Optional[SamplerDeviceType] = None
|
|
119
|
+
|
|
120
|
+
data : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
121
|
+
|
|
122
|
+
rng : Optional[SamplerRNGType] = None
|
|
123
|
+
data_rng : Optional[BRNGType] = None
|
|
124
|
+
|
|
125
|
+
def get_flat_at(self, idx : SamplerArrayType) -> SamplerArrayType:
|
|
126
|
+
return self.get_flat_at_with_metadata(idx)[0]
|
|
127
|
+
|
|
128
|
+
@abc.abstractmethod
|
|
129
|
+
def get_flat_at_with_metadata(
|
|
130
|
+
self, idx : SamplerArrayType
|
|
131
|
+
) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
|
|
132
|
+
raise NotImplementedError
|
|
133
|
+
|
|
134
|
+
def get_at(self, idx : SamplerArrayType) -> SamplerBatchT:
|
|
135
|
+
return space_flatten_utils.unflatten_data(self.sampled_space, self.get_flat_at(idx), start_dim=1)
|
|
136
|
+
|
|
137
|
+
def get_at_with_metadata(
|
|
138
|
+
self, idx : SamplerArrayType
|
|
139
|
+
) -> Tuple[SamplerBatchT, Optional[Dict[str, Any]]]:
|
|
140
|
+
flat_data, metadata = self.get_flat_at_with_metadata(idx)
|
|
141
|
+
return space_flatten_utils.unflatten_data(self.sampled_space, flat_data, start_dim=1), metadata
|
|
142
|
+
|
|
143
|
+
def sample_index(self) -> SamplerArrayType:
|
|
144
|
+
new_rng, indices = self.backend.random.random_discrete_uniform( # (B, )
|
|
145
|
+
(self.batch_size,),
|
|
146
|
+
0,
|
|
147
|
+
len(self.data),
|
|
148
|
+
rng=self.data_rng if self.data_rng is not None else self.rng,
|
|
149
|
+
device=self.data.device,
|
|
150
|
+
)
|
|
151
|
+
if self.data_rng is not None:
|
|
152
|
+
self.data_rng = new_rng
|
|
153
|
+
else:
|
|
154
|
+
self.rng = new_rng
|
|
155
|
+
return indices
|
|
156
|
+
|
|
157
|
+
def sample_flat(self) -> SamplerArrayType:
|
|
158
|
+
idx = self.sample_index()
|
|
159
|
+
return self.get_flat_at(idx)
|
|
160
|
+
|
|
161
|
+
def sample_flat_with_metadata(self) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
|
|
162
|
+
idx = self.sample_index()
|
|
163
|
+
return self.get_flat_at_with_metadata(idx)
|
|
164
|
+
|
|
165
|
+
def sample(self) -> SamplerBatchT:
|
|
166
|
+
idx = self.sample_index()
|
|
167
|
+
return self.get_at(idx)
|
|
168
|
+
|
|
169
|
+
def sample_with_metadata(self) -> Tuple[SamplerBatchT, Optional[Dict[str, Any]]]:
|
|
170
|
+
idx = self.sample_index()
|
|
171
|
+
return self.get_at_with_metadata(idx)
|
|
172
|
+
|
|
173
|
+
def __iter__(self) -> Iterator[SamplerBatchT]:
|
|
174
|
+
return self.epoch_iter()
|
|
175
|
+
|
|
176
|
+
def epoch_iter(self) -> Iterator[SamplerBatchT]:
|
|
177
|
+
if self.data_rng is not None:
|
|
178
|
+
self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
|
|
179
|
+
else:
|
|
180
|
+
self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
|
|
181
|
+
n_batches = len(self.data) // self.batch_size
|
|
182
|
+
num_left = len(self.data) % self.batch_size
|
|
183
|
+
for i in range(n_batches):
|
|
184
|
+
yield self.get_at(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
185
|
+
if num_left > 0:
|
|
186
|
+
yield self.get_at(idx[-num_left:])
|
|
187
|
+
|
|
188
|
+
def epoch_iter_with_metadata(self) -> Iterator[Tuple[SamplerBatchT, Optional[Dict[str, Any]]]]:
|
|
189
|
+
if self.data_rng is not None:
|
|
190
|
+
self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
|
|
191
|
+
else:
|
|
192
|
+
self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
|
|
193
|
+
n_batches = len(self.data) // self.batch_size
|
|
194
|
+
num_left = len(self.data) % self.batch_size
|
|
195
|
+
for i in range(n_batches):
|
|
196
|
+
yield self.get_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
197
|
+
if num_left > 0:
|
|
198
|
+
yield self.get_at_with_metadata(idx[-num_left:])
|
|
199
|
+
|
|
200
|
+
def epoch_flat_iter(self) -> Iterator[SamplerArrayType]:
|
|
201
|
+
if self.data_rng is not None:
|
|
202
|
+
self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
|
|
203
|
+
else:
|
|
204
|
+
self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
|
|
205
|
+
n_batches = len(self.data) // self.batch_size
|
|
206
|
+
num_left = len(self.data) % self.batch_size
|
|
207
|
+
for i in range(n_batches):
|
|
208
|
+
yield self.get_flat_at(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
209
|
+
if num_left > 0:
|
|
210
|
+
yield self.get_flat_at(idx[-num_left:])
|
|
211
|
+
|
|
212
|
+
def epoch_flat_iter_with_metadata(self) -> Iterator[Tuple[SamplerArrayType, Optional[Dict[str, Any]]]]:
|
|
213
|
+
if self.data_rng is not None:
|
|
214
|
+
self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
|
|
215
|
+
else:
|
|
216
|
+
self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
|
|
217
|
+
n_batches = len(self.data) // self.batch_size
|
|
218
|
+
num_left = len(self.data) % self.batch_size
|
|
219
|
+
for i in range(n_batches):
|
|
220
|
+
yield self.get_flat_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
|
|
221
|
+
if num_left > 0:
|
|
222
|
+
yield self.get_flat_at_with_metadata(idx[-num_left:])
|
|
223
|
+
|
|
224
|
+
def close(self) -> None:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
def __del__(self) -> None:
|
|
228
|
+
self.close()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import os
|
|
3
|
+
from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
|
|
4
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
5
|
+
|
|
6
|
+
from unienv_interface.space import Space
|
|
7
|
+
from .common import BatchBase, BatchT, IndexableType
|
|
8
|
+
|
|
9
|
+
class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
10
|
+
"""
|
|
11
|
+
SpaceStorage is an abstract base class for storages that hold instances of a specific space.
|
|
12
|
+
It provides a common interface for creating, loading, and managing the storage of instances of a given space.
|
|
13
|
+
Note that if you want your space storage to support multiprocessing, you need to check / implement `__getstate__` and `__setstate__` methods to ensure that the storage can be pickled and unpickled correctly.
|
|
14
|
+
"""
|
|
15
|
+
# ========== Class Creation and Loading Methods ==========
|
|
16
|
+
@classmethod
|
|
17
|
+
def create(
|
|
18
|
+
cls,
|
|
19
|
+
single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
20
|
+
*args,
|
|
21
|
+
capacity : Optional[int],
|
|
22
|
+
cache_path : Optional[Union[str, os.PathLike]] = None,
|
|
23
|
+
**kwargs
|
|
24
|
+
) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def load_from(
|
|
29
|
+
cls,
|
|
30
|
+
path: Union[str, os.PathLike],
|
|
31
|
+
single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
32
|
+
*,
|
|
33
|
+
capacity : Optional[int] = None,
|
|
34
|
+
**kwargs
|
|
35
|
+
) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
# ========== Class Attributes ==========
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
The file extension (e.g. `.pt`) used for saving a single instance of the space.
|
|
42
|
+
If this is None, it means the storage stores files in a folder
|
|
43
|
+
"""
|
|
44
|
+
single_file_ext : Optional[str] = None
|
|
45
|
+
|
|
46
|
+
# ======== Instance Attributes ==========
|
|
47
|
+
"""
|
|
48
|
+
The total capacity (number of single instances) of the storage.
|
|
49
|
+
If None, the storage has unlimited capacity.
|
|
50
|
+
"""
|
|
51
|
+
capacity : Optional[int] = None
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
The cache path for the storage.
|
|
55
|
+
If None, the storage will not use caching.
|
|
56
|
+
"""
|
|
57
|
+
cache_filename : Optional[Union[str, os.PathLike]] = None
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
61
|
+
return self.single_instance_space.backend
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def device(self) -> Optional[BDeviceType]:
|
|
65
|
+
return self.single_instance_space.device
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
70
|
+
):
|
|
71
|
+
self.single_instance_space = single_instance_space
|
|
72
|
+
|
|
73
|
+
def extend_length(self, length : int) -> None:
|
|
74
|
+
"""
|
|
75
|
+
This is used by capacity = None storages to extend the length of the storage
|
|
76
|
+
If this is called on a storage with a fixed capacity, we will simply ignore the call.
|
|
77
|
+
"""
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
def shrink_length(self, length : int) -> None:
|
|
81
|
+
"""
|
|
82
|
+
This is used by capacity = None storages to shrink the length of the storage
|
|
83
|
+
If this is called on a storage with a fixed capacity, we will simply ignore the call.
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def __len__(self) -> int:
|
|
88
|
+
"""
|
|
89
|
+
Returns the number of instances in the storage
|
|
90
|
+
Storages with unlimited capacity should implement this method to return the current length of the storage.
|
|
91
|
+
"""
|
|
92
|
+
if self.capacity is None:
|
|
93
|
+
raise NotImplementedError(f"__len__ is not implemented for class {type(self).__name__}")
|
|
94
|
+
return self.capacity
|
|
95
|
+
|
|
96
|
+
# We don't define them here, since they are optional and the `ReplayBuffer` checks if they are implemented
|
|
97
|
+
# by using hasattr(self, "get_flattened") and hasattr(self, "set_flattened").
|
|
98
|
+
# def get_flattened(self, index : Union[IndexableType, BArrayType]) -> BArrayType:
|
|
99
|
+
# raise NotImplementedError
|
|
100
|
+
|
|
101
|
+
# def set_flattened(self, index : Union[IndexableType, BArrayType], value : BArrayType) -> None:
|
|
102
|
+
# raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
@abc.abstractmethod
|
|
105
|
+
def get(self, index : Union[IndexableType, BArrayType]) -> BatchT:
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
@abc.abstractmethod
|
|
109
|
+
def set(self, index : Union[IndexableType, BArrayType], value : BatchT) -> None:
|
|
110
|
+
raise NotImplementedError
|
|
111
|
+
|
|
112
|
+
def clear(self) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Clear all data inside the storage and set the length to 0 if the storage has unlimited capacity.
|
|
115
|
+
For storages with fixed capacity, this should reset the storage to its initial state.
|
|
116
|
+
"""
|
|
117
|
+
if self.capacity is None:
|
|
118
|
+
self.shrink_length(len(self))
|
|
119
|
+
|
|
120
|
+
@abc.abstractmethod
|
|
121
|
+
def dumps(self, path : Union[str, os.PathLike]) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Dumps the storage to the specified path.
|
|
124
|
+
This is used for storages that have a single file extension (e.g. `.pt` for PyTorch).
|
|
125
|
+
"""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
def close(self) -> None:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
def __del__(self) -> None:
|
|
132
|
+
self.close()
|