unienv 0.0.1b4__tar.gz → 0.0.1b5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b4/unienv.egg-info → unienv-0.0.1b5}/PKG-INFO +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/pyproject.toml +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5/unienv.egg-info}/PKG-INFO +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv.egg-info/SOURCES.txt +3 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/base/storage.py +2 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/slicestack_batch.py +1 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/replay_buffer/replay_buffer.py +136 -65
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/storages/dict_storage.py +39 -7
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/storages/flattened.py +8 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/storages/hdf5.py +6 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/storages/pytorch.py +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/storages/transformation.py +16 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/transformations/image_compress.py +22 -9
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/func_wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/flatten_utils.py +8 -2
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/tuple.py +4 -4
- unienv-0.0.1b5/unienv_interface/transformations/image_resize.py +106 -0
- unienv-0.0.1b5/unienv_interface/transformations/iter_transform.py +92 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/utils/symbol_util.py +7 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4 → unienv-0.0.1b5}/LICENSE +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/README.md +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/setup.cfg +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv.egg-info/dependency_links.txt +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv.egg-info/requires.txt +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv.egg-info/top_level.txt +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/base/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/base/common.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/backend_compat.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/combined_batch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/framestack_batch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/batches/transformations.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/integrations/pytorch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/replay_buffer/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/samplers/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_data/samplers/step_sampler.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/base.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/jax.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/numpy.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/pytorch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/backends/serialization.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/env.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/funcenv.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/vec_env.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/env_base/wrapper.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/func_wrapper/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/func_wrapper/transformation.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/batch_utils.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/construct_utils.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/gym_utils.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/batched.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/binary.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/box.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/dict.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/dynamic_box.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/graph.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/text.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/space/spaces/union.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/chained_transform.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/dict_transform.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/filter_dict.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/rescale.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/transformations/transformation.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/utils/control_util.py +0 -0
- /unienv-0.0.1b4/unienv_interface/utils/data_queue.py → /unienv-0.0.1b5/unienv_interface/utils/framestack_queue.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/utils/seed_util.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/utils/stateclass.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/utils/vec_util.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/combined_funcnode.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/combined_node.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/funcnode.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/funcworld.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/node.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/world/world.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/__init__.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/action_rescale.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/backend_compat.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/flatten.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/gym_compat.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/time_limit.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/transformation.py +0 -0
- {unienv-0.0.1b4 → unienv-0.0.1b5}/unienv_interface/wrapper/video_record.py +0 -0
|
@@ -68,10 +68,12 @@ unienv_interface/transformations/batch_and_unbatch.py
|
|
|
68
68
|
unienv_interface/transformations/chained_transform.py
|
|
69
69
|
unienv_interface/transformations/dict_transform.py
|
|
70
70
|
unienv_interface/transformations/filter_dict.py
|
|
71
|
+
unienv_interface/transformations/image_resize.py
|
|
72
|
+
unienv_interface/transformations/iter_transform.py
|
|
71
73
|
unienv_interface/transformations/rescale.py
|
|
72
74
|
unienv_interface/transformations/transformation.py
|
|
73
75
|
unienv_interface/utils/control_util.py
|
|
74
|
-
unienv_interface/utils/
|
|
76
|
+
unienv_interface/utils/framestack_queue.py
|
|
75
77
|
unienv_interface/utils/seed_util.py
|
|
76
78
|
unienv_interface/utils/stateclass.py
|
|
77
79
|
unienv_interface/utils/symbol_util.py
|
|
@@ -20,6 +20,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
20
20
|
*args,
|
|
21
21
|
capacity : Optional[int],
|
|
22
22
|
cache_path : Optional[Union[str, os.PathLike]] = None,
|
|
23
|
+
multiprocessing : bool = False,
|
|
23
24
|
**kwargs
|
|
24
25
|
) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
25
26
|
raise NotImplementedError
|
|
@@ -32,6 +33,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
|
|
|
32
33
|
*,
|
|
33
34
|
capacity : Optional[int] = None,
|
|
34
35
|
read_only : bool = True,
|
|
36
|
+
multiprocessing : bool = False,
|
|
35
37
|
**kwargs
|
|
36
38
|
) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
37
39
|
raise NotImplementedError
|
|
@@ -33,6 +33,7 @@ class SliceStackedBatch(BatchBase[
|
|
|
33
33
|
fill_invalid_data : bool = True,
|
|
34
34
|
stack_metadata : bool = False,
|
|
35
35
|
):
|
|
36
|
+
assert batch.backend.dtype_is_real_integer(fixed_offset.dtype), "Fixed offset must be an integer tensor"
|
|
36
37
|
assert len(fixed_offset.shape) == 1, "Fixed offset must be a 1D tensor"
|
|
37
38
|
assert fixed_offset.shape[0] > 0, "Fixed offset must have a positive length"
|
|
38
39
|
assert batch.backend.any(fixed_offset == 0), "There should be at least one zero in the fixed offset"
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import os
|
|
3
3
|
import dataclasses
|
|
4
|
+
import multiprocessing as mp
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
|
|
4
7
|
from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
|
|
5
8
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
9
|
|
|
@@ -51,7 +54,6 @@ def index_with_offset(
|
|
|
51
54
|
return data_index
|
|
52
55
|
|
|
53
56
|
class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
54
|
-
is_mutable = True
|
|
55
57
|
# =========== Class Attributes ==========
|
|
56
58
|
@staticmethod
|
|
57
59
|
def create(
|
|
@@ -60,6 +62,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
60
62
|
*args,
|
|
61
63
|
cache_path : Optional[Union[str, os.PathLike]] = None,
|
|
62
64
|
capacity : Optional[int] = None,
|
|
65
|
+
multiprocessing : bool = False,
|
|
63
66
|
**kwargs
|
|
64
67
|
) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
65
68
|
storage_path_relative = "storage" + (storage_cls.single_file_ext or "")
|
|
@@ -70,6 +73,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
70
73
|
*args,
|
|
71
74
|
cache_path=None if cache_path is None else os.path.join(cache_path, storage_path_relative),
|
|
72
75
|
capacity=capacity,
|
|
76
|
+
multiprocessing=multiprocessing,
|
|
73
77
|
**kwargs
|
|
74
78
|
)
|
|
75
79
|
return ReplayBuffer(
|
|
@@ -77,7 +81,8 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
77
81
|
storage_path_relative,
|
|
78
82
|
0,
|
|
79
83
|
0,
|
|
80
|
-
cache_path=cache_path
|
|
84
|
+
cache_path=cache_path,
|
|
85
|
+
multiprocessing=multiprocessing
|
|
81
86
|
)
|
|
82
87
|
|
|
83
88
|
@staticmethod
|
|
@@ -97,6 +102,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
97
102
|
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
98
103
|
device: Optional[BDeviceType] = None,
|
|
99
104
|
read_only : bool = True,
|
|
105
|
+
multiprocessing : bool = False,
|
|
100
106
|
**storage_kwargs
|
|
101
107
|
) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
102
108
|
with open(os.path.join(path, "metadata.json"), "r") as f:
|
|
@@ -118,52 +124,103 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
118
124
|
single_instance_space,
|
|
119
125
|
capacity=capacity,
|
|
120
126
|
read_only=read_only,
|
|
127
|
+
multiprocessing=multiprocessing,
|
|
121
128
|
**storage_kwargs
|
|
122
129
|
)
|
|
123
|
-
return ReplayBuffer(
|
|
130
|
+
return ReplayBuffer(
|
|
131
|
+
storage,
|
|
132
|
+
metadata["storage_path_relative"],
|
|
133
|
+
count,
|
|
134
|
+
offset,
|
|
135
|
+
cache_path=path,
|
|
136
|
+
multiprocessing=multiprocessing
|
|
137
|
+
)
|
|
124
138
|
|
|
125
139
|
# =========== Instance Attributes and Methods ==========
|
|
126
140
|
def dumps(self, path : Union[str, os.PathLike]):
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
+
with self._lock_scope():
|
|
142
|
+
os.makedirs(path, exist_ok=True)
|
|
143
|
+
storage_path = os.path.join(path, self.storage_path_relative)
|
|
144
|
+
self.storage.dumps(storage_path)
|
|
145
|
+
metadata = {
|
|
146
|
+
"type": __class__.__name__,
|
|
147
|
+
"count": self.count,
|
|
148
|
+
"offset": self.offset,
|
|
149
|
+
"capacity": self.storage.capacity,
|
|
150
|
+
"storage_cls": get_full_class_name(type(self.storage)),
|
|
151
|
+
"storage_path_relative": self.storage_path_relative,
|
|
152
|
+
"single_instance_space": bsu.space_to_json(self.storage.single_instance_space),
|
|
153
|
+
}
|
|
154
|
+
with open(os.path.join(path, "metadata.json"), "w") as f:
|
|
155
|
+
json.dump(metadata, f)
|
|
141
156
|
|
|
142
157
|
def __init__(
|
|
143
158
|
self,
|
|
144
159
|
storage : SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
145
|
-
storage_path_relative :
|
|
160
|
+
storage_path_relative : str,
|
|
146
161
|
count : int = 0,
|
|
147
162
|
offset : int = 0,
|
|
148
163
|
cache_path : Optional[Union[str, os.PathLike]] = None,
|
|
164
|
+
multiprocessing : bool = False,
|
|
149
165
|
):
|
|
150
166
|
self.storage = storage
|
|
151
|
-
self.
|
|
152
|
-
self.offset = offset
|
|
153
|
-
self.storage_path_relative = storage_path_relative
|
|
167
|
+
self._storage_path_relative = storage_path_relative
|
|
154
168
|
self._cache_path = cache_path
|
|
169
|
+
self._multiprocessing = multiprocessing
|
|
170
|
+
if multiprocessing:
|
|
171
|
+
assert storage.is_multiprocessing_safe, "Storage is not multiprocessing safe"
|
|
172
|
+
self._lock = mp.Lock()
|
|
173
|
+
self._count_value = mp.Value("q", int(count))
|
|
174
|
+
self._offset_value = mp.Value("q", int(offset))
|
|
175
|
+
else:
|
|
176
|
+
self._lock = None
|
|
177
|
+
self._count_value = int(count)
|
|
178
|
+
self._offset_value = int(offset)
|
|
179
|
+
|
|
155
180
|
super().__init__(
|
|
156
181
|
storage.single_instance_space,
|
|
157
182
|
None
|
|
158
183
|
)
|
|
159
184
|
|
|
185
|
+
def _lock_scope(self):
|
|
186
|
+
if self._lock is not None:
|
|
187
|
+
return self._lock
|
|
188
|
+
else:
|
|
189
|
+
return nullcontext()
|
|
190
|
+
|
|
160
191
|
@property
|
|
161
192
|
def cache_path(self) -> Optional[Union[str, os.PathLike]]:
|
|
162
193
|
return self._cache_path
|
|
163
194
|
|
|
195
|
+
@property
|
|
196
|
+
def storage_path_relative(self) -> str:
|
|
197
|
+
return self._storage_path_relative
|
|
198
|
+
|
|
164
199
|
def __len__(self) -> int:
|
|
165
200
|
return self.count
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def count(self) -> int:
|
|
204
|
+
return self._count_value.value if self._multiprocessing else self._count_value
|
|
205
|
+
|
|
206
|
+
@count.setter
|
|
207
|
+
def count(self, value: int) -> None:
|
|
208
|
+
if self._multiprocessing:
|
|
209
|
+
self._count_value.value = int(value)
|
|
210
|
+
else:
|
|
211
|
+
self._count_value = int(value)
|
|
166
212
|
|
|
213
|
+
@property
|
|
214
|
+
def offset(self) -> int:
|
|
215
|
+
return self._offset_value.value if self._multiprocessing else self._offset_value
|
|
216
|
+
|
|
217
|
+
@offset.setter
|
|
218
|
+
def offset(self, value: int) -> None:
|
|
219
|
+
if self._multiprocessing:
|
|
220
|
+
self._offset_value.value = int(value)
|
|
221
|
+
else:
|
|
222
|
+
self._offset_value = int(value)
|
|
223
|
+
|
|
167
224
|
@property
|
|
168
225
|
def capacity(self) -> Optional[int]:
|
|
169
226
|
return self.storage.capacity
|
|
@@ -176,12 +233,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
176
233
|
def device(self) -> Optional[BDeviceType]:
|
|
177
234
|
return self.storage.device
|
|
178
235
|
|
|
236
|
+
@property
|
|
237
|
+
def is_mutable(self) -> bool:
|
|
238
|
+
return self.storage.is_mutable
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
242
|
+
return self._multiprocessing
|
|
243
|
+
|
|
179
244
|
def get_flattened_at(self, idx):
|
|
180
245
|
return self.get_flattened_at_with_metadata(idx)[0]
|
|
181
246
|
|
|
182
247
|
def get_flattened_at_with_metadata(self, idx: Union[IndexableType, BArrayType]) -> BArrayType:
|
|
183
248
|
if hasattr(self.storage, "get_flattened"):
|
|
184
|
-
|
|
249
|
+
with self._lock_scope():
|
|
250
|
+
data = self.storage.get_flattened(idx)
|
|
185
251
|
return data, None
|
|
186
252
|
|
|
187
253
|
data, metadata = self.get_at_with_metadata(idx)
|
|
@@ -195,19 +261,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
195
261
|
return self.get_at_with_metadata(idx)[0]
|
|
196
262
|
|
|
197
263
|
def get_at_with_metadata(self, idx):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
264
|
+
with self._lock_scope():
|
|
265
|
+
data_index = index_with_offset(
|
|
266
|
+
self.backend,
|
|
267
|
+
idx,
|
|
268
|
+
self.count,
|
|
269
|
+
self.offset,
|
|
270
|
+
self.device
|
|
271
|
+
)
|
|
272
|
+
data = self.storage.get(data_index)
|
|
206
273
|
return data, None
|
|
207
274
|
|
|
208
275
|
def set_flattened_at(self, idx: Union[IndexableType, BArrayType], value: BArrayType) -> None:
|
|
209
276
|
if hasattr(self.storage, "set_flattened"):
|
|
210
|
-
self.
|
|
277
|
+
with self._lock_scope():
|
|
278
|
+
self.storage.set_flattened(idx, value)
|
|
211
279
|
return
|
|
212
280
|
|
|
213
281
|
if isinstance(idx, int):
|
|
@@ -217,13 +285,14 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
217
285
|
self.set_at(idx, value)
|
|
218
286
|
|
|
219
287
|
def set_at(self, idx, value):
|
|
220
|
-
self.
|
|
221
|
-
self.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
288
|
+
with self._lock_scope():
|
|
289
|
+
self.storage.set(index_with_offset(
|
|
290
|
+
self.backend,
|
|
291
|
+
idx,
|
|
292
|
+
self.count,
|
|
293
|
+
self.offset,
|
|
294
|
+
self.device
|
|
295
|
+
), value)
|
|
227
296
|
|
|
228
297
|
def extend_flattened(
|
|
229
298
|
self,
|
|
@@ -233,35 +302,37 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
233
302
|
self.extend(unflattened_data)
|
|
234
303
|
|
|
235
304
|
def extend(self, value):
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
305
|
+
with self._lock_scope():
|
|
306
|
+
B = sbu.batch_size_data(value)
|
|
307
|
+
if B == 0:
|
|
308
|
+
return
|
|
309
|
+
if self.capacity is None:
|
|
310
|
+
assert self.offset == 0, "Offset must be 0 when capacity is None"
|
|
311
|
+
self.storage.extend_length(B)
|
|
312
|
+
self.storage.set(slice(self.count, self.count + B), value)
|
|
313
|
+
self.count += B
|
|
314
|
+
return
|
|
315
|
+
|
|
316
|
+
# We have a fixed capacity, only keep the last `capacity` elements
|
|
317
|
+
if B >= self.capacity:
|
|
318
|
+
self.storage.set(Ellipsis, sbu.get_at(self._batched_space, value, slice(-self.capacity, None)))
|
|
319
|
+
self.count = self.capacity
|
|
320
|
+
self.offset = 0
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
# Otherwise, perform round-robin writes
|
|
324
|
+
indexes = (self.backend.arange(B, device=self.device) + self.offset + self.count) % self.capacity
|
|
325
|
+
self.storage.set(indexes, value)
|
|
326
|
+
outflow = max(0, self.count + B - self.capacity)
|
|
327
|
+
if outflow > 0:
|
|
328
|
+
self.offset = (self.offset + outflow) % self.capacity
|
|
329
|
+
self.count = min(self.count + B, self.capacity)
|
|
260
330
|
|
|
261
331
|
def clear(self):
|
|
262
|
-
self.
|
|
263
|
-
|
|
264
|
-
|
|
332
|
+
with self._lock_scope():
|
|
333
|
+
self.count = 0
|
|
334
|
+
self.offset = 0
|
|
335
|
+
self.storage.clear()
|
|
265
336
|
|
|
266
337
|
def close(self) -> None:
|
|
267
338
|
self.storage.close()
|