unienv 0.0.1b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. unienv-0.0.1b1/LICENSE +23 -0
  2. unienv-0.0.1b1/PKG-INFO +20 -0
  3. unienv-0.0.1b1/README.md +46 -0
  4. unienv-0.0.1b1/pyproject.toml +31 -0
  5. unienv-0.0.1b1/setup.cfg +4 -0
  6. unienv-0.0.1b1/unienv.egg-info/PKG-INFO +20 -0
  7. unienv-0.0.1b1/unienv.egg-info/SOURCES.txt +88 -0
  8. unienv-0.0.1b1/unienv.egg-info/dependency_links.txt +1 -0
  9. unienv-0.0.1b1/unienv.egg-info/requires.txt +19 -0
  10. unienv-0.0.1b1/unienv.egg-info/top_level.txt +6 -0
  11. unienv-0.0.1b1/unienv_data/__init__.py +4 -0
  12. unienv-0.0.1b1/unienv_data/base/__init__.py +3 -0
  13. unienv-0.0.1b1/unienv_data/base/common.py +228 -0
  14. unienv-0.0.1b1/unienv_data/base/storage.py +132 -0
  15. unienv-0.0.1b1/unienv_data/base/transformations.py +144 -0
  16. unienv-0.0.1b1/unienv_data/batches/__init__.py +4 -0
  17. unienv-0.0.1b1/unienv_data/batches/backend_compat.py +181 -0
  18. unienv-0.0.1b1/unienv_data/batches/combined_batch.py +361 -0
  19. unienv-0.0.1b1/unienv_data/batches/framestack_batch.py +51 -0
  20. unienv-0.0.1b1/unienv_data/batches/slicestack_batch.py +428 -0
  21. unienv-0.0.1b1/unienv_data/integrations/pytorch.py +63 -0
  22. unienv-0.0.1b1/unienv_data/replay_buffer/__init__.py +2 -0
  23. unienv-0.0.1b1/unienv_data/replay_buffer/replay_buffer.py +263 -0
  24. unienv-0.0.1b1/unienv_data/replay_buffer/trajectory_replay_buffer.py +479 -0
  25. unienv-0.0.1b1/unienv_data/samplers/__init__.py +3 -0
  26. unienv-0.0.1b1/unienv_data/samplers/multiprocessing_sampler.py +388 -0
  27. unienv-0.0.1b1/unienv_data/samplers/slice_sampler.py +266 -0
  28. unienv-0.0.1b1/unienv_data/samplers/step_sampler.py +77 -0
  29. unienv-0.0.1b1/unienv_data/storages/common.py +156 -0
  30. unienv-0.0.1b1/unienv_data/storages/hdf5.py +376 -0
  31. unienv-0.0.1b1/unienv_data/storages/pytorch.py +161 -0
  32. unienv-0.0.1b1/unienv_interface/__init__.py +4 -0
  33. unienv-0.0.1b1/unienv_interface/backends/__init__.py +3 -0
  34. unienv-0.0.1b1/unienv_interface/backends/base.py +1 -0
  35. unienv-0.0.1b1/unienv_interface/backends/jax.py +18 -0
  36. unienv-0.0.1b1/unienv_interface/backends/numpy.py +19 -0
  37. unienv-0.0.1b1/unienv_interface/backends/pytorch.py +19 -0
  38. unienv-0.0.1b1/unienv_interface/backends/serialization.py +70 -0
  39. unienv-0.0.1b1/unienv_interface/env_base/__init__.py +4 -0
  40. unienv-0.0.1b1/unienv_interface/env_base/env.py +143 -0
  41. unienv-0.0.1b1/unienv_interface/env_base/funcenv.py +308 -0
  42. unienv-0.0.1b1/unienv_interface/env_base/funcenv_wrapper.py +231 -0
  43. unienv-0.0.1b1/unienv_interface/env_base/wrapper.py +286 -0
  44. unienv-0.0.1b1/unienv_interface/func_wrapper/__init__.py +1 -0
  45. unienv-0.0.1b1/unienv_interface/func_wrapper/transformation.py +127 -0
  46. unienv-0.0.1b1/unienv_interface/space/__init__.py +3 -0
  47. unienv-0.0.1b1/unienv_interface/space/space.py +119 -0
  48. unienv-0.0.1b1/unienv_interface/space/space_utils/__init__.py +3 -0
  49. unienv-0.0.1b1/unienv_interface/space/space_utils/batch_utils.py +867 -0
  50. unienv-0.0.1b1/unienv_interface/space/space_utils/flatten_utils.py +306 -0
  51. unienv-0.0.1b1/unienv_interface/space/space_utils/gym_utils.py +450 -0
  52. unienv-0.0.1b1/unienv_interface/space/space_utils/serialization_utils.py +221 -0
  53. unienv-0.0.1b1/unienv_interface/space/spaces/__init__.py +22 -0
  54. unienv-0.0.1b1/unienv_interface/space/spaces/binary.py +100 -0
  55. unienv-0.0.1b1/unienv_interface/space/spaces/box.py +324 -0
  56. unienv-0.0.1b1/unienv_interface/space/spaces/dict.py +180 -0
  57. unienv-0.0.1b1/unienv_interface/space/spaces/dynamic_box.py +452 -0
  58. unienv-0.0.1b1/unienv_interface/space/spaces/graph.py +309 -0
  59. unienv-0.0.1b1/unienv_interface/space/spaces/text.py +130 -0
  60. unienv-0.0.1b1/unienv_interface/space/spaces/tuple.py +110 -0
  61. unienv-0.0.1b1/unienv_interface/space/spaces/union.py +115 -0
  62. unienv-0.0.1b1/unienv_interface/transformations/__init__.py +6 -0
  63. unienv-0.0.1b1/unienv_interface/transformations/batch_and_unbatch.py +36 -0
  64. unienv-0.0.1b1/unienv_interface/transformations/chained_transform.py +63 -0
  65. unienv-0.0.1b1/unienv_interface/transformations/dict_transform.py +127 -0
  66. unienv-0.0.1b1/unienv_interface/transformations/filter_dict.py +171 -0
  67. unienv-0.0.1b1/unienv_interface/transformations/rescale.py +98 -0
  68. unienv-0.0.1b1/unienv_interface/transformations/transformation.py +46 -0
  69. unienv-0.0.1b1/unienv_interface/utils/seed_util.py +21 -0
  70. unienv-0.0.1b1/unienv_interface/utils/symbol_util.py +13 -0
  71. unienv-0.0.1b1/unienv_interface/world/__init__.py +4 -0
  72. unienv-0.0.1b1/unienv_interface/world/funcnode.py +210 -0
  73. unienv-0.0.1b1/unienv_interface/world/funcworld.py +62 -0
  74. unienv-0.0.1b1/unienv_interface/world/node.py +148 -0
  75. unienv-0.0.1b1/unienv_interface/world/world.py +167 -0
  76. unienv-0.0.1b1/unienv_interface/wrapper/__init__.py +9 -0
  77. unienv-0.0.1b1/unienv_interface/wrapper/action_rescale.py +31 -0
  78. unienv-0.0.1b1/unienv_interface/wrapper/backend_compat.py +187 -0
  79. unienv-0.0.1b1/unienv_interface/wrapper/batch_and_unbatch.py +84 -0
  80. unienv-0.0.1b1/unienv_interface/wrapper/control_frequency_limit.py +63 -0
  81. unienv-0.0.1b1/unienv_interface/wrapper/flatten.py +145 -0
  82. unienv-0.0.1b1/unienv_interface/wrapper/frame_stack.py +242 -0
  83. unienv-0.0.1b1/unienv_interface/wrapper/gym_compat.py +299 -0
  84. unienv-0.0.1b1/unienv_interface/wrapper/time_limit.py +77 -0
  85. unienv-0.0.1b1/unienv_interface/wrapper/transformation.py +187 -0
  86. unienv-0.0.1b1/unienv_interface/wrapper/video_record.py +235 -0
  87. unienv-0.0.1b1/unienv_maniskill/__init__.py +1 -0
  88. unienv-0.0.1b1/unienv_maniskill/wrapper/maniskill_compat.py +235 -0
  89. unienv-0.0.1b1/unienv_mjxplayground/__init__.py +1 -0
  90. unienv-0.0.1b1/unienv_mjxplayground/wrapper/playground_compat.py +256 -0
