unienv 0.0.1b4__py3-none-any.whl → 0.0.1b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b4
3
+ Version: 0.0.1b5
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -1,27 +1,27 @@
1
- unienv-0.0.1b4.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
1
+ unienv-0.0.1b5.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
2
2
  unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
3
3
  unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
4
4
  unienv_data/base/common.py,sha256=EYOzuYmvsCy1uJftsw6cXeycPIr8P7GWZ3_q4wgoNeo,12879
5
- unienv_data/base/storage.py,sha256=s99PYEZGa76kf-Enx57kOyVkwjb-tpU-vTHcGc5Dcew,5415
5
+ unienv_data/base/storage.py,sha256=afICsO_7Zbm9azV0Jxho_z9F7JM30TUDjJM1NHETDHM,5495
6
6
  unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
7
7
  unienv_data/batches/backend_compat.py,sha256=7Juf7nU2jYHohRzNzmGfqMMpJtFM-3oTzzLu6EbC77E,8168
8
8
  unienv_data/batches/combined_batch.py,sha256=aua1H86sa_qWrCtAAp5JIZMGtFiiKFPFkU0y5JoyElM,15325
9
9
  unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
10
- unienv_data/batches/slicestack_batch.py,sha256=J2EhARcPA-zz6EBnV7OLzm4yyvnZ06vrdUoPD5RkJ-o,16672
10
+ unienv_data/batches/slicestack_batch.py,sha256=Q3-gsJTvMjKTeZAHWNBTGRsws0HctsfMMTw0vylNxvA,16785
11
11
  unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
12
12
  unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
13
13
  unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
14
- unienv_data/replay_buffer/replay_buffer.py,sha256=nhbC-7aHGIYhcCdmaaDdhB2U9ODAZrbKMq8dP8ffOv0,10344
15
- unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=fxV6FIqAHhN8opYs2WjAJMPqNRWD3iIku-4WlaydyG4,20737
14
+ unienv_data/replay_buffer/replay_buffer.py,sha256=tdpKXzztd830FXzOE7SRYQ9Hu4cXmigy313EuF6Q-9c,12713
15
+ unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=cqRmzdewFS8IvJcMwxxQgwZf7TvvrViym87OaCOes3Y,24009
16
16
  unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
17
17
  unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
18
18
  unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
19
- unienv_data/storages/dict_storage.py,sha256=SqCGcGT9Y4l0thdmx23XSxRMzIEIuldA6m8Cd9HrpnA,12588
20
- unienv_data/storages/flattened.py,sha256=Fu01TjrzvmyNhXEGtC4FiBTb7cqXDtVkErc1QNwLvcI,6704
21
- unienv_data/storages/hdf5.py,sha256=F_mkrmX6SGT2HamJAyYopBmj_Nf5NzJiyvVN9irtiiM,26260
22
- unienv_data/storages/pytorch.py,sha256=ftO8cND7PFV0J1B1o2YOWqj4U_pyWsJvWv9lC9A7LJg,6953
23
- unienv_data/storages/transformation.py,sha256=9BIwrvdruiTRduqC03e5UbSjBT1jLSxLCkNfrsVDP7I,7577
24
- unienv_data/transformations/image_compress.py,sha256=dINrvmpTWy3sbqruHk0kPZG2XNyJI90ERgErXV7GamE,9131
19
+ unienv_data/storages/dict_storage.py,sha256=DSqRIgo3m1XtUcLtyjYSqqpi01mr_nJOLg5BCddwPcg,13862
20
+ unienv_data/storages/flattened.py,sha256=-1NBoCXSBMcE_kd5GRD_EvTyDupvpSzEdZka9DIWacU,7046
21
+ unienv_data/storages/hdf5.py,sha256=uNhL7ji0Zzp4qRSXPS1H7Q74W1I4terUV8dNYKak06k,26730
22
+ unienv_data/storages/pytorch.py,sha256=YczjJ9OVM-QqAYWuluSRSv80vZnsg1nP-b6vSk-6D2Y,6953
23
+ unienv_data/storages/transformation.py,sha256=-9_jPZNpx6RXY_ojv_1UCSTa4Z9apI9V9jit8nG93oM,8133
24
+ unienv_data/transformations/image_compress.py,sha256=pp9Q5CoVcZwu-oajnpFvtMuv7e3w2ZZV50S4pshoaj4,10000
25
25
  unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
