unienv 0.0.1b4__py3-none-any.whl → 0.0.1b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,9 @@
1
1
  import abc
2
2
  import os
3
3
  import dataclasses
4
+ import multiprocessing as mp
5
+ from contextlib import nullcontext
6
+
4
7
  from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type, Mapping
5
8
  from typing_extensions import TypedDict
6
9
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
@@ -22,8 +25,6 @@ class TrajectoryData(TypedDict, Generic[BatchT, EpisodeBatchT]):
22
25
  episode_data : Optional[EpisodeBatchT]
23
26
 
24
27
  class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BArrayType, BDeviceType, BDtypeType, BRNGType], Generic[BatchT, EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
25
- is_mutable = True
26
-
27
28
  # =========== Class Attributes ==========
28
29
  @staticmethod
29
30
  def create(
@@ -39,6 +40,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
39
40
  episode_data_capacity : Optional[int] = None,
40
41
  episode_data_storage_kwargs : Dict[str, Any] = {},
41
42
  cache_path : Optional[Union[str, os.PathLike]] = None,
43
+ multiprocessing : bool = False,
42
44
  **kwargs,
43
45
  ) -> "TrajectoryReplayBuffer[BatchT, EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
44
46
  backend = step_data_instance_space.backend
@@ -47,6 +49,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
47
49
  step_data_instance_space,
48
50
  cache_path=None if cache_path is None else os.path.join(cache_path, "step_data"),
49
51
  capacity=step_data_capacity,
52
+ multiprocessing=multiprocessing,
50
53
  **kwargs
51
54
  )
52
55
  step_episode_id_kwargs = step_episode_id_storage_kwargs if step_episode_id_storage_cls is not None else kwargs
@@ -62,6 +65,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
62
65
  ),
63
66
  cache_path=None if cache_path is None else os.path.join(cache_path, "step_episode_ids"),
64
67
  capacity=step_episode_id_capacity,
68
+ multiprocessing=multiprocessing,
65
69
  **step_episode_id_kwargs
66
70
  )
67
71
  if episode_data_instance_space is not None:
@@ -75,6 +79,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
75
79
  episode_data_instance_space,
76
80
  cache_path=None if cache_path is None else os.path.join(cache_path, "episode_data"),
77
81
  capacity=episode_data_capacity,
82
+ multiprocessing=multiprocessing,
78
83
  **episode_data_storage_kwargs
79
84
  )
80
85
  else:
@@ -85,6 +90,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
85
90
  episode_data_buffer,
86
91
  current_episode_id=0,
87
92
  episode_id_to_index_map={} if episode_data_buffer is not None else None,
93
+ multiprocessing=multiprocessing,
88
94
  )
89
95
 
90
96
  @staticmethod
@@ -103,6 +109,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
103
109
  *,
104
110
  backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
105
111
  device: Optional[BDeviceType] = None,
112
+ read_only: bool = False,
113
+ multiprocessing: bool = False,
106
114
  step_storage_kwargs: Dict[str, Any] = {},
107
115
  step_episode_id_storage_kwargs: Dict[str, Any] = {},
108
116
  episode_storage_kwargs: Dict[str, Any] = {},
@@ -118,6 +126,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
118
126
  os.path.join(path, "step_data"),
119
127
  backend=backend,
120
128
  device=device,
129
+ read_only=read_only,
130
+ multiprocessing=multiprocessing,
121
131
  **step_storage_kwargs
122
132
  )
123
133
  step_episode_id_storage_kwargs.update(storage_kwargs)
@@ -125,6 +135,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
125
135
  os.path.join(path, "step_episode_ids"),
126
136
  backend=backend,
127
137
  device=device,
138
+ read_only=read_only,
139
+ multiprocessing=multiprocessing,
128
140
  **step_episode_id_storage_kwargs
129
141
  )
130
142
  episode_data_buffer = None
@@ -134,6 +146,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
134
146
  os.path.join(path, "episode_data"),
135
147
  backend=backend,
136
148
  device=device,
149
+ read_only=read_only,
150
+ multiprocessing=multiprocessing,
137
151
  **episode_storage_kwargs