unienv-0.0.1b1/LICENSE ADDED
@@ -0,0 +1,23 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2016 OpenAI
4
+ Copyright (c) 2022 Farama Foundation
5
+ Copyright (c) 2024 Yunhao Cao
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.4
2
+ Name: unienv
3
+ Version: 0.0.1b1
4
+ Requires-Python: >=3.10
5
+ License-File: LICENSE
6
+ Requires-Dist: numpy
7
+ Requires-Dist: xbarray>=0.0.1a8
8
+ Requires-Dist: pillow
9
+ Requires-Dist: h5py
10
+ Provides-Extra: dev
11
+ Requires-Dist: pytest; extra == "dev"
12
+ Provides-Extra: gymnasium
13
+ Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
14
+ Provides-Extra: video
15
+ Requires-Dist: moviepy>=2.1; extra == "video"
16
+ Provides-Extra: mjx
17
+ Requires-Dist: playground; extra == "mjx"
18
+ Provides-Extra: maniskill
19
+ Requires-Dist: mani_skill>=3.0.0b12; extra == "maniskill"
20
+ Dynamic: license-file
@@ -0,0 +1,46 @@
1
+ # UniEnvPy
2
+
3
+ TLDR: Gymnasium Library replacement with support for multiple tensor backends
4
+
5
+ Provides an universal interface for single / parallel state-based or function-based environments. Also contains a set of utilities (such as replay buffers, wrappers, etc.) to facilitate the training of reinforcement learning agents.
6
+
7
+ ## Cross-backend Support
8
+
9
+ UniEnvPy supports multiple tensor backends with zero-copy translation layers through the DLPack protocol, and allows you to use the same abstract compute backend interface to write custom data transformation layers, environment wrappers and other utilities. This is powered by the [xbarray](https://github.com/realquantumcookie/xbarray) package, which builts on top of the Array API Standard, and supports the following backends:
10
+
11
+ - numpy
12
+ - pytorch
13
+ - jax
14
+
15
+ We also support diverse simulation environments and real robots, built on top of the abstract environment / world interface. This allows you to reuse code across different sim and real robots.
16
+
17
+ Current supported simulation environments:
18
+ - Any Environment defined in Gymnasium interface
19
+ - <s>Mujoco</s> (New code will be added in the future, but I'm currently working on refractoring World based environments)
20
+ - MJX based on [Mujoco-Playground](https://github.com/google-deepmind/mujoco_playground)
21
+ - [ManiSkill 3](https://github.com/haosulab/ManiSkill/)
22
+
23
+ Current supported real robots:
24
+ - Franka Research 3 + RobotiQ Gripper in Droid Setup
25
+ - OyMotion OHand
26
+
27
+ ## Installation
28
+
29
+ Clone down this repo
30
+
31
+ ```bash
32
+ git clone https://github.com/realquantumcookie/UniEnvPy
33
+ cd UniEnvPy
34
+ ```
35
+
36
+ Install the package with pip
37
+
38
+ ```bash
39
+ pip install -e .
40
+ ```
41
+
42
+ You can install optional dependencies such as `gymnasium`, `mjx`, `maniskill`, `video` by running
43
+
44
+ ```bash
45
+ pip install -e .[gymnasium,mjx,maniskill,video]
46
+ ```
@@ -0,0 +1,31 @@
1
+ [project]
2
+ name = "unienv"
3
+ version = "0.0.1b1"
4
+ requires-python = ">= 3.10"
5
+ dependencies = [
6
+ "numpy",
7
+ "xbarray>=0.0.1a8",
8
+ "pillow",
9
+ "h5py",
10
+ ]
11
+
12
+ [project.optional-dependencies]
13
+ dev = [
14
+ "pytest",
15
+ ]
16
+ gymnasium = [
17
+ "gymnasium>=0.29.0",
18
+ ]
19
+ video = [
20
+ "moviepy>=2.1"
21
+ ]
22
+ mjx = [
23
+ "playground",
24
+ ]
25
+ maniskill = [
26
+ "mani_skill>=3.0.0b12"
27
+ ]
28
+
29
+ [tool.setuptools.packages.find]
30
+ include = ["*"]
31
+ exclude = ["training*", "tests*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.4
2
+ Name: unienv
3
+ Version: 0.0.1b1
4
+ Requires-Python: >=3.10
5
+ License-File: LICENSE
6
+ Requires-Dist: numpy
7
+ Requires-Dist: xbarray>=0.0.1a8
8
+ Requires-Dist: pillow
9
+ Requires-Dist: h5py
10
+ Provides-Extra: dev
11
+ Requires-Dist: pytest; extra == "dev"
12
+ Provides-Extra: gymnasium
13
+ Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
14
+ Provides-Extra: video
15
+ Requires-Dist: moviepy>=2.1; extra == "video"
16
+ Provides-Extra: mjx
17
+ Requires-Dist: playground; extra == "mjx"
18
+ Provides-Extra: maniskill
19
+ Requires-Dist: mani_skill>=3.0.0b12; extra == "maniskill"
20
+ Dynamic: license-file
@@ -0,0 +1,88 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ unienv.egg-info/PKG-INFO
5
+ unienv.egg-info/SOURCES.txt
6
+ unienv.egg-info/dependency_links.txt
7
+ unienv.egg-info/requires.txt
8
+ unienv.egg-info/top_level.txt
9
+ unienv_data/__init__.py
10
+ unienv_data/base/__init__.py
11
+ unienv_data/base/common.py
12
+ unienv_data/base/storage.py
13
+ unienv_data/base/transformations.py
14
+ unienv_data/batches/__init__.py
15
+ unienv_data/batches/backend_compat.py
16
+ unienv_data/batches/combined_batch.py
17
+ unienv_data/batches/framestack_batch.py
18
+ unienv_data/batches/slicestack_batch.py
19
+ unienv_data/integrations/pytorch.py
20
+ unienv_data/replay_buffer/__init__.py
21
+ unienv_data/replay_buffer/replay_buffer.py
22
+ unienv_data/replay_buffer/trajectory_replay_buffer.py
23
+ unienv_data/samplers/__init__.py
24
+ unienv_data/samplers/multiprocessing_sampler.py
25
+ unienv_data/samplers/slice_sampler.py
26
+ unienv_data/samplers/step_sampler.py
27
+ unienv_data/storages/common.py
28
+ unienv_data/storages/hdf5.py
29
+ unienv_data/storages/pytorch.py
30
+ unienv_interface/__init__.py
31
+ unienv_interface/backends/__init__.py
32
+ unienv_interface/backends/base.py
33
+ unienv_interface/backends/jax.py
34
+ unienv_interface/backends/numpy.py
35
+ unienv_interface/backends/pytorch.py
36
+ unienv_interface/backends/serialization.py
37
+ unienv_interface/env_base/__init__.py
38
+ unienv_interface/env_base/env.py
39
+ unienv_interface/env_base/funcenv.py
40
+ unienv_interface/env_base/funcenv_wrapper.py
41
+ unienv_interface/env_base/wrapper.py
42
+ unienv_interface/func_wrapper/__init__.py
43
+ unienv_interface/func_wrapper/transformation.py
44
+ unienv_interface/space/__init__.py
45
+ unienv_interface/space/space.py
46
+ unienv_interface/space/space_utils/__init__.py
47
+ unienv_interface/space/space_utils/batch_utils.py
48
+ unienv_interface/space/space_utils/flatten_utils.py
49
+ unienv_interface/space/space_utils/gym_utils.py
50
+ unienv_interface/space/space_utils/serialization_utils.py
51
+ unienv_interface/space/spaces/__init__.py
52
+ unienv_interface/space/spaces/binary.py
53
+ unienv_interface/space/spaces/box.py
54
+ unienv_interface/space/spaces/dict.py
55
+ unienv_interface/space/spaces/dynamic_box.py
56
+ unienv_interface/space/spaces/graph.py
57
+ unienv_interface/space/spaces/text.py
58
+ unienv_interface/space/spaces/tuple.py
59
+ unienv_interface/space/spaces/union.py
60
+ unienv_interface/transformations/__init__.py
61
+ unienv_interface/transformations/batch_and_unbatch.py
62
+ unienv_interface/transformations/chained_transform.py
63
+ unienv_interface/transformations/dict_transform.py
64
+ unienv_interface/transformations/filter_dict.py
65
+ unienv_interface/transformations/rescale.py
66
+ unienv_interface/transformations/transformation.py
67
+ unienv_interface/utils/seed_util.py
68
+ unienv_interface/utils/symbol_util.py
69
+ unienv_interface/world/__init__.py
70
+ unienv_interface/world/funcnode.py
71
+ unienv_interface/world/funcworld.py
72
+ unienv_interface/world/node.py
73
+ unienv_interface/world/world.py
74
+ unienv_interface/wrapper/__init__.py
75
+ unienv_interface/wrapper/action_rescale.py
76
+ unienv_interface/wrapper/backend_compat.py
77
+ unienv_interface/wrapper/batch_and_unbatch.py
78
+ unienv_interface/wrapper/control_frequency_limit.py
79
+ unienv_interface/wrapper/flatten.py
80
+ unienv_interface/wrapper/frame_stack.py
81
+ unienv_interface/wrapper/gym_compat.py
82
+ unienv_interface/wrapper/time_limit.py
83
+ unienv_interface/wrapper/transformation.py
84
+ unienv_interface/wrapper/video_record.py
85
+ unienv_maniskill/__init__.py
86
+ unienv_maniskill/wrapper/maniskill_compat.py
87
+ unienv_mjxplayground/__init__.py
88
+ unienv_mjxplayground/wrapper/playground_compat.py
@@ -0,0 +1,19 @@
1
+ numpy
2
+ xbarray>=0.0.1a8
3
+ pillow
4
+ h5py
5
+
6
+ [dev]
7
+ pytest
8
+
9
+ [gymnasium]
10
+ gymnasium>=0.29.0
11
+
12
+ [maniskill]
13
+ mani_skill>=3.0.0b12
14
+
15
+ [mjx]
16
+ playground
17
+
18
+ [video]
19
+ moviepy>=2.1
@@ -0,0 +1,6 @@
1
+ dist
2
+ unienv_data
3
+ unienv_interface
4
+ unienv_maniskill
5
+ unienv_mjxplayground
6
+ unienv_mujoco
@@ -0,0 +1,4 @@
1
+ from .base import *
2
+ from .batches import *
3
+ from .samplers import *
4
+ from .replay_buffer import *
@@ -0,0 +1,3 @@
1
+ from .common import BatchT, SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType, BatchBase, BatchSampler, IndexableType
2
+ from .transformations import TransformedBatch
3
+ from .storage import SpaceStorage
@@ -0,0 +1,228 @@
1
+ from typing import List, Tuple, Union, Dict, Any, Optional, Generic, TypeVar, Iterable, Iterator
2
+ from types import EllipsisType
3
+ import os
4
+ import abc
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from unienv_interface.env_base.env import ContextType, ObsType, ActType
7
+ from unienv_interface.space import Space, BoxSpace, DictSpace
8
+ import dataclasses
9
+
10
+ from unienv_interface.space.space_utils import batch_utils as space_batch_utils, flatten_utils as space_flatten_utils
11
+
12
+ IndexableType = Union[int, slice, EllipsisType]
13
+
14
+ BatchT = TypeVar('BatchT')
15
+ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
16
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
17
+ device: Optional[BDeviceType] = None
18
+
19
+ # If the batch is mutable, then the data can be changed (extend_*, set_*, remove_*, etc.)
20
+ is_mutable: bool = True
21
+
22
+ def __init__(
23
+ self,
24
+ single_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
25
+ single_metadata_space : Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]] = None,
26
+ ):
27
+ self.single_space = single_space
28
+ self.single_metadata_space = single_metadata_space
29
+ self._batched_space : Space[
30
+ BatchT, Any, BDeviceType, BDtypeType, BRNGType
31
+ ] = space_batch_utils.batch_space(single_space, 1)
32
+ if single_metadata_space is not None:
33
+ self._batched_metadata_space : DictSpace[
34
+ BDeviceType, BDtypeType, BRNGType
35
+ ] = space_batch_utils.batch_space(single_metadata_space, 1)
36
+ else:
37
+ self._batched_metadata_space = None
38
+
39
+ @abc.abstractmethod
40
+ def __len__(self) -> int:
41
+ raise NotImplementedError
42
+
43
+ def get_flattened_at(self, idx : Union[IndexableType, BArrayType]) -> BArrayType:
44
+ return self.get_flattened_at_with_metadata(idx)[0]
45
+
46
+ @abc.abstractmethod
47
+ def get_flattened_at_with_metadata(
48
+ self, idx : Union[IndexableType, BArrayType]
49
+ ) -> Tuple[BArrayType, Optional[Dict[str, Any]]]:
50
+ raise NotImplementedError
51
+
52
+ def set_flattened_at(self, idx : Union[IndexableType, BArrayType], value : BArrayType) -> None:
53
+ raise NotImplementedError
54
+
55
+ def extend_flattened(self, value : BArrayType) -> None:
56
+ raise NotImplementedError
57
+
58
+ def get_at(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
59
+ flattened_data = self.get_flattened_at(idx)
60
+ if isinstance(idx, int):
61
+ return space_flatten_utils.unflatten_data(self.single_space, flattened_data)
62
+ else:
63
+ return space_flatten_utils.unflatten_data(self._batched_space, flattened_data, start_dim=1)
64
+
65
+ def get_at_with_metadata(
66
+ self, idx : Union[IndexableType, BArrayType]
67
+ ) -> Tuple[BatchT, Optional[Dict[str, Any]]]:
68
+ flattened_data, metadata = self.get_flattened_at_with_metadata(idx)
69
+ if isinstance(idx, int):
70
+ return space_flatten_utils.unflatten_data(self.single_space, flattened_data), metadata
71
+ else:
72
+ return space_flatten_utils.unflatten_data(self._batched_space, flattened_data, start_dim=1), metadata
73
+
74
+ def __getitem__(self, idx : Union[IndexableType, BArrayType]) -> BatchT:
75
+ return self.get_at(idx)
76
+
77
+ def set_at(self, idx : Union[IndexableType, BArrayType], value : BatchT) -> None:
78
+ if isinstance(idx, int):
79
+ flattened_data = space_flatten_utils.flatten_data(self.single_space, value)
80
+ else:
81
+ flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
82
+ self.set_flattened_at(idx, flattened_data)
83
+
84
+ def __setitem__(self, idx : Union[IndexableType, BArrayType], value : BatchT) -> None:
85
+ self.set_at(idx, value)
86
+
87
+ def remove_at(self, idx : Union[IndexableType, BArrayType]) -> None:
88
+ raise NotImplementedError
89
+
90
+ def __delitem__(self, idx : Union[IndexableType, BArrayType]) -> None:
91
+ self.remove_at(idx)
92
+
93
+ def extend(self, value : BatchT) -> None:
94
+ flattened_data = space_flatten_utils.flatten_data(self._batched_space, value, start_dim=1)
95
+ self.extend_flattened(flattened_data)
96
+
97
+ def close(self) -> None:
98
+ pass
99
+
100
+ def __del__(self) -> None:
101
+ self.close()
102
+
103
+ SamplerBatchT = TypeVar('SamplerBatchT')
104
+ SamplerArrayType = TypeVar('SamplerArrayType')
105
+ SamplerDeviceType = TypeVar('SamplerDeviceType')
106
+ SamplerDtypeType = TypeVar('SamplerDtypeType')
107
+ SamplerRNGType = TypeVar('SamplerRNGType')
108
+ class BatchSampler(abc.ABC, Generic[
109
+ SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType,
110
+ BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
111
+ ]):
112
+ batch_size : int
113
+ sampled_space : Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
114
+ sampled_space_flat : BoxSpace[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
115
+ sampled_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = None
116
+
117
+ backend : ComputeBackend[SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]
118
+ device : Optional[SamplerDeviceType] = None
119
+
120
+ data : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
121
+
122
+ rng : Optional[SamplerRNGType] = None
123
+ data_rng : Optional[BRNGType] = None
124
+
125
+ def get_flat_at(self, idx : SamplerArrayType) -> SamplerArrayType:
126
+ return self.get_flat_at_with_metadata(idx)[0]
127
+
128
+ @abc.abstractmethod
129
+ def get_flat_at_with_metadata(
130
+ self, idx : SamplerArrayType
131
+ ) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
132
+ raise NotImplementedError
133
+
134
+ def get_at(self, idx : SamplerArrayType) -> SamplerBatchT:
135
+ return space_flatten_utils.unflatten_data(self.sampled_space, self.get_flat_at(idx), start_dim=1)
136
+
137
+ def get_at_with_metadata(
138
+ self, idx : SamplerArrayType
139
+ ) -> Tuple[SamplerBatchT, Optional[Dict[str, Any]]]:
140
+ flat_data, metadata = self.get_flat_at_with_metadata(idx)
141
+ return space_flatten_utils.unflatten_data(self.sampled_space, flat_data, start_dim=1), metadata
142
+
143
+ def sample_index(self) -> SamplerArrayType:
144
+ new_rng, indices = self.backend.random.random_discrete_uniform( # (B, )
145
+ (self.batch_size,),
146
+ 0,
147
+ len(self.data),
148
+ rng=self.data_rng if self.data_rng is not None else self.rng,
149
+ device=self.data.device,
150
+ )
151
+ if self.data_rng is not None:
152
+ self.data_rng = new_rng
153
+ else:
154
+ self.rng = new_rng
155
+ return indices
156
+
157
+ def sample_flat(self) -> SamplerArrayType:
158
+ idx = self.sample_index()
159
+ return self.get_flat_at(idx)
160
+
161
+ def sample_flat_with_metadata(self) -> Tuple[SamplerArrayType, Optional[Dict[str, Any]]]:
162
+ idx = self.sample_index()
163
+ return self.get_flat_at_with_metadata(idx)
164
+
165
+ def sample(self) -> SamplerBatchT:
166
+ idx = self.sample_index()
167
+ return self.get_at(idx)
168
+
169
+ def sample_with_metadata(self) -> Tuple[SamplerBatchT, Optional[Dict[str, Any]]]:
170
+ idx = self.sample_index()
171
+ return self.get_at_with_metadata(idx)
172
+
173
+ def __iter__(self) -> Iterator[SamplerBatchT]:
174
+ return self.epoch_iter()
175
+
176
+ def epoch_iter(self) -> Iterator[SamplerBatchT]:
177
+ if self.data_rng is not None:
178
+ self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
179
+ else:
180
+ self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
181
+ n_batches = len(self.data) // self.batch_size
182
+ num_left = len(self.data) % self.batch_size
183
+ for i in range(n_batches):
184
+ yield self.get_at(idx[i*self.batch_size:(i+1)*self.batch_size])
185
+ if num_left > 0:
186
+ yield self.get_at(idx[-num_left:])
187
+
188
+ def epoch_iter_with_metadata(self) -> Iterator[Tuple[SamplerBatchT, Optional[Dict[str, Any]]]]:
189
+ if self.data_rng is not None:
190
+ self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
191
+ else:
192
+ self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
193
+ n_batches = len(self.data) // self.batch_size
194
+ num_left = len(self.data) % self.batch_size
195
+ for i in range(n_batches):
196
+ yield self.get_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
197
+ if num_left > 0:
198
+ yield self.get_at_with_metadata(idx[-num_left:])
199
+
200
+ def epoch_flat_iter(self) -> Iterator[SamplerArrayType]:
201
+ if self.data_rng is not None:
202
+ self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
203
+ else:
204
+ self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
205
+ n_batches = len(self.data) // self.batch_size
206
+ num_left = len(self.data) % self.batch_size
207
+ for i in range(n_batches):
208
+ yield self.get_flat_at(idx[i*self.batch_size:(i+1)*self.batch_size])
209
+ if num_left > 0:
210
+ yield self.get_flat_at(idx[-num_left:])
211
+
212
+ def epoch_flat_iter_with_metadata(self) -> Iterator[Tuple[SamplerArrayType, Optional[Dict[str, Any]]]]:
213
+ if self.data_rng is not None:
214
+ self.data_rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.data_rng, device=self.data.device)
215
+ else:
216
+ self.rng, idx = self.backend.random.random_permutation(len(self.data), rng=self.rng, device=self.data.device)
217
+ n_batches = len(self.data) // self.batch_size
218
+ num_left = len(self.data) % self.batch_size
219
+ for i in range(n_batches):
220
+ yield self.get_flat_at_with_metadata(idx[i*self.batch_size:(i+1)*self.batch_size])
221
+ if num_left > 0:
222
+ yield self.get_flat_at_with_metadata(idx[-num_left:])
223
+
224
+ def close(self) -> None:
225
+ pass
226
+
227
+ def __del__(self) -> None:
228
+ self.close()
@@ -0,0 +1,132 @@
1
+ import abc
2
+ import os
3
+ from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
4
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
5
+
6
+ from unienv_interface.space import Space
7
+ from .common import BatchBase, BatchT, IndexableType
8
+
9
+ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
10
+ """
11
+ SpaceStorage is an abstract base class for storages that hold instances of a specific space.
12
+ It provides a common interface for creating, loading, and managing the storage of instances of a given space.
13
+ Note that if you want your space storage to support multiprocessing, you need to check / implement `__getstate__` and `__setstate__` methods to ensure that the storage can be pickled and unpickled correctly.
14
+ """
15
+ # ========== Class Creation and Loading Methods ==========
16
+ @classmethod
17
+ def create(
18
+ cls,
19
+ single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
20
+ *args,
21
+ capacity : Optional[int],
22
+ cache_path : Optional[Union[str, os.PathLike]] = None,
23
+ **kwargs
24
+ ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
25
+ raise NotImplementedError
26
+
27
+ @classmethod
28
+ def load_from(
29
+ cls,
30
+ path: Union[str, os.PathLike],
31
+ single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
32
+ *,
33
+ capacity : Optional[int] = None,
34
+ **kwargs
35
+ ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
36
+ raise NotImplementedError
37
+
38
+ # ========== Class Attributes ==========
39
+
40
+ """
41
+ The file extension (e.g. `.pt`) used for saving a single instance of the space.
42
+ If this is None, it means the storage stores files in a folder
43
+ """
44
+ single_file_ext : Optional[str] = None
45
+
46
+ # ======== Instance Attributes ==========
47
+ """
48
+ The total capacity (number of single instances) of the storage.
49
+ If None, the storage has unlimited capacity.
50
+ """
51
+ capacity : Optional[int] = None
52
+
53
+ """
54
+ The cache path for the storage.
55
+ If None, the storage will not use caching.
56
+ """
57
+ cache_filename : Optional[Union[str, os.PathLike]] = None
58
+
59
+ @property
60
+ def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
61
+ return self.single_instance_space.backend
62
+
63
+ @property
64
+ def device(self) -> Optional[BDeviceType]:
65
+ return self.single_instance_space.device
66
+
67
+ def __init__(
68
+ self,
69
+ single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
70
+ ):
71
+ self.single_instance_space = single_instance_space
72
+
73
+ def extend_length(self, length : int) -> None:
74
+ """
75
+ This is used by capacity = None storages to extend the length of the storage
76
+ If this is called on a storage with a fixed capacity, we will simply ignore the call.
77
+ """
78
+ pass
79
+
80
+ def shrink_length(self, length : int) -> None:
81
+ """
82
+ This is used by capacity = None storages to shrink the length of the storage
83
+ If this is called on a storage with a fixed capacity, we will simply ignore the call.
84
+ """
85
+ pass
86
+
87
+ def __len__(self) -> int:
88
+ """
89
+ Returns the number of instances in the storage
90
+ Storages with unlimited capacity should implement this method to return the current length of the storage.
91
+ """
92
+ if self.capacity is None:
93
+ raise NotImplementedError(f"__len__ is not implemented for class {type(self).__name__}")
94
+ return self.capacity
95
+
96
+ # We don't define them here, since they are optional and the `ReplayBuffer` checks if they are implemented
97
+ # by using hasattr(self, "get_flattened") and hasattr(self, "set_flattened").
98
+ # def get_flattened(self, index : Union[IndexableType, BArrayType]) -> BArrayType:
99
+ # raise NotImplementedError
100
+
101
+ # def set_flattened(self, index : Union[IndexableType, BArrayType], value : BArrayType) -> None:
102
+ # raise NotImplementedError
103
+
104
+ @abc.abstractmethod
105
+ def get(self, index : Union[IndexableType, BArrayType]) -> BatchT:
106
+ raise NotImplementedError
107
+
108
+ @abc.abstractmethod
109
+ def set(self, index : Union[IndexableType, BArrayType], value : BatchT) -> None:
110
+ raise NotImplementedError
111
+
112
+ def clear(self) -> None:
113
+ """
114
+ Clear all data inside the storage and set the length to 0 if the storage has unlimited capacity.
115
+ For storages with fixed capacity, this should reset the storage to its initial state.
116
+ """
117
+ if self.capacity is None:
118
+ self.shrink_length(len(self))
119
+
120
+ @abc.abstractmethod
121
+ def dumps(self, path : Union[str, os.PathLike]) -> None:
122
+ """
123
+ Dumps the storage to the specified path.
124
+ This is used for storages that have a single file extension (e.g. `.pt` for PyTorch).
125
+ """
126
+ raise NotImplementedError
127
+
128
+ def close(self) -> None:
129
+ pass
130
+
131
+ def __del__(self) -> None:
132
+ self.close()