26
26
  unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQwXscyM7FM,198
27
27
  unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
@@ -36,14 +36,14 @@ unienv_interface/env_base/funcenv_wrapper.py,sha256=chw1iJ1RhAFMv4JAk67cttJvI9ag
36
36
  unienv_interface/env_base/vec_env.py,sha256=bcv6NdOxt0Xp1fRMXqzFtmVw6LQ-pDj_Jvj-qaW6otQ,16116
37
37
  unienv_interface/env_base/wrapper.py,sha256=7hf4Rr2wouS0igPoahhvb2tzYY3bCaWL0NlgwpYZwQs,9734
38
38
  unienv_interface/func_wrapper/__init__.py,sha256=6BPF8O25WkIBpODVTwnOE9HGSm3KRKX6iPwFGWESlxA,123
39
- unienv_interface/func_wrapper/frame_stack.py,sha256=52CqAHDqwgHtOwMwxzB3Syup9kA19zdlvXCH4mI7MNU,6819
39
+ unienv_interface/func_wrapper/frame_stack.py,sha256=wuGsrluoz60FTczRuo8sHPfpl_Yl4GVTRBb2QDzYPrA,6825
40
40
  unienv_interface/func_wrapper/transformation.py,sha256=7mdzcpjLjqtpbtXoqbkGtTMPQxoMmMsqzDWHcZLbrhs,5939
41
41
  unienv_interface/space/__init__.py,sha256=6-wLoD9mKDAfz7IuQs_Rn9DMDfDwTZ0tEhQ924libpg,99
42
42
  unienv_interface/space/space.py,sha256=mFlCcDvMgEPTXlwo_iwBlm6Eg4Bn2rrecgsfIVstdq0,4067
43
43
  unienv_interface/space/space_utils/__init__.py,sha256=GAsPoZC8YNabx3Gw5m2o4zsnG8zmA3mcuM9_lNKhiGo,121
44
44
  unienv_interface/space/space_utils/batch_utils.py,sha256=qXK7kERPXKGIYozz7lpjzVz56S9GkH6ZASfIRzCYXHY,36993
45
45
  unienv_interface/space/space_utils/construct_utils.py,sha256=Y4RpV9obY8XQ85O3r_NC1HrBk-Nm941ffRNXNL7nHgA,8323
46
- unienv_interface/space/space_utils/flatten_utils.py,sha256=kkHkjrsk43NDbg3Q5VAhVoIXStuRayYFO-7knsDzx4A,12289
46
+ unienv_interface/space/space_utils/flatten_utils.py,sha256=6ObJgVq4yhOq_7N5E5pQZS6WmmeKu-MyRFJ_x-gqmNg,12607
47
47
  unienv_interface/space/space_utils/gym_utils.py,sha256=nH8EKruOKCXNrIMPUd9F4XGKCfFkhxsTmx4I1BeSgn0,15079
48
48
  unienv_interface/space/space_utils/serialization_utils.py,sha256=LWYSFN7E6tEFe8ULWm42LkFUxP_0dfTGkCcx0yl4Y8s,9530
49
49
  unienv_interface/space/spaces/__init__.py,sha256=Jap768TlwHFDDiTzHZ0qaHEFEVC1cKA2QzLlSZVQnjI,535