138
152
  )
139
153
  else:
@@ -144,7 +158,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
144
158
  if episode_data_buffer is not None:
145
159
  raw_map = metadata.get("episode_id_to_index_map")
146
160
  episode_id_to_index_map = (
147
- {int(k): v for k, v in raw_map.items()}
161
+ {int(k): int(v) for k, v in raw_map.items()}
148
162
  if raw_map is not None else {}
149
163
  )
150
164
  else:
@@ -156,35 +170,37 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
156
170
  episode_data_buffer,
157
171
  current_episode_id=metadata["current_episode_id"],
158
172
  episode_id_to_index_map=episode_id_to_index_map,
173
+ multiprocessing=multiprocessing,
159
174
  )
160
175
 
161
176
  # ========== Instance Attributes and Methods ==========
162
177
  def dumps(self, path : Union[str, os.PathLike]):
163
- os.makedirs(path, exist_ok=True)
164
- step_data_path = os.path.join(path, "step_data")
165
- self.step_data_buffer.dumps(step_data_path)
166
- step_episode_id_path = os.path.join(path, "step_episode_ids")
167
- self.step_episode_id_buffer.dumps(step_episode_id_path)
168
- if self.episode_data_buffer is not None:
169
- episode_data_path = os.path.join(path, "episode_data")
170
- self.episode_data_buffer.dumps(episode_data_path)
171
-
172
- # Convert episode_id_to_index_map keys to strings for JSON serialization
173
- if self.episode_id_to_index_map is not None:
174
- episode_id_to_index_map = {
175
- str(ep_id): idx
176
- for ep_id, idx in self.episode_id_to_index_map.items()
177
- }
178
- else:
179
- episode_id_to_index_map = None
178
+ with self._lock_scope():
179
+ os.makedirs(path, exist_ok=True)
180
+ step_data_path = os.path.join(path, "step_data")
181
+ self.step_data_buffer.dumps(step_data_path)
182
+ step_episode_id_path = os.path.join(path, "step_episode_ids")
183
+ self.step_episode_id_buffer.dumps(step_episode_id_path)
184
+ if self.episode_data_buffer is not None:
185
+ episode_data_path = os.path.join(path, "episode_data")
186
+ self.episode_data_buffer.dumps(episode_data_path)
187
+
188
+ # Convert episode_id_to_index_map keys to strings for JSON serialization
189
+ if self.episode_id_to_index_map is not None:
190
+ episode_id_to_index_map = {
191
+ str(ep_id): idx
192
+ for ep_id, idx in self.episode_id_to_index_map.items()
193
+ }
194
+ else:
195
+ episode_id_to_index_map = None
180
196
 
181
- metadata = {
182
- "type": __class__.__name__,
183
- "current_episode_id": self.current_episode_id,
184
- "episode_id_to_index_map": episode_id_to_index_map,
185
- }
186
- with open(os.path.join(path, "metadata.json"), "w") as f:
187
- json.dump(metadata, f)
197
+ metadata = {
198
+ "type": __class__.__name__,
199
+ "current_episode_id": self.current_episode_id,
200
+ "episode_id_to_index_map": episode_id_to_index_map,
201
+ }
202
+ with open(os.path.join(path, "metadata.json"), "w") as f:
203
+ json.dump(metadata, f)
188
204
 
189
205
  def __init__(
190
206
  self,
@@ -193,6 +209,7 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
193
209
  episode_data_buffer : Optional[ReplayBuffer[EpisodeBatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]],
194
210
  current_episode_id: int = 0,
195
211
  episode_id_to_index_map : Optional[Dict[int, int]] = None,
212
+ multiprocessing: bool = False,
196
213
  ):
197
214
  assert step_data_buffer.backend == step_episode_id_buffer.backend, \
198
215
  "Step data buffer and step episode ID buffer must have the same backend."
@@ -232,8 +249,20 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
232
249
  self.step_data_buffer = step_data_buffer
233
250
  self.step_episode_id_buffer = step_episode_id_buffer
234
251
  self.episode_data_buffer = episode_data_buffer
