unienv 0.0.1b4__py3-none-any.whl → 0.0.1b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/METADATA +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/RECORD +23 -21
- unienv_data/base/storage.py +2 -0
- unienv_data/batches/slicestack_batch.py +1 -0
- unienv_data/replay_buffer/replay_buffer.py +136 -65
- unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- unienv_data/storages/dict_storage.py +39 -7
- unienv_data/storages/flattened.py +8 -1
- unienv_data/storages/hdf5.py +6 -0
- unienv_data/storages/pytorch.py +1 -1
- unienv_data/storages/transformation.py +16 -1
- unienv_data/transformations/image_compress.py +22 -9
- unienv_interface/func_wrapper/frame_stack.py +1 -1
- unienv_interface/space/space_utils/flatten_utils.py +8 -2
- unienv_interface/space/spaces/tuple.py +4 -4
- unienv_interface/transformations/image_resize.py +106 -0
- unienv_interface/transformations/iter_transform.py +92 -0
- unienv_interface/utils/symbol_util.py +7 -1
- unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/WHEEL +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/top_level.txt +0 -0
- /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import os
|
|
3
3
|
import dataclasses
|
|
4
|
+
import multiprocessing as mp
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
|
|
4
7
|
from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type, Mapping
|
|
5
8
|
from typing_extensions import TypedDict
|
|
6
9
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
@@ -22,8 +25,6 @@ class TrajectoryData(TypedDict, Generic[BatchT, EpisodeBatchT]):
|
|
|
22
25
|
episode_data : Optional[EpisodeBatchT]
|
|
23
26
|
|
|
24
27
|
class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BArrayType, BDeviceType, BDtypeType, BRNGType], Generic[BatchT, EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
25
|
-
is_mutable = True
|
|
26
|
-
|
|
27
28
|
# =========== Class Attributes ==========
|
|
28
29
|
@staticmethod
|
|
29
30
|
def create(
|
|
@@ -39,6 +40,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
39
40
|
episode_data_capacity : Optional[int] = None,
|
|
40
41
|
episode_data_storage_kwargs : Dict[str, Any] = {},
|
|
41
42
|
cache_path : Optional[Union[str, os.PathLike]] = None,
|
|
43
|
+
multiprocessing : bool = False,
|
|
42
44
|
**kwargs,
|
|
43
45
|
) -> "TrajectoryReplayBuffer[BatchT, EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
44
46
|
backend = step_data_instance_space.backend
|
|
@@ -47,6 +49,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
47
49
|
step_data_instance_space,
|
|
48
50
|
cache_path=None if cache_path is None else os.path.join(cache_path, "step_data"),
|
|
49
51
|
capacity=step_data_capacity,
|
|
52
|
+
multiprocessing=multiprocessing,
|
|
50
53
|
**kwargs
|
|
51
54
|
)
|
|
52
55
|
step_episode_id_kwargs = step_episode_id_storage_kwargs if step_episode_id_storage_cls is not None else kwargs
|
|
@@ -62,6 +65,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
62
65
|
),
|
|
63
66
|
cache_path=None if cache_path is None else os.path.join(cache_path, "step_episode_ids"),
|
|
64
67
|
capacity=step_episode_id_capacity,
|
|
68
|
+
multiprocessing=multiprocessing,
|
|
65
69
|
**step_episode_id_kwargs
|
|
66
70
|
)
|
|
67
71
|
if episode_data_instance_space is not None:
|
|
@@ -75,6 +79,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
75
79
|
episode_data_instance_space,
|
|
76
80
|
cache_path=None if cache_path is None else os.path.join(cache_path, "episode_data"),
|
|
77
81
|
capacity=episode_data_capacity,
|
|
82
|
+
multiprocessing=multiprocessing,
|
|
78
83
|
**episode_data_storage_kwargs
|
|
79
84
|
)
|
|
80
85
|
else:
|
|
@@ -85,6 +90,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
85
90
|
episode_data_buffer,
|
|
86
91
|
current_episode_id=0,
|
|
87
92
|
episode_id_to_index_map={} if episode_data_buffer is not None else None,
|
|
93
|
+
multiprocessing=multiprocessing,
|
|
88
94
|
)
|
|
89
95
|
|
|
90
96
|
@staticmethod
|
|
@@ -103,6 +109,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
103
109
|
*,
|
|
104
110
|
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
105
111
|
device: Optional[BDeviceType] = None,
|
|
112
|
+
read_only: bool = False,
|
|
113
|
+
multiprocessing: bool = False,
|
|
106
114
|
step_storage_kwargs: Dict[str, Any] = {},
|
|
107
115
|
step_episode_id_storage_kwargs: Dict[str, Any] = {},
|
|
108
116
|
episode_storage_kwargs: Dict[str, Any] = {},
|
|
@@ -118,6 +126,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
118
126
|
os.path.join(path, "step_data"),
|
|
119
127
|
backend=backend,
|
|
120
128
|
device=device,
|
|
129
|
+
read_only=read_only,
|
|
130
|
+
multiprocessing=multiprocessing,
|
|
121
131
|
**step_storage_kwargs
|
|
122
132
|
)
|
|
123
133
|
step_episode_id_storage_kwargs.update(storage_kwargs)
|
|
@@ -125,6 +135,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
125
135
|
os.path.join(path, "step_episode_ids"),
|
|
126
136
|
backend=backend,
|
|
127
137
|
device=device,
|
|
138
|
+
read_only=read_only,
|
|
139
|
+
multiprocessing=multiprocessing,
|
|
128
140
|
**step_episode_id_storage_kwargs
|
|
129
141
|
)
|
|
130
142
|
episode_data_buffer = None
|
|
@@ -134,6 +146,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
134
146
|
os.path.join(path, "episode_data"),
|
|
135
147
|
backend=backend,
|
|
136
148
|
device=device,
|
|
149
|
+
read_only=read_only,
|
|
150
|
+
multiprocessing=multiprocessing,
|
|
137
151
|
**episode_storage_kwargs
|
|
138
152
|
)
|
|
139
153
|
else:
|
|
@@ -144,7 +158,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
144
158
|
if episode_data_buffer is not None:
|
|
145
159
|
raw_map = metadata.get("episode_id_to_index_map")
|
|
146
160
|
episode_id_to_index_map = (
|
|
147
|
-
{int(k): v for k, v in raw_map.items()}
|
|
161
|
+
{int(k): int(v) for k, v in raw_map.items()}
|
|
148
162
|
if raw_map is not None else {}
|
|
149
163
|
)
|
|
150
164
|
else:
|
|
@@ -156,35 +170,37 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
156
170
|
episode_data_buffer,
|
|
157
171
|
current_episode_id=metadata["current_episode_id"],
|
|
158
172
|
episode_id_to_index_map=episode_id_to_index_map,
|
|
173
|
+
multiprocessing=multiprocessing,
|
|
159
174
|
)
|
|
160
175
|
|
|
161
176
|
# ========== Instance Attributes and Methods ==========
|
|
162
177
|
def dumps(self, path : Union[str, os.PathLike]):
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
episode_id_to_index_map
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
178
|
+
with self._lock_scope():
|
|
179
|
+
os.makedirs(path, exist_ok=True)
|
|
180
|
+
step_data_path = os.path.join(path, "step_data")
|
|
181
|
+
self.step_data_buffer.dumps(step_data_path)
|
|
182
|
+
step_episode_id_path = os.path.join(path, "step_episode_ids")
|
|
183
|
+
self.step_episode_id_buffer.dumps(step_episode_id_path)
|
|
184
|
+
if self.episode_data_buffer is not None:
|
|
185
|
+
episode_data_path = os.path.join(path, "episode_data")
|
|
186
|
+
self.episode_data_buffer.dumps(episode_data_path)
|
|
187
|
+
|
|
188
|
+
# Convert episode_id_to_index_map keys to strings for JSON serialization
|
|
189
|
+
if self.episode_id_to_index_map is not None:
|
|
190
|
+
episode_id_to_index_map = {
|
|
191
|
+
str(ep_id): idx
|
|
192
|
+
for ep_id, idx in self.episode_id_to_index_map.items()
|
|
193
|
+
}
|
|
194
|
+
else:
|
|
195
|
+
episode_id_to_index_map = None
|
|
180
196
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
197
|
+
metadata = {
|
|
198
|
+
"type": __class__.__name__,
|
|
199
|
+
"current_episode_id": self.current_episode_id,
|
|
200
|
+
"episode_id_to_index_map": episode_id_to_index_map,
|
|
201
|
+
}
|
|
202
|
+
with open(os.path.join(path, "metadata.json"), "w") as f:
|
|
203
|
+
json.dump(metadata, f)
|
|
188
204
|
|
|
189
205
|
def __init__(
|
|
190
206
|
self,
|
|
@@ -193,6 +209,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
193
209
|
episode_data_buffer : Optional[ReplayBuffer[EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
194
210
|
current_episode_id: int = 0,
|
|
195
211
|
episode_id_to_index_map : Optional[Dict[int, int]] = None,
|
|
212
|
+
multiprocessing: bool = False,
|
|
196
213
|
):
|
|
197
214
|
assert step_data_buffer.backend == step_episode_id_buffer.backend, \
|
|
198
215
|
"Step data buffer and step episode ID buffer must have the same backend."
|
|
@@ -232,8 +249,20 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
232
249
|
self.step_data_buffer = step_data_buffer
|
|
233
250
|
self.step_episode_id_buffer = step_episode_id_buffer
|
|
234
251
|
self.episode_data_buffer = episode_data_buffer
|
|
235
|
-
|
|
236
|
-
|
|
252
|
+
if multiprocessing:
|
|
253
|
+
self._current_episode_id = mp.Value('i', current_episode_id)
|
|
254
|
+
self.episode_id_to_index_map = mp.Manager().dict(episode_id_to_index_map) if episode_id_to_index_map is not None else None
|
|
255
|
+
self._lock = mp.RLock()
|
|
256
|
+
else:
|
|
257
|
+
self._current_episode_id = current_episode_id
|
|
258
|
+
self.episode_id_to_index_map = episode_id_to_index_map
|
|
259
|
+
self._lock = None
|
|
260
|
+
|
|
261
|
+
def _lock_scope(self):
|
|
262
|
+
if self._lock is not None:
|
|
263
|
+
return self._lock
|
|
264
|
+
else:
|
|
265
|
+
return nullcontext()
|
|
237
266
|
|
|
238
267
|
def __len__(self) -> int:
|
|
239
268
|
return len(self.step_episode_id_buffer)
|
|
@@ -242,6 +271,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
242
271
|
self,
|
|
243
272
|
episode_id : Union[int, BArrayType],
|
|
244
273
|
) -> Union[int, BArrayType]:
|
|
274
|
+
assert self.episode_data_buffer is not None, \
|
|
275
|
+
"Episode data buffer is not set. Cannot get episode data index."
|
|
245
276
|
if isinstance(episode_id, int):
|
|
246
277
|
return self.episode_id_to_index_map[episode_id]
|
|
247
278
|
else:
|
|
@@ -268,30 +299,59 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
268
299
|
def device(self) -> Optional[BDeviceType]:
|
|
269
300
|
return self.step_data_buffer.device
|
|
270
301
|
|
|
302
|
+
@property
|
|
303
|
+
def is_mutable(self) -> bool:
|
|
304
|
+
return (
|
|
305
|
+
self.step_data_buffer.is_mutable and
|
|
306
|
+
self.step_episode_id_buffer.is_mutable and
|
|
307
|
+
(self.episode_data_buffer.is_mutable if self.episode_data_buffer is not None else True)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
312
|
+
return (
|
|
313
|
+
self._lock is not None and
|
|
314
|
+
self.step_data_buffer.is_multiprocessing_safe and
|
|
315
|
+
self.step_episode_id_buffer.is_multiprocessing_safe and
|
|
316
|
+
(self.episode_data_buffer.is_multiprocessing_safe if self.episode_data_buffer is not None else True)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def current_episode_id(self) -> int:
|
|
321
|
+
return self._current_episode_id if isinstance(self._current_episode_id, int) else self._current_episode_id.value
|
|
322
|
+
|
|
323
|
+
@current_episode_id.setter
|
|
324
|
+
def current_episode_id(self, value: int):
|
|
325
|
+
if isinstance(self._current_episode_id, int):
|
|
326
|
+
self._current_episode_id = value
|
|
327
|
+
else:
|
|
328
|
+
self._current_episode_id.value = value
|
|
329
|
+
|
|
271
330
|
def get_flattened_at(self, idx):
|
|
272
331
|
return self.get_flattened_at_with_metadata(idx)[0]
|
|
273
332
|
|
|
274
333
|
def get_flattened_at_with_metadata(self, idx):
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
self.
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
334
|
+
with self._lock_scope():
|
|
335
|
+
episode_ids = self.step_episode_id_buffer.get_at(idx)
|
|
336
|
+
step_data_flat, metadata = self.step_data_buffer.get_flattened_at_with_metadata(idx)
|
|
337
|
+
|
|
338
|
+
# We do some tricks knowing how flat data is layed out for dictionary space
|
|
339
|
+
if self.episode_data_buffer is not None:
|
|
340
|
+
episode_data_flat = self.episode_data_buffer.get_flattened_at(
|
|
341
|
+
self.episode_id_to_episode_data_index(episode_ids)
|
|
342
|
+
)
|
|
343
|
+
if isinstance(idx, int):
|
|
344
|
+
data_flat = self.backend.concat([
|
|
345
|
+
step_data_flat,
|
|
346
|
+
episode_data_flat
|
|
347
|
+
], axis=0)
|
|
348
|
+
else:
|
|
349
|
+
data_flat = self.backend.concat([
|
|
350
|
+
step_data_flat,
|
|
351
|
+
episode_data_flat
|
|
352
|
+
], axis=1)
|
|
288
353
|
else:
|
|
289
|
-
data_flat =
|
|
290
|
-
step_data_flat,
|
|
291
|
-
episode_data_flat
|
|
292
|
-
], axis=1)
|
|
293
|
-
else:
|
|
294
|
-
data_flat = step_data_flat
|
|
354
|
+
data_flat = step_data_flat
|
|
295
355
|
|
|
296
356
|
metadata = {} if metadata is None else copy.copy(metadata)
|
|
297
357
|
metadata["episode_ids"] = episode_ids
|
|
@@ -301,17 +361,18 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
301
361
|
return self.get_at_with_metadata(idx)[0]
|
|
302
362
|
|
|
303
363
|
def get_at_with_metadata(self, idx):
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
self.
|
|
313
|
-
|
|
314
|
-
|
|
364
|
+
with self._lock_scope():
|
|
365
|
+
episode_ids = self.step_episode_id_buffer.get_at(idx)
|
|
366
|
+
step_data, metadata = self.step_data_buffer.get_at_with_metadata(idx)
|
|
367
|
+
|
|
368
|
+
data = {
|
|
369
|
+
"step_data": step_data,
|
|
370
|
+
}
|
|
371
|
+
if self.episode_data_buffer is not None:
|
|
372
|
+
episode_data = self.episode_data_buffer.get_at(
|
|
373
|
+
self.episode_id_to_episode_data_index(episode_ids)
|
|
374
|
+
)
|
|
375
|
+
data["episode_data"] = episode_data
|
|
315
376
|
|
|
316
377
|
metadata = {} if metadata is None else copy.copy(metadata)
|
|
317
378
|
metadata["episode_ids"] = episode_ids
|
|
@@ -328,25 +389,26 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
328
389
|
self.set_at(idx, unflat_data)
|
|
329
390
|
|
|
330
391
|
def set_at(self, idx, value):
|
|
331
|
-
|
|
332
|
-
|
|
392
|
+
with self._lock_scope():
|
|
393
|
+
if "episode_ids" in value:
|
|
394
|
+
self.step_episode_id_buffer.set_at(idx, value['episode_ids'])
|
|
395
|
+
|
|
396
|
+
if "step_data" in value:
|
|
397
|
+
step_data = value["step_data"]
|
|
398
|
+
self.step_data_buffer.set_at(idx, step_data)
|
|
333
399
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
self._batched_space['episode_data'],
|
|
346
|
-
episode_data,
|
|
347
|
-
unique_indices
|
|
400
|
+
if "episode_data" in value and self.episode_data_buffer is not None:
|
|
401
|
+
episode_ids = value["episode_ids"] if "episode_ids" in value else self.step_episode_id_buffer.get_at(idx)
|
|
402
|
+
episode_data = value["episode_data"]
|
|
403
|
+
episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
|
|
404
|
+
self.set_episode_data_at(
|
|
405
|
+
episode_ids_unique,
|
|
406
|
+
sbu.get_at(
|
|
407
|
+
self._batched_space['episode_data'],
|
|
408
|
+
episode_data,
|
|
409
|
+
unique_indices
|
|
410
|
+
)
|
|
348
411
|
)
|
|
349
|
-
)
|
|
350
412
|
|
|
351
413
|
def set_episode_data_at(
|
|
352
414
|
self,
|
|
@@ -355,58 +417,60 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
355
417
|
) -> None:
|
|
356
418
|
assert self.episode_data_buffer is not None, \
|
|
357
419
|
"Episode data buffer is not set. Cannot set episode data."
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
"Episode data buffer is full. Cannot set episode data."
|
|
364
|
-
index = len(self.episode_id_to_index_map)
|
|
365
|
-
self.episode_id_to_index_map[episode_id] = index
|
|
366
|
-
self.episode_data_buffer.extend(
|
|
367
|
-
sbu.concatenate(
|
|
368
|
-
self._batched_space['episode_data'],
|
|
369
|
-
[value]
|
|
370
|
-
)
|
|
371
|
-
)
|
|
372
|
-
else:
|
|
373
|
-
assert self.backend.is_backendarray(episode_id), \
|
|
374
|
-
"Episode ID must be an integer or a backend array."
|
|
375
|
-
assert len(episode_id.shape) == 1, \
|
|
376
|
-
"Episode ID must be a 1D array."
|
|
377
|
-
|
|
378
|
-
valid_ids = [] # Stores (rb_index, index_in_batch) tuples
|
|
379
|
-
new_ids = [] # Stores (episode_id, index_in_batch) tuples
|
|
380
|
-
for i in range(episode_id.shape[0]):
|
|
381
|
-
ep_id = episode_id[i]
|
|
382
|
-
if ep_id in self.episode_id_to_index_map:
|
|
383
|
-
valid_ids.append((self.episode_id_to_index_map[ep_id], i))
|
|
420
|
+
|
|
421
|
+
with self._lock_scope():
|
|
422
|
+
if isinstance(episode_id, int):
|
|
423
|
+
if episode_id in self.episode_id_to_index_map:
|
|
424
|
+
index = self.episode_id_to_index_map[episode_id]
|
|
384
425
|
else:
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
sbu.concatenate(
|
|
395
|
-
self._batched_space['episode_data'],
|
|
396
|
-
[value[i] for _, i in new_ids]
|
|
426
|
+
assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) < self.episode_data_buffer.capacity, \
|
|
427
|
+
"Episode data buffer is full. Cannot set episode data."
|
|
428
|
+
index = len(self.episode_id_to_index_map)
|
|
429
|
+
self.episode_id_to_index_map[episode_id] = index
|
|
430
|
+
self.episode_data_buffer.extend(
|
|
431
|
+
sbu.concatenate(
|
|
432
|
+
self._batched_space['episode_data'],
|
|
433
|
+
[value]
|
|
434
|
+
)
|
|
397
435
|
)
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
436
|
+
else:
|
|
437
|
+
assert self.backend.is_backendarray(episode_id), \
|
|
438
|
+
"Episode ID must be an integer or a backend array."
|
|
439
|
+
assert len(episode_id.shape) == 1, \
|
|
440
|
+
"Episode ID must be a 1D array."
|
|
441
|
+
|
|
442
|
+
valid_ids = [] # Stores (rb_index, index_in_batch) tuples
|
|
443
|
+
new_ids = [] # Stores (episode_id, index_in_batch) tuples
|
|
444
|
+
for i in range(episode_id.shape[0]):
|
|
445
|
+
ep_id = episode_id[i]
|
|
446
|
+
if ep_id in self.episode_id_to_index_map:
|
|
447
|
+
valid_ids.append((self.episode_id_to_index_map[ep_id], i))
|
|
448
|
+
else:
|
|
449
|
+
new_ids.append((ep_id, i))
|
|
450
|
+
if len(new_ids) > 0:
|
|
451
|
+
assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) + len(new_ids) <= self.episode_data_buffer.capacity, \
|
|
452
|
+
"Episode data buffer is full. Cannot set episode data."
|
|
453
|
+
start_index = len(self.episode_id_to_index_map)
|
|
454
|
+
for ep_id, i in new_ids:
|
|
455
|
+
self.episode_id_to_index_map[ep_id] = start_index
|
|
456
|
+
start_index += 1
|
|
457
|
+
self.episode_data_buffer.extend(
|
|
458
|
+
sbu.concatenate(
|
|
459
|
+
self._batched_space['episode_data'],
|
|
460
|
+
[value[i] for _, i in new_ids]
|
|
461
|
+
)
|
|
462
|
+
)
|
|
463
|
+
if len(valid_ids) > 0:
|
|
464
|
+
rb_indices = self.backend.asarray([i for i, _ in valid_ids], device=self.device)
|
|
465
|
+
indices_in_batch = self.backend.asarray([i for _, i in valid_ids], device=self.device)
|
|
466
|
+
self.episode_data_buffer.set_at(
|
|
467
|
+
rb_indices,
|
|
468
|
+
sbu.get_at(
|
|
469
|
+
self._batched_space['episode_data'],
|
|
470
|
+
value,
|
|
471
|
+
indices_in_batch
|
|
472
|
+
)
|
|
408
473
|
)
|
|
409
|
-
)
|
|
410
474
|
|
|
411
475
|
def extend_flattened(self, value):
|
|
412
476
|
try:
|
|
@@ -419,38 +483,39 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
419
483
|
self.extend(unflattened_data)
|
|
420
484
|
|
|
421
485
|
def extend(self, value):
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
486
|
+
with self._lock_scope():
|
|
487
|
+
B = sbu.batch_size_data(value)
|
|
488
|
+
if B == 0:
|
|
489
|
+
return
|
|
490
|
+
|
|
491
|
+
if not isinstance(value, Mapping) or "step_data" not in value:
|
|
492
|
+
# If the value is not a mapping or does not contain "step_data", we assume it's a single step data
|
|
493
|
+
value = {
|
|
494
|
+
"step_data": value
|
|
495
|
+
}
|
|
431
496
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
497
|
+
if "episode_ids" in value:
|
|
498
|
+
episode_ids = value["episode_ids"]
|
|
499
|
+
else:
|
|
500
|
+
episode_ids = self.backend.full(
|
|
501
|
+
(B,),
|
|
502
|
+
self.current_episode_id,
|
|
503
|
+
dtype=self.backend.default_integer_dtype,
|
|
504
|
+
device=self.device
|
|
505
|
+
)
|
|
506
|
+
self.step_episode_id_buffer.extend(episode_ids)
|
|
507
|
+
self.step_data_buffer.extend(value["step_data"])
|
|
508
|
+
|
|
509
|
+
if "episode_data" in value and self.episode_data_buffer is not None:
|
|
510
|
+
episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
|
|
511
|
+
self.set_episode_data_at(
|
|
512
|
+
episode_ids_unique,
|
|
513
|
+
sbu.get_at(
|
|
514
|
+
self._batched_space['episode_data'],
|
|
515
|
+
value["episode_data"],
|
|
516
|
+
unique_indices
|
|
517
|
+
)
|
|
452
518
|
)
|
|
453
|
-
)
|
|
454
519
|
|
|
455
520
|
def set_current_episode_data(
|
|
456
521
|
self,
|
|
@@ -462,18 +527,20 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
|
|
|
462
527
|
)
|
|
463
528
|
|
|
464
529
|
def mark_episode_end(self) -> None:
|
|
465
|
-
self.
|
|
530
|
+
with self._lock_scope():
|
|
531
|
+
self.current_episode_id += 1
|
|
466
532
|
|
|
467
533
|
def clear(self):
|
|
468
|
-
self.
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
self.episode_data_buffer
|
|
472
|
-
|
|
473
|
-
|
|
534
|
+
with self._lock_scope():
|
|
535
|
+
self.step_data_buffer.clear()
|
|
536
|
+
self.step_episode_id_buffer.clear()
|
|
537
|
+
if self.episode_data_buffer is not None:
|
|
538
|
+
self.episode_data_buffer.clear()
|
|
539
|
+
self.episode_id_to_index_map = {}
|
|
540
|
+
self.current_episode_id = 0
|
|
474
541
|
|
|
475
542
|
def close(self):
|
|
476
543
|
self.step_data_buffer.close()
|
|
477
544
|
self.step_episode_id_buffer.close()
|
|
478
545
|
if self.episode_data_buffer is not None:
|
|
479
|
-
self.episode_data_buffer.close()
|
|
546
|
+
self.episode_data_buffer.close()
|