@@ -54,20 +54,22 @@ unienv_interface/space/spaces/dict.py,sha256=G5_iYC1Bj5DqeJ7aFlq6eRJbnpATbIRIyRu
54
54
  unienv_interface/space/spaces/dynamic_box.py,sha256=HvMNgzfYwIVc5VVgCtq-8lQbNI1V1dZMI-w60AwYss4,19591
55
55
  unienv_interface/space/spaces/graph.py,sha256=KocRFLtYP5VWYpwbP6HybXH5R4jTIYJdNePKb6vhnYE,15163
56
56
  unienv_interface/space/spaces/text.py,sha256=ePGGJdiD3q-BAX6IHLO7HMe0OH4VrzF043K02eb0zXI,4443
57
- unienv_interface/space/spaces/tuple.py,sha256=rgZQz3EB3CLbIk9UlHBQbM6w9gssbA1izm-Qq-_Chqs,4267
57
+ unienv_interface/space/spaces/tuple.py,sha256=mmJab6kl5VtQStyn754pmk0RLPSQW06Mu15Hp3Qad80,4287
58
58
  unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
59
59
  unienv_interface/transformations/__init__.py,sha256=g19uGnDHMywvDAXRaqFgoWAF1vCPrbJENEpaEgtIrOw,353
60
60
  unienv_interface/transformations/batch_and_unbatch.py,sha256=ELCnNtwmgA5wpTBJZasfNSHmtf4vzydzLPmO6IGbT9o,1164
61
61
  unienv_interface/transformations/chained_transform.py,sha256=TDnUvxUKK6bXGc_sfr6ZCvvVWw7P5KX2sA9i7i2lx14,2075
62
62
  unienv_interface/transformations/dict_transform.py,sha256=ynrJrloVUix2I27Ir1mL86crT0vY5DvpiBAVxPBJup4,5357
63
63
  unienv_interface/transformations/filter_dict.py,sha256=DzR-hgHoHJObTipxwB2UrKVlTxbfIrJohaOgqdAICLY,5871
64
+ unienv_interface/transformations/image_resize.py,sha256=QyPnpMvdx3IvQyW5_iRq7LMnQQuq7XpOv3x6qQHuNeI,4454
65
+ unienv_interface/transformations/iter_transform.py,sha256=lK7fopeiZJrO0WUXFoUmAOhmYkdHXDnChsQ9TJGV8hU,3688
64
66
  unienv_interface/transformations/rescale.py,sha256=fM5ukWUvNvPeDO48_PRU0KyyvGhBIDxaN9XZyQ1VaQQ,4364
65
67
  unienv_interface/transformations/transformation.py,sha256=u4_9H1tvophhgG0p0F3xfkMMsRuaKY2TQmVeGoeQsJ0,1652
66
68
  unienv_interface/utils/control_util.py,sha256=lY_1EknglY3cNekWX9rYWt0ZUglaPMtIt4M5D9y0WfE,2351
67
- unienv_interface/utils/data_queue.py,sha256=UZiuQDOn39DB9Heu6xinrwuzAL3X8jHlDkFoSC5Phtc,5707
69
+ unienv_interface/utils/framestack_queue.py,sha256=UZiuQDOn39DB9Heu6xinrwuzAL3X8jHlDkFoSC5Phtc,5707
68
70
  unienv_interface/utils/seed_util.py,sha256=Up3nBXj7L8w-S9W5Q1U2d9accMhMf0TmHPaN6JXDVWs,677
69
71
  unienv_interface/utils/stateclass.py,sha256=xjzicPGX1UuI7q3ZAxhBCCoouKfNtLywUzQtLaT0yS4,1390
70
- unienv_interface/utils/symbol_util.py,sha256=NAERK-D_2MaTg2eYW-L75tbzPQN5YJIiKtM9zuQ89Sw,383
72
+ unienv_interface/utils/symbol_util.py,sha256=EKC5cVyuXaP5n68-bSbk1A3jCCJCrX90BF7c8mFQYrU,562
71
73
  unienv_interface/utils/vec_util.py,sha256=EIK680ReCl_rr-qKP8co5hwz8Dx-gks8SHf-CLOZSOA,373