235
- self.current_episode_id = current_episode_id
236
- self.episode_id_to_index_map = episode_id_to_index_map
252
+ if multiprocessing:
253
+ self._current_episode_id = mp.Value('i', current_episode_id)
254
+ self.episode_id_to_index_map = mp.Manager().dict(episode_id_to_index_map) if episode_id_to_index_map is not None else None
255
+ self._lock = mp.RLock()
256
+ else:
257
+ self._current_episode_id = current_episode_id
258
+ self.episode_id_to_index_map = episode_id_to_index_map
259
+ self._lock = None
260
+
261
+ def _lock_scope(self):
262
+ if self._lock is not None:
263
+ return self._lock
264
+ else:
265
+ return nullcontext()
237
266
 
238
267
  def __len__(self) -> int:
239
268
  return len(self.step_episode_id_buffer)
@@ -242,6 +271,8 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
242
271
  self,
243
272
  episode_id : Union[int, BArrayType],
244
273
  ) -> Union[int, BArrayType]:
274
+ assert self.episode_data_buffer is not None, \
275
+ "Episode data buffer is not set. Cannot get episode data index."
245
276
  if isinstance(episode_id, int):
246
277
  return self.episode_id_to_index_map[episode_id]
247
278
  else:
@@ -268,30 +299,59 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
268
299
  def device(self) -> Optional[BDeviceType]:
269
300
  return self.step_data_buffer.device
270
301
 
302
+ @property
303
+ def is_mutable(self) -> bool:
304
+ return (
305
+ self.step_data_buffer.is_mutable and
306
+ self.step_episode_id_buffer.is_mutable and
307
+ (self.episode_data_buffer.is_mutable if self.episode_data_buffer is not None else True)
308
+ )
309
+
310
+ @property
311
+ def is_multiprocessing_safe(self) -> bool:
312
+ return (
313
+ self._lock is not None and
314
+ self.step_data_buffer.is_multiprocessing_safe and
315
+ self.step_episode_id_buffer.is_multiprocessing_safe and
316
+ (self.episode_data_buffer.is_multiprocessing_safe if self.episode_data_buffer is not None else True)
317
+ )
318
+
319
+ @property
320
+ def current_episode_id(self) -> int:
321
+ return self._current_episode_id if isinstance(self._current_episode_id, int) else self._current_episode_id.value
322
+
323
+ @current_episode_id.setter
324
+ def current_episode_id(self, value: int):
325
+ if isinstance(self._current_episode_id, int):
326
+ self._current_episode_id = value
327
+ else:
328
+ self._current_episode_id.value = value
329
+
271
330
  def get_flattened_at(self, idx):
272
331
  return self.get_flattened_at_with_metadata(idx)[0]
273
332
 
274
333
  def get_flattened_at_with_metadata(self, idx):
275
- episode_ids = self.step_episode_id_buffer.get_at(idx)
276
- step_data_flat, metadata = self.step_data_buffer.get_flattened_at_with_metadata(idx)
277
-
278
- # We do some tricks knowing how flat data is layed out for dictionary space
279
- if self.episode_data_buffer is not None:
280
- episode_data_flat = self.episode_data_buffer.get_flattened_at(
281
- self.episode_id_to_episode_data_index(episode_ids)
282
- )
283
- if isinstance(idx, int):
284
- data_flat = self.backend.concat([
285
- step_data_flat,
286
- episode_data_flat
287
- ], axis=0)
334
+ with self._lock_scope():
335
+ episode_ids = self.step_episode_id_buffer.get_at(idx)
336
+ step_data_flat, metadata = self.step_data_buffer.get_flattened_at_with_metadata(idx)
337
+
338
+ # We do some tricks knowing how flat data is layed out for dictionary space
339
+ if self.episode_data_buffer is not None:
340
+ episode_data_flat = self.episode_data_buffer.get_flattened_at(
341
+ self.episode_id_to_episode_data_index(episode_ids)
342
+ )
343
+ if isinstance(idx, int):
344
+ data_flat = self.backend.concat([
345
+ step_data_flat,
346
+ episode_data_flat
347
+ ], axis=0)
348
+ else:
349
+ data_flat = self.backend.concat([
350
+ step_data_flat,
351
+ episode_data_flat
352
+ ], axis=1)
288
353
  else:
289
- data_flat = self.backend.concat([
290
- step_data_flat,
291
- episode_data_flat
292
- ], axis=1)
293
- else:
294
- data_flat = step_data_flat
354
+ data_flat = step_data_flat
295
355
 
296
356
  metadata = {} if metadata is None else copy.copy(metadata)
297
357
  metadata["episode_ids"] = episode_ids
@@ -301,17 +361,18 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
301
361
  return self.get_at_with_metadata(idx)[0]
302
362
 
303
363
  def get_at_with_metadata(self, idx):
304
- episode_ids = self.step_episode_id_buffer.get_at(idx)
305
- step_data, metadata = self.step_data_buffer.get_at_with_metadata(idx)
306
-
307
- data = {
308
- "step_data": step_data,
309
- }
310
- if self.episode_data_buffer is not None:
311
- episode_data = self.episode_data_buffer.get_at(
312
- self.episode_id_to_episode_data_index(episode_ids)
313
- )
314
- data["episode_data"] = episode_data
364
+ with self._lock_scope():
365
+ episode_ids = self.step_episode_id_buffer.get_at(idx)
366
+ step_data, metadata = self.step_data_buffer.get_at_with_metadata(idx)
367
+
368
+ data = {
369
+ "step_data": step_data,
370
+ }
371
+ if self.episode_data_buffer is not None:
372
+ episode_data = self.episode_data_buffer.get_at(
373
+ self.episode_id_to_episode_data_index(episode_ids)
374
+ )
375
+ data["episode_data"] = episode_data
315
376
 
316
377
  metadata = {} if metadata is None else copy.copy(metadata)
317
378
  metadata["episode_ids"] = episode_ids
@@ -328,25 +389,26 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
328
389
  self.set_at(idx, unflat_data)
329
390
 
330
391
  def set_at(self, idx, value):
331
- if "episode_ids" in value:
332
- self.step_episode_id_buffer.set_at(idx, value['episode_ids'])
392
+ with self._lock_scope():
393
+ if "episode_ids" in value:
394
+ self.step_episode_id_buffer.set_at(idx, value['episode_ids'])
395
+
396
+ if "step_data" in value:
397
+ step_data = value["step_data"]
398
+ self.step_data_buffer.set_at(idx, step_data)
333
399
 
334
- if "step_data" in value:
335
- step_data = value["step_data"]
336
- self.step_data_buffer.set_at(idx, step_data)
337
-
338
- if "episode_data" in value and self.episode_data_buffer is not None:
339
- episode_ids = value["episode_ids"] if "episode_ids" in value else self.step_episode_id_buffer.get_at(idx)
340
- episode_data = value["episode_data"]
341
- episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
342
- self.set_episode_data_at(
343
- episode_ids_unique,
344
- sbu.get_at(
345
- self._batched_space['episode_data'],
346
- episode_data,
347
- unique_indices
400
+ if "episode_data" in value and self.episode_data_buffer is not None:
401
+ episode_ids = value["episode_ids"] if "episode_ids" in value else self.step_episode_id_buffer.get_at(idx)
402
+ episode_data = value["episode_data"]
403
+ episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
404
+ self.set_episode_data_at(
405
+ episode_ids_unique,
406
+ sbu.get_at(
407
+ self._batched_space['episode_data'],
408
+ episode_data,
409
+ unique_indices
410
+ )
348
411
  )
349
- )
350
412
 
351
413
  def set_episode_data_at(
352
414
  self,
@@ -355,58 +417,60 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
355
417
  ) -> None:
356
418
  assert self.episode_data_buffer is not None, \
357
419
  "Episode data buffer is not set. Cannot set episode data."
358
- if isinstance(episode_id, int):
359
- if episode_id in self.episode_id_to_index_map:
360
- index = self.episode_id_to_index_map[episode_id]
361
- else:
362
- assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) < self.episode_data_buffer.capacity, \
363
- "Episode data buffer is full. Cannot set episode data."
364
- index = len(self.episode_id_to_index_map)
365
- self.episode_id_to_index_map[episode_id] = index
366
- self.episode_data_buffer.extend(
367
- sbu.concatenate(
368
- self._batched_space['episode_data'],
369
- [value]
370
- )
371
- )
372
- else:
373
- assert self.backend.is_backendarray(episode_id), \
374
- "Episode ID must be an integer or a backend array."
375
- assert len(episode_id.shape) == 1, \
376
- "Episode ID must be a 1D array."
377
-
378
- valid_ids = [] # Stores (rb_index, index_in_batch) tuples
379
- new_ids = [] # Stores (episode_id, index_in_batch) tuples
380
- for i in range(episode_id.shape[0]):
381
- ep_id = episode_id[i]
382
- if ep_id in self.episode_id_to_index_map:
383
- valid_ids.append((self.episode_id_to_index_map[ep_id], i))
420
+
421
+ with self._lock_scope():
422
+ if isinstance(episode_id, int):
423
+ if episode_id in self.episode_id_to_index_map:
424
+ index = self.episode_id_to_index_map[episode_id]
384
425
  else:
385
- new_ids.append((ep_id, i))
386
- if len(new_ids) > 0:
387
- assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) + len(new_ids) <= self.episode_data_buffer.capacity, \
388
- "Episode data buffer is full. Cannot set episode data."
389
- start_index = len(self.episode_id_to_index_map)
390
- for ep_id, i in new_ids:
391
- self.episode_id_to_index_map[ep_id] = start_index
392
- start_index += 1
393
- self.episode_data_buffer.extend(
394
- sbu.concatenate(
395
- self._batched_space['episode_data'],
396
- [value[i] for _, i in new_ids]
426
+ assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) < self.episode_data_buffer.capacity, \
427
+ "Episode data buffer is full. Cannot set episode data."
428
+ index = len(self.episode_id_to_index_map)
429
+ self.episode_id_to_index_map[episode_id] = index
430
+ self.episode_data_buffer.extend(
431
+ sbu.concatenate(
432
+ self._batched_space['episode_data'],
433
+ [value]
434
+ )
397
435
  )
398
- )
399
- if len(valid_ids) > 0:
400
- rb_indices = self.backend.asarray([i for i, _ in valid_ids], device=self.device)
401
- indices_in_batch = self.backend.asarray([i for _, i in valid_ids], device=self.device)
402
- self.episode_data_buffer.set_at(
403
- rb_indices,
404
- sbu.get_at(
405
- self._batched_space['episode_data'],
406
- value,
407
- indices_in_batch
436
+ else:
437
+ assert self.backend.is_backendarray(episode_id), \
438
+ "Episode ID must be an integer or a backend array."
439
+ assert len(episode_id.shape) == 1, \
440
+ "Episode ID must be a 1D array."
441
+
442
+ valid_ids = [] # Stores (rb_index, index_in_batch) tuples
443
+ new_ids = [] # Stores (episode_id, index_in_batch) tuples
444
+ for i in range(episode_id.shape[0]):
445
+ ep_id = episode_id[i]
446
+ if ep_id in self.episode_id_to_index_map:
447
+ valid_ids.append((self.episode_id_to_index_map[ep_id], i))
448
+ else:
449
+ new_ids.append((ep_id, i))
450
+ if len(new_ids) > 0:
451
+ assert self.episode_data_buffer.capacity is None or len(self.episode_data_buffer) + len(new_ids) <= self.episode_data_buffer.capacity, \
452
+ "Episode data buffer is full. Cannot set episode data."
453
+ start_index = len(self.episode_id_to_index_map)
454
+ for ep_id, i in new_ids:
455
+ self.episode_id_to_index_map[ep_id] = start_index
456
+ start_index += 1
457
+ self.episode_data_buffer.extend(
458
+ sbu.concatenate(
459
+ self._batched_space['episode_data'],
460
+ [value[i] for _, i in new_ids]
461
+ )
462
+ )
463
+ if len(valid_ids) > 0:
464
+ rb_indices = self.backend.asarray([i for i, _ in valid_ids], device=self.device)
465
+ indices_in_batch = self.backend.asarray([i for _, i in valid_ids], device=self.device)
466
+ self.episode_data_buffer.set_at(
467
+ rb_indices,
468
+ sbu.get_at(
469
+ self._batched_space['episode_data'],
470
+ value,
471
+ indices_in_batch
472
+ )
408
473
  )
409
- )
410
474
 
411
475
  def extend_flattened(self, value):
412
476
  try:
@@ -419,38 +483,39 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
419
483
  self.extend(unflattened_data)
420
484
 
421
485
  def extend(self, value):
422
- B = sbu.batch_size_data(value)
423
- if B == 0:
424
- return
425
-
426
- if not isinstance(value, Mapping) or "step_data" not in value:
427
- # If the value is not a mapping or does not contain "step_data", we assume it's a single step data
428
- value = {
429
- "step_data": value
430
- }
486
+ with self._lock_scope():
487
+ B = sbu.batch_size_data(value)
488
+ if B == 0:
489
+ return
490
+
491
+ if not isinstance(value, Mapping) or "step_data" not in value:
492
+ # If the value is not a mapping or does not contain "step_data", we assume it's a single step data
493
+ value = {
494
+ "step_data": value
495
+ }
431
496
 
432
- if "episode_ids" in value:
433
- episode_ids = value["episode_ids"]
434
- else:
435
- episode_ids = self.backend.full(
436
- (B,),
437
- self.current_episode_id,
438
- dtype=self.backend.default_integer_dtype,
439
- device=self.device
440
- )
441
- self.step_episode_id_buffer.extend(episode_ids)
442
- self.step_data_buffer.extend(value["step_data"])
443
-
444
- if "episode_data" in value and self.episode_data_buffer is not None:
445
- episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
446
- self.set_episode_data_at(
447
- episode_ids_unique,
448
- sbu.get_at(
449
- self._batched_space['episode_data'],
450
- value["episode_data"],
451
- unique_indices
497
+ if "episode_ids" in value:
498
+ episode_ids = value["episode_ids"]
499
+ else:
500
+ episode_ids = self.backend.full(
501
+ (B,),
502
+ self.current_episode_id,
503
+ dtype=self.backend.default_integer_dtype,
504
+ device=self.device
505
+ )
506
+ self.step_episode_id_buffer.extend(episode_ids)
507
+ self.step_data_buffer.extend(value["step_data"])
508
+
509
+ if "episode_data" in value and self.episode_data_buffer is not None:
510
+ episode_ids_unique, unique_indices, _, _ = self.backend.unique_all(episode_ids)
511
+ self.set_episode_data_at(
512
+ episode_ids_unique,
513
+ sbu.get_at(
514
+ self._batched_space['episode_data'],
515
+ value["episode_data"],
516
+ unique_indices
517
+ )
452
518
  )
453
- )
454
519
 
455
520
  def set_current_episode_data(
456
521
  self,
@@ -462,18 +527,20 @@ class TrajectoryReplayBuffer(BatchBase[TrajectoryData[BatchT, EpisodeBatchT], BA
462
527
  )
463
528
 
464
529
  def mark_episode_end(self) -> None:
465
- self.current_episode_id += 1
530
+ with self._lock_scope():
531
+ self.current_episode_id += 1
466
532
 
467
533
  def clear(self):
468
- self.step_data_buffer.clear()
469
- self.step_episode_id_buffer.clear()
470
- if self.episode_data_buffer is not None:
471
- self.episode_data_buffer.clear()
472
- self.episode_id_to_index_map = {}
473
- self.current_episode_id = 0
534
+ with self._lock_scope():
535
+ self.step_data_buffer.clear()
536
+ self.step_episode_id_buffer.clear()
537
+ if self.episode_data_buffer is not None:
538
+ self.episode_data_buffer.clear()
539
+ self.episode_id_to_index_map = {}
540
+ self.current_episode_id = 0
474
541
 
475
542
  def close(self):
476
543
  self.step_data_buffer.close()
477
544
  self.step_episode_id_buffer.close()
478
545
  if self.episode_data_buffer is not None:
479
- self.episode_data_buffer.close()
546
+ self.episode_data_buffer.close()