72
74
  unienv_interface/world/__init__.py,sha256=aGuYTz8XFzW32RGkdi2b2LJ1sa0kgFrQyOR3JXDEwLQ,230
73
75
  unienv_interface/world/combined_funcnode.py,sha256=O9qWxhtMJkDVtWuGyaeEj3nKMgIyRAPqF9-5LU6yna8,10853
@@ -82,12 +84,12 @@ unienv_interface/wrapper/backend_compat.py,sha256=T6hosgu2hrZvg3xtnyELmR6Exlz-zt
82
84
  unienv_interface/wrapper/batch_and_unbatch.py,sha256=HpmnppgOKmshNlfmJYkGQYtEU7_U7q3mEdY5n4UaqEY,3457
83
85
  unienv_interface/wrapper/control_frequency_limit.py,sha256=B0E2aUbaUr2p2yIN6wT3q4rAbPYsVroioqma2qKMoC0,2322
84
86
  unienv_interface/wrapper/flatten.py,sha256=NWA5xne5j_L34oq_wT85wGvp6iHwdCSeGsk1DMugvRw,5837
85
- unienv_interface/wrapper/frame_stack.py,sha256=lZZh_T_AmxsRWeYSLsTU321lVgIt12MX1eWl_yRNlWg,6002
87
+ unienv_interface/wrapper/frame_stack.py,sha256=07rt8SuUQmniu0HRAzAuSrW9K1ri_87UxxsF-WIUzbI,6008
86
88
  unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKilWaQQyBkw,9789
87
89
  unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
88
90
  unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
89
91
  unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
90
- unienv-0.0.1b4.dist-info/METADATA,sha256=R_70XnKo1K6ObRxMmSlW1W_lxfD_rGR6txa3wBHGPOM,3033
91
- unienv-0.0.1b4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
92
- unienv-0.0.1b4.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
93
- unienv-0.0.1b4.dist-info/RECORD,,
92
+ unienv-0.0.1b5.dist-info/METADATA,sha256=IH2-ZqP73SKp1QKGX05k16X0M5bl_JGWrSulUXdrkE4,3033
93
+ unienv-0.0.1b5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
94
+ unienv-0.0.1b5.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
95
+ unienv-0.0.1b5.dist-info/RECORD,,
@@ -20,6 +20,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
20
20
  *args,
21
21
  capacity : Optional[int],
22
22
  cache_path : Optional[Union[str, os.PathLike]] = None,
23
+ multiprocessing : bool = False,
23
24
  **kwargs
24
25
  ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
25
26
  raise NotImplementedError
@@ -32,6 +33,7 @@ class SpaceStorage(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType,
32
33
  *,
33
34
  capacity : Optional[int] = None,
34
35
  read_only : bool = True,
36
+ multiprocessing : bool = False,
35
37
  **kwargs
36
38
  ) -> "SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
37
39
  raise NotImplementedError
@@ -33,6 +33,7 @@ class SliceStackedBatch(BatchBase[
33
33
  fill_invalid_data : bool = True,
34
34
  stack_metadata : bool = False,
35
35
  ):
36
+ assert batch.backend.dtype_is_real_integer(fixed_offset.dtype), "Fixed offset must be an integer tensor"
36
37
  assert len(fixed_offset.shape) == 1, "Fixed offset must be a 1D tensor"
37
38
  assert fixed_offset.shape[0] > 0, "Fixed offset must have a positive length"
38
39
  assert batch.backend.any(fixed_offset == 0), "There should be at least one zero in the fixed offset"
@@ -1,6 +1,9 @@
1
1
  import abc
2
2
  import os
3
3
  import dataclasses
4
+ import multiprocessing as mp
5
+ from contextlib import nullcontext
6
+
4
7
  from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
5
8
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
9
 
@@ -51,7 +54,6 @@ def index_with_offset(
51
54
  return data_index
52
55
 
53
56
  class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
54
- is_mutable = True
55
57
  # =========== Class Attributes ==========
56
58
  @staticmethod
57
59
  def create(
@@ -60,6 +62,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
60
62
  *args,
61
63
  cache_path : Optional[Union[str, os.PathLike]] = None,
62
64
  capacity : Optional[int] = None,
65
+ multiprocessing : bool = False,
63
66
  **kwargs
64
67
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
65
68
  storage_path_relative = "storage" + (storage_cls.single_file_ext or "")
@@ -70,6 +73,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
70
73
  *args,
71
74
  cache_path=None if cache_path is None else os.path.join(cache_path, storage_path_relative),
72
75
  capacity=capacity,
76
+ multiprocessing=multiprocessing,
73
77
  **kwargs
74
78
  )
75
79
  return ReplayBuffer(
@@ -77,7 +81,8 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
77
81
  storage_path_relative,
78
82
  0,
79
83
  0,
80
- cache_path=cache_path
84
+ cache_path=cache_path,
85
+ multiprocessing=multiprocessing
81
86
  )
82
87
 
83
88
  @staticmethod
@@ -97,6 +102,7 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
97
102
  backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
98
103
  device: Optional[BDeviceType] = None,
99
104
  read_only : bool = True,
105
+ multiprocessing : bool = False,
100
106
  **storage_kwargs
101
107
  ) -> "ReplayBuffer[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
102
108
  with open(os.path.join(path, "metadata.json"), "r") as f:
@@ -118,52 +124,103 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
118
124
  single_instance_space,
119
125
  capacity=capacity,
120
126
  read_only=read_only,
127
+ multiprocessing=multiprocessing,
121
128
  **storage_kwargs
122
129
  )
123
- return ReplayBuffer(storage, metadata["storage_path_relative"], count, offset, cache_path=path)
130
+ return ReplayBuffer(
131
+ storage,
132
+ metadata["storage_path_relative"],
133
+ count,
134
+ offset,
135
+ cache_path=path,
136
+ multiprocessing=multiprocessing
137
+ )
124
138
 
125
139
  # =========== Instance Attributes and Methods ==========
126
140
  def dumps(self, path : Union[str, os.PathLike]):
127
- os.makedirs(path, exist_ok=True)
128
- storage_path = os.path.join(path, self.storage_path_relative)
129
- self.storage.dumps(storage_path)
130
- metadata = {
131
- "type": __class__.__name__,
132
- "count": self.count,
133
- "offset": self.offset,
134
- "capacity": self.storage.capacity,
135
- "storage_cls": get_full_class_name(type(self.storage)),
136
- "storage_path_relative": self.storage_path_relative,
137
- "single_instance_space": bsu.space_to_json(self.storage.single_instance_space),
138
- }
139
- with open(os.path.join(path, "metadata.json"), "w") as f:
140
- json.dump(metadata, f)
141
+ with self._lock_scope():
142
+ os.makedirs(path, exist_ok=True)
143
+ storage_path = os.path.join(path, self.storage_path_relative)
144
+ self.storage.dumps(storage_path)
145
+ metadata = {
146
+ "type": __class__.__name__,
147
+ "count": self.count,
148
+ "offset": self.offset,
149
+ "capacity": self.storage.capacity,
150
+ "storage_cls": get_full_class_name(type(self.storage)),
151
+ "storage_path_relative": self.storage_path_relative,
152
+ "single_instance_space": bsu.space_to_json(self.storage.single_instance_space),
153
+ }
154
+ with open(os.path.join(path, "metadata.json"), "w") as f:
155
+ json.dump(metadata, f)
141
156
 
142
157
  def __init__(
143
158
  self,
144
159
  storage : SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType],
145
- storage_path_relative : Union[str, os.PathLike],
160
+ storage_path_relative : str,
146
161
  count : int = 0,
147
162
  offset : int = 0,
148
163
  cache_path : Optional[Union[str, os.PathLike]] = None,
164
+ multiprocessing : bool = False,
149
165
  ):
150
166
  self.storage = storage
151
- self.count = count
152
- self.offset = offset
153
- self.storage_path_relative = storage_path_relative
167
+ self._storage_path_relative = storage_path_relative
154
168
  self._cache_path = cache_path
169
+ self._multiprocessing = multiprocessing
170
+ if multiprocessing:
171
+ assert storage.is_multiprocessing_safe, "Storage is not multiprocessing safe"
172
+ self._lock = mp.Lock()
173
+ self._count_value = mp.Value("q", int(count))
174
+ self._offset_value = mp.Value("q", int(offset))
175
+ else:
176
+ self._lock = None
177
+ self._count_value = int(count)
178
+ self._offset_value = int(offset)
179
+
155
180
  super().__init__(
156
181
  storage.single_instance_space,
157
182
  None
158
183
  )
159
184
 
185
+ def _lock_scope(self):
186
+ if self._lock is not None:
187
+ return self._lock
188
+ else:
189
+ return nullcontext()
190
+
160
191
  @property
161
192
  def cache_path(self) -> Optional[Union[str, os.PathLike]]:
162
193
  return self._cache_path
163
194
 
195
+ @property
196
+ def storage_path_relative(self) -> str:
197
+ return self._storage_path_relative
198
+
164
199
  def __len__(self) -> int:
165
200
  return self.count
201
+
202
+ @property
203
+ def count(self) -> int:
204
+ return self._count_value.value if self._multiprocessing else self._count_value
205
+
206
+ @count.setter
207
+ def count(self, value: int) -> None:
208
+ if self._multiprocessing:
209
+ self._count_value.value = int(value)
210
+ else:
211
+ self._count_value = int(value)
166
212
 
213
+ @property
214
+ def offset(self) -> int:
215
+ return self._offset_value.value if self._multiprocessing else self._offset_value
216
+
217
+ @offset.setter
218
+ def offset(self, value: int) -> None:
219
+ if self._multiprocessing:
220
+ self._offset_value.value = int(value)
221
+ else:
222
+ self._offset_value = int(value)
223
+
167
224
  @property
168
225
  def capacity(self) -> Optional[int]:
169
226
  return self.storage.capacity
@@ -176,12 +233,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
176
233
  def device(self) -> Optional[BDeviceType]:
177
234
  return self.storage.device
178
235
 
236
+ @property
237
+ def is_mutable(self) -> bool:
238
+ return self.storage.is_mutable
239
+
240
+ @property
241
+ def is_multiprocessing_safe(self) -> bool:
242
+ return self._multiprocessing
243
+
179
244
  def get_flattened_at(self, idx):
180
245
  return self.get_flattened_at_with_metadata(idx)[0]
181
246
 
182
247
  def get_flattened_at_with_metadata(self, idx: Union[IndexableType, BArrayType]) -> BArrayType:
183
248
  if hasattr(self.storage, "get_flattened"):
184
- data = self.storage.get_flattened(idx)
249
+ with self._lock_scope():
250
+ data = self.storage.get_flattened(idx)
185
251
  return data, None
186
252
 
187
253
  data, metadata = self.get_at_with_metadata(idx)
@@ -195,19 +261,21 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
195
261
  return self.get_at_with_metadata(idx)[0]
196
262
 
197
263
  def get_at_with_metadata(self, idx):
198
- data_index = index_with_offset(
199
- self.backend,
200
- idx,
201
- self.count,
202
- self.offset,
203
- self.device
204
- )
205
- data = self.storage.get(data_index)
264
+ with self._lock_scope():
265
+ data_index = index_with_offset(
266
+ self.backend,
267
+ idx,
268
+ self.count,
269
+ self.offset,
270
+ self.device
271
+ )
272
+ data = self.storage.get(data_index)
206
273
  return data, None
207
274
 
208
275
  def set_flattened_at(self, idx: Union[IndexableType, BArrayType], value: BArrayType) -> None:
209
276
  if hasattr(self.storage, "set_flattened"):
210
- self.storage.set_flattened(idx, value)
277
+ with self._lock_scope():
278
+ self.storage.set_flattened(idx, value)
211
279
  return
212
280
 
213
281
  if isinstance(idx, int):
@@ -217,13 +285,14 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
217
285
  self.set_at(idx, value)
218
286
 
219
287
  def set_at(self, idx, value):
220
- self.storage.set(index_with_offset(
221
- self.backend,
222
- idx,
223
- self.count,
224
- self.offset,
225
- self.device
226
- ), value)
288
+ with self._lock_scope():
289
+ self.storage.set(index_with_offset(
290
+ self.backend,
291
+ idx,
292
+ self.count,
293
+ self.offset,
294
+ self.device
295
+ ), value)
227
296
 
228
297
  def extend_flattened(
229
298
  self,
@@ -233,35 +302,37 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
233
302
  self.extend(unflattened_data)
234
303
 
235
304
  def extend(self, value):
236
- B = sbu.batch_size_data(value)
237
- if B == 0:
238
- return
239
- if self.capacity is None:
240
- assert self.offset == 0, "Offset must be 0 when capacity is None"
241
- self.storage.extend_length(B)
242
- self.storage.set(slice(self.count, self.count + B), value)
243
- self.count += B
244
- return
245
-
246
- # We have a fixed capacity, only keep the last `capacity` elements
247
- if B >= self.capacity:
248
- self.storage.set(Ellipsis, sbu.get_at(self._batched_space, value, slice(-self.capacity, None)))
249
- self.count = self.capacity
250
- self.offset = 0
251
- return
252
-
253
- # Otherwise, perform round-robin writes
254
- indexes = (self.backend.arange(B, device=self.device) + self.offset + self.count) % self.capacity
255
- self.storage.set(indexes, value)
256
- outflow = max(0, self.count + B - self.capacity)
257
- if outflow > 0:
258
- self.offset = (self.offset + outflow) % self.capacity
259
- self.count = min(self.count + B, self.capacity)
305
+ with self._lock_scope():
306
+ B = sbu.batch_size_data(value)
307
+ if B == 0:
308
+ return
309
+ if self.capacity is None:
310
+ assert self.offset == 0, "Offset must be 0 when capacity is None"
311
+ self.storage.extend_length(B)
312
+ self.storage.set(slice(self.count, self.count + B), value)
313
+ self.count += B
314
+ return
315
+
316
+ # We have a fixed capacity, only keep the last `capacity` elements
317
+ if B >= self.capacity:
318
+ self.storage.set(Ellipsis, sbu.get_at(self._batched_space, value, slice(-self.capacity, None)))
319
+ self.count = self.capacity
320
+ self.offset = 0
321
+ return
322
+
323
+ # Otherwise, perform round-robin writes
324
+ indexes = (self.backend.arange(B, device=self.device) + self.offset + self.count) % self.capacity
325
+ self.storage.set(indexes, value)
326
+ outflow = max(0, self.count + B - self.capacity)
327
+ if outflow > 0:
328
+ self.offset = (self.offset + outflow) % self.capacity
329
+ self.count = min(self.count + B, self.capacity)
260
330
 
261
331
  def clear(self):
262
- self.count = 0
263
- self.offset = 0
264
- self.storage.clear()
332
+ with self._lock_scope():
333
+ self.count = 0
334
+ self.offset = 0
335
+ self.storage.clear()
265
336
 
266
337
  def close(self) -> None:
267
338
  self.storage.close()