unienv 0.0.1b5__py3-none-any.whl → 0.0.1b7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/METADATA +3 -2
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/RECORD +30 -21
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/WHEEL +1 -1
- unienv_data/base/common.py +25 -10
- unienv_data/batches/backend_compat.py +1 -1
- unienv_data/batches/combined_batch.py +1 -1
- unienv_data/replay_buffer/replay_buffer.py +51 -8
- unienv_data/storages/_episode_storage.py +438 -0
- unienv_data/storages/_list_storage.py +136 -0
- unienv_data/storages/backend_compat.py +268 -0
- unienv_data/storages/flattened.py +3 -3
- unienv_data/storages/hdf5.py +7 -2
- unienv_data/storages/image_storage.py +144 -0
- unienv_data/storages/npz_storage.py +135 -0
- unienv_data/storages/pytorch.py +16 -9
- unienv_data/storages/video_storage.py +297 -0
- unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
- unienv_data/transformations/image_compress.py +81 -18
- unienv_interface/space/space_utils/batch_utils.py +5 -1
- unienv_interface/space/spaces/dict.py +6 -0
- unienv_interface/transformations/__init__.py +3 -1
- unienv_interface/transformations/batch_and_unbatch.py +43 -4
- unienv_interface/transformations/chained_transform.py +9 -8
- unienv_interface/transformations/crop.py +69 -0
- unienv_interface/transformations/dict_transform.py +8 -2
- unienv_interface/transformations/identity.py +16 -0
- unienv_interface/transformations/rescale.py +24 -5
- unienv_interface/wrapper/backend_compat.py +1 -1
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: unienv
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.1b7
|
|
4
4
|
Summary: Unified robot environment framework supporting multiple tensor and simulation backends
|
|
5
5
|
License-Expression: MIT
|
|
6
6
|
Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
|
|
@@ -12,9 +12,10 @@ Requires-Python: >=3.10
|
|
|
12
12
|
Description-Content-Type: text/markdown
|
|
13
13
|
License-File: LICENSE
|
|
14
14
|
Requires-Dist: numpy
|
|
15
|
-
Requires-Dist: xbarray>=0.0.
|
|
15
|
+
Requires-Dist: xbarray>=0.0.1a13
|
|
16
16
|
Requires-Dist: pillow
|
|
17
17
|
Requires-Dist: cloudpickle
|
|
18
|
+
Requires-Dist: pyvers
|
|
18
19
|
Provides-Extra: dev
|
|
19
20
|
Requires-Dist: pytest; extra == "dev"
|
|
20
21
|
Provides-Extra: gymnasium
|
|
@@ -1,27 +1,34 @@
|
|
|
1
|
-
unienv-0.0.
|
|
1
|
+
unienv-0.0.1b7.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
|
|
2
2
|
unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
|
|
3
3
|
unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
|
|
4
|
-
unienv_data/base/common.py,sha256=
|
|
4
|
+
unienv_data/base/common.py,sha256=A3RtD3Omqk0Qplsc-44ukAEzbQEU22_MkwUlC7l-HHM,13083
|
|
5
5
|
unienv_data/base/storage.py,sha256=afICsO_7Zbm9azV0Jxho_z9F7JM30TUDjJM1NHETDHM,5495
|
|
6
6
|
unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
|
|
7
|
-
unienv_data/batches/backend_compat.py,sha256=
|
|
8
|
-
unienv_data/batches/combined_batch.py,sha256=
|
|
7
|
+
unienv_data/batches/backend_compat.py,sha256=tzFG8gTq0yW-J6PLvu--lCGS0lFc0QfelicJ50p_HYc,8207
|
|
8
|
+
unienv_data/batches/combined_batch.py,sha256=pNrbLvU565BUDWO0pZLCnSMygmoGVCLxjC9OkLRKtLA,15330
|
|
9
9
|
unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
|
|
10
10
|
unienv_data/batches/slicestack_batch.py,sha256=Q3-gsJTvMjKTeZAHWNBTGRsws0HctsfMMTw0vylNxvA,16785
|
|
11
11
|
unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
|
|
12
12
|
unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
|
|
13
13
|
unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
|
|
14
|
-
unienv_data/replay_buffer/replay_buffer.py,sha256=
|
|
14
|
+
unienv_data/replay_buffer/replay_buffer.py,sha256=8vPma5dL6jDGhI3Oo6IEvNcDYJG9Lb0Xlvxp45tQMEs,14498
|
|
15
15
|
unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=cqRmzdewFS8IvJcMwxxQgwZf7TvvrViym87OaCOes3Y,24009
|
|
16
16
|
unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
|
|
17
17
|
unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
|
|
18
18
|
unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
|
|
19
|
+
unienv_data/storages/_episode_storage.py,sha256=OpZt4P-P6LHrBR4F-tNcCFROLskWaOKWCDfoPV7qz1I,21970
|
|
20
|
+
unienv_data/storages/_list_storage.py,sha256=pH9xZOqXCx65NBRRD-INcP8OP-NWsI-JvdzVsPj9MSg,5225
|
|
21
|
+
unienv_data/storages/backend_compat.py,sha256=BxeMJlC3FI60KLJ7QB5kF-mrGlJ6xi584Dcu4IN4Zrc,10714
|
|
19
22
|
unienv_data/storages/dict_storage.py,sha256=DSqRIgo3m1XtUcLtyjYSqqpi01mr_nJOLg5BCddwPcg,13862
|
|
20
|
-
unienv_data/storages/flattened.py,sha256
|
|
21
|
-
unienv_data/storages/hdf5.py,sha256=
|
|
22
|
-
unienv_data/storages/
|
|
23
|
+
unienv_data/storages/flattened.py,sha256=Yf1G4D6KE36sESyDMGWKXqhFjz6Idx7N1aEhihmGovA,7055
|
|
24
|
+
unienv_data/storages/hdf5.py,sha256=Jnls1rs7nlOOp9msmAfhuZp80OZd8S2Llls176EOUc4,27096
|
|
25
|
+
unienv_data/storages/image_storage.py,sha256=4J1ZiGFHbGLHmReMztImJoDcRmiB_llD2wbMB3rdvOQ,5137
|
|
26
|
+
unienv_data/storages/npz_storage.py,sha256=IP2DXbUs_ySzILne3s3hq3gwHiy9tfpWz6HcNciA8DU,4868
|
|
27
|
+
unienv_data/storages/pytorch.py,sha256=bf3ys6eBlMvjyPK4XE-itENjEWq5Vm60qNwBNqJIZqg,7345
|
|
23
28
|
unienv_data/storages/transformation.py,sha256=-9_jPZNpx6RXY_ojv_1UCSTa4Z9apI9V9jit8nG93oM,8133
|
|
24
|
-
unienv_data/
|
|
29
|
+
unienv_data/storages/video_storage.py,sha256=2vcNlghhDZWWzAdf9t0VeCMZrv-x_rYkYaCw8XV8AJA,13331
|
|
30
|
+
unienv_data/third_party/tensordict/memmap_tensor.py,sha256=J6SkFf-FDy43XuaHLgbvDsHt6v2vYfuhRyeoV02P8vw,42589
|
|
31
|
+
unienv_data/transformations/image_compress.py,sha256=f8JTY4DJEXaiu5lO77T4ROV950rh_bOZBchOF-O0tx8,13130
|
|
25
32
|
unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
|
|
26
33
|
unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQwXscyM7FM,198
|
|
27
34
|
unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
|
|
@@ -41,7 +48,7 @@ unienv_interface/func_wrapper/transformation.py,sha256=7mdzcpjLjqtpbtXoqbkGtTMPQ
|
|
|
41
48
|
unienv_interface/space/__init__.py,sha256=6-wLoD9mKDAfz7IuQs_Rn9DMDfDwTZ0tEhQ924libpg,99
|
|
42
49
|
unienv_interface/space/space.py,sha256=mFlCcDvMgEPTXlwo_iwBlm6Eg4Bn2rrecgsfIVstdq0,4067
|
|
43
50
|
unienv_interface/space/space_utils/__init__.py,sha256=GAsPoZC8YNabx3Gw5m2o4zsnG8zmA3mcuM9_lNKhiGo,121
|
|
44
|
-
unienv_interface/space/space_utils/batch_utils.py,sha256=
|
|
51
|
+
unienv_interface/space/space_utils/batch_utils.py,sha256=hD4ItBp2WQzIQR5u0Zkw0FQQfOeg6ZPRi18Johmcc40,37150
|
|
45
52
|
unienv_interface/space/space_utils/construct_utils.py,sha256=Y4RpV9obY8XQ85O3r_NC1HrBk-Nm941ffRNXNL7nHgA,8323
|
|
46
53
|
unienv_interface/space/space_utils/flatten_utils.py,sha256=6ObJgVq4yhOq_7N5E5pQZS6WmmeKu-MyRFJ_x-gqmNg,12607
|
|
47
54
|
unienv_interface/space/space_utils/gym_utils.py,sha256=nH8EKruOKCXNrIMPUd9F4XGKCfFkhxsTmx4I1BeSgn0,15079
|
|
@@ -50,20 +57,22 @@ unienv_interface/space/spaces/__init__.py,sha256=Jap768TlwHFDDiTzHZ0qaHEFEVC1cKA
|
|
|
50
57
|
unienv_interface/space/spaces/batched.py,sha256=RA8aLUSS14zBSCTm_ud18TTa-ntbIZ074xwJ0xls1Jk,3691
|
|
51
58
|
unienv_interface/space/spaces/binary.py,sha256=0iQUbO37dhkznVpjhsJdwlD-KdWgCEx2H7KrybuZ_PM,3570
|
|
52
59
|
unienv_interface/space/spaces/box.py,sha256=NCmileEZOKz-L3WNzZ-uwydrRFsIMdNZBwTn1vWgeI0,13316
|
|
53
|
-
unienv_interface/space/spaces/dict.py,sha256=
|
|
60
|
+
unienv_interface/space/spaces/dict.py,sha256=NggllKi0smoz2bL3yrfBM5FJGBNRWZ05xXaNEqY1QKs,7234
|
|
54
61
|
unienv_interface/space/spaces/dynamic_box.py,sha256=HvMNgzfYwIVc5VVgCtq-8lQbNI1V1dZMI-w60AwYss4,19591
|
|
55
62
|
unienv_interface/space/spaces/graph.py,sha256=KocRFLtYP5VWYpwbP6HybXH5R4jTIYJdNePKb6vhnYE,15163
|
|
56
63
|
unienv_interface/space/spaces/text.py,sha256=ePGGJdiD3q-BAX6IHLO7HMe0OH4VrzF043K02eb0zXI,4443
|
|
57
64
|
unienv_interface/space/spaces/tuple.py,sha256=mmJab6kl5VtQStyn754pmk0RLPSQW06Mu15Hp3Qad80,4287
|
|
58
65
|
unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
|
|
59
|
-
unienv_interface/transformations/__init__.py,sha256=
|
|
60
|
-
unienv_interface/transformations/batch_and_unbatch.py,sha256=
|
|
61
|
-
unienv_interface/transformations/chained_transform.py,sha256=
|
|
62
|
-
unienv_interface/transformations/
|
|
66
|
+
unienv_interface/transformations/__init__.py,sha256=zf8NbY-HW4EgHri9PxpuelEvBpFwUtDEcJiXXhFSDNQ,435
|
|
67
|
+
unienv_interface/transformations/batch_and_unbatch.py,sha256=LIEQ_rtAdccdw38VdmWJT_DuqdOyb7aMFcMWlyQBz2U,2164
|
|
68
|
+
unienv_interface/transformations/chained_transform.py,sha256=_6E1g_8u-WAxKd-f2sHJwKQk9HTIRnulyXwHUwJP12I,2203
|
|
69
|
+
unienv_interface/transformations/crop.py,sha256=sigcQcLklp3P6b6KQfP-Ja3OV1CWeusCLNKMvNNdACQ,3107
|
|
70
|
+
unienv_interface/transformations/dict_transform.py,sha256=GhFSN9t3mL3gvoD_GH-np68Fo4m78YnSyHbUHeyzKcw,5540
|
|
63
71
|
unienv_interface/transformations/filter_dict.py,sha256=DzR-hgHoHJObTipxwB2UrKVlTxbfIrJohaOgqdAICLY,5871
|
|
72
|
+
unienv_interface/transformations/identity.py,sha256=biW3caBis6ixlOJQk2RJ-7OzP16n0yhpIuqvd7e7Ack,549
|
|
64
73
|
unienv_interface/transformations/image_resize.py,sha256=QyPnpMvdx3IvQyW5_iRq7LMnQQuq7XpOv3x6qQHuNeI,4454
|
|
65
74
|
unienv_interface/transformations/iter_transform.py,sha256=lK7fopeiZJrO0WUXFoUmAOhmYkdHXDnChsQ9TJGV8hU,3688
|
|
66
|
-
unienv_interface/transformations/rescale.py,sha256=
|
|
75
|
+
unienv_interface/transformations/rescale.py,sha256=85PAq5ta9KelxMaL6RIJXBFxOmRbZjsGlMJiElCW9wI,5329
|
|
67
76
|
unienv_interface/transformations/transformation.py,sha256=u4_9H1tvophhgG0p0F3xfkMMsRuaKY2TQmVeGoeQsJ0,1652
|
|
68
77
|
unienv_interface/utils/control_util.py,sha256=lY_1EknglY3cNekWX9rYWt0ZUglaPMtIt4M5D9y0WfE,2351
|
|
69
78
|
unienv_interface/utils/framestack_queue.py,sha256=UZiuQDOn39DB9Heu6xinrwuzAL3X8jHlDkFoSC5Phtc,5707
|
|
@@ -80,7 +89,7 @@ unienv_interface/world/node.py,sha256=EAvHnx0u7IudmWQDbAUIRVEqB4kh2Xsm1aXdS3Celo
|
|
|
80
89
|
unienv_interface/world/world.py,sha256=Kl7wbNbs2YR3CjFrCLFhDB3DQUAWM6LjBwSADQtBTII,5740
|
|
81
90
|
unienv_interface/wrapper/__init__.py,sha256=ZNqr-WjVRqgvIxkLkeABxpYZ6tRgJNZOzmluDeJ6W_w,614
|
|
82
91
|
unienv_interface/wrapper/action_rescale.py,sha256=rTJlEHvWSuwGVX83cjfLWvszBk7B2iExX_K37vH8Wic,1231
|
|
83
|
-
unienv_interface/wrapper/backend_compat.py,sha256=
|
|
92
|
+
unienv_interface/wrapper/backend_compat.py,sha256=amLAITi1qLylQ45BkpvmwMXSkG-J9YEu1JPjCrBT5I8,7120
|
|
84
93
|
unienv_interface/wrapper/batch_and_unbatch.py,sha256=HpmnppgOKmshNlfmJYkGQYtEU7_U7q3mEdY5n4UaqEY,3457
|
|
85
94
|
unienv_interface/wrapper/control_frequency_limit.py,sha256=B0E2aUbaUr2p2yIN6wT3q4rAbPYsVroioqma2qKMoC0,2322
|
|
86
95
|
unienv_interface/wrapper/flatten.py,sha256=NWA5xne5j_L34oq_wT85wGvp6iHwdCSeGsk1DMugvRw,5837
|
|
@@ -89,7 +98,7 @@ unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKi
|
|
|
89
98
|
unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
|
|
90
99
|
unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
|
|
91
100
|
unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
|
|
92
|
-
unienv-0.0.
|
|
93
|
-
unienv-0.0.
|
|
94
|
-
unienv-0.0.
|
|
95
|
-
unienv-0.0.
|
|
101
|
+
unienv-0.0.1b7.dist-info/METADATA,sha256=HT6qx5dKz7d5lOf4MBzdtJwx7dixbSaeQviHKCjJYnc,3056
|
|
102
|
+
unienv-0.0.1b7.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
103
|
+
unienv-0.0.1b7.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
|
|
104
|
+
unienv-0.0.1b7.dist-info/RECORD,,
|
unienv_data/base/common.py
CHANGED
|
@@ -7,6 +7,7 @@ from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
|
7
7
|
from unienv_interface.space import Space, BoxSpace, DictSpace
|
|
8
8
|
import dataclasses
|
|
9
9
|
|
|
10
|
+
from functools import cached_property
|
|
10
11
|
from unienv_interface.space.space_utils import batch_utils as space_batch_utils, flatten_utils as space_flatten_utils
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
@@ -46,13 +47,18 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
|
|
|
46
47
|
):
|
|
47
48
|
self.single_space = single_space
|
|
48
49
|
self.single_metadata_space = single_metadata_space
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
|
|
51
|
+
# For backwards compatibility
|
|
52
|
+
@cached_property
|
|
53
|
+
def _batched_space(self) -> Space[BatchT, BDeviceType, BDtypeType, BRNGType]:
|
|
54
|
+
return space_batch_utils.batch_space(self.single_space, 1)
|
|
55
|
+
|
|
56
|
+
@cached_property
|
|
57
|
+
def _batched_metadata_space(self) -> Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]:
|
|
58
|
+
if self.single_metadata_space is not None:
|
|
59
|
+
return space_batch_utils.batch_space(self.single_metadata_space, 1)
|
|
54
60
|
else:
|
|
55
|
-
|
|
61
|
+
return None
|
|
56
62
|
|
|
57
63
|
@property
|
|
58
64
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
@@ -146,7 +152,7 @@ class BatchBase(abc.ABC, Generic[BatchT, BArrayType, BDeviceType, BDtypeType, BR
|
|
|
146
152
|
if tqdm:
|
|
147
153
|
from tqdm import tqdm
|
|
148
154
|
iterable_start = tqdm(iterable_start, desc="Extending Batch")
|
|
149
|
-
for start_idx in
|
|
155
|
+
for start_idx in iterable_start:
|
|
150
156
|
end_idx = min(start_idx + chunk_size, n_total)
|
|
151
157
|
data_chunk = other.get_at(slice(start_idx, end_idx))
|
|
152
158
|
self.extend(data_chunk)
|
|
@@ -183,15 +189,24 @@ class BatchSampler(
|
|
|
183
189
|
) -> None:
|
|
184
190
|
super().__init__(single_space=single_space, single_metadata_space=single_metadata_space)
|
|
185
191
|
self.batch_size = batch_size
|
|
186
|
-
|
|
187
|
-
self._batched_metadata_space : Optional[DictSpace[SamplerDeviceType, SamplerDtypeType, SamplerRNGType]] = space_batch_utils.batch_space(self.single_metadata_space, batch_size) if self.single_metadata_space is not None else None
|
|
188
|
-
|
|
192
|
+
|
|
189
193
|
def manual_seed(self, seed : int) -> None:
|
|
190
194
|
if self.rng is not None:
|
|
191
195
|
self.rng = self.backend.random.random_number_generator(seed, device=self.device)
|
|
192
196
|
if self.data_rng is not None:
|
|
193
197
|
self.data_rng = self.backend.random.random_number_generator(seed, device=self.data.device)
|
|
194
198
|
|
|
199
|
+
@cached_property
|
|
200
|
+
def _batched_space(self) -> Space[BatchT, BDeviceType, BDtypeType, BRNGType]:
|
|
201
|
+
return space_batch_utils.batch_space(self.single_space, self.batch_size)
|
|
202
|
+
|
|
203
|
+
@cached_property
|
|
204
|
+
def _batched_metadata_space(self) -> Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]:
|
|
205
|
+
if self.single_metadata_space is not None:
|
|
206
|
+
return space_batch_utils.batch_space(self.single_metadata_space, self.batch_size)
|
|
207
|
+
else:
|
|
208
|
+
return None
|
|
209
|
+
|
|
195
210
|
@property
|
|
196
211
|
def sampled_space(self) -> Space[SamplerBatchT, SamplerDeviceType, SamplerDtypeType, SamplerRNGType]:
|
|
197
212
|
return self._batched_space
|
|
@@ -36,7 +36,7 @@ def data_to(
|
|
|
36
36
|
key: data_to(value, source_backend, target_backend, target_device)
|
|
37
37
|
for key, value in data.items()
|
|
38
38
|
}
|
|
39
|
-
elif isinstance(data, Sequence):
|
|
39
|
+
elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
|
|
40
40
|
data = [
|
|
41
41
|
data_to(value, source_backend, target_backend, target_device)
|
|
42
42
|
for value in data
|
|
@@ -94,7 +94,7 @@ class CombinedBatch(BatchBase[
|
|
|
94
94
|
batch_index = int(self.backend.sum(
|
|
95
95
|
idx >= self.index_caches[:, 0]
|
|
96
96
|
) - 1)
|
|
97
|
-
return batch_index, idx - self.index_caches[batch_index, 0]
|
|
97
|
+
return batch_index, idx - int(self.index_caches[batch_index, 0])
|
|
98
98
|
|
|
99
99
|
def _convert_index(self, idx : Union[IndexableType, BArrayType]) -> Tuple[
|
|
100
100
|
int,
|
|
@@ -2,6 +2,7 @@ import abc
|
|
|
2
2
|
import os
|
|
3
3
|
import dataclasses
|
|
4
4
|
import multiprocessing as mp
|
|
5
|
+
import ctypes
|
|
5
6
|
from contextlib import nullcontext
|
|
6
7
|
|
|
7
8
|
from typing import Generic, TypeVar, Optional, Any, Dict, Union, Tuple, Sequence, Callable, Type
|
|
@@ -94,6 +95,48 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
94
95
|
metadata = json.load(f)
|
|
95
96
|
return metadata.get('type', None) == __class__.__name__
|
|
96
97
|
return False
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def get_length_from_path(
|
|
101
|
+
path : Union[str, os.PathLike]
|
|
102
|
+
) -> Optional[int]:
|
|
103
|
+
if os.path.exists(os.path.join(path, "metadata.json")):
|
|
104
|
+
with open(os.path.join(path, "metadata.json"), "r") as f:
|
|
105
|
+
metadata = json.load(f)
|
|
106
|
+
if metadata.get('type', None) != __class__.__name__:
|
|
107
|
+
return None
|
|
108
|
+
return int(metadata["count"])
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def get_capacity_from_path(
|
|
113
|
+
path : Union[str, os.PathLike]
|
|
114
|
+
) -> Optional[int]:
|
|
115
|
+
if os.path.exists(os.path.join(path, "metadata.json")):
|
|
116
|
+
with open(os.path.join(path, "metadata.json"), "r") as f:
|
|
117
|
+
metadata = json.load(f)
|
|
118
|
+
if metadata.get('type', None) != __class__.__name__:
|
|
119
|
+
return None
|
|
120
|
+
return int(metadata["capacity"])
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def get_space_from_path(
|
|
125
|
+
path : Union[str, os.PathLike],
|
|
126
|
+
*,
|
|
127
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
128
|
+
device: Optional[BDeviceType] = None,
|
|
129
|
+
) -> Optional[Space[BatchT, BDeviceType, BDtypeType, BRNGType]]:
|
|
130
|
+
if os.path.exists(os.path.join(path, "metadata.json")):
|
|
131
|
+
with open(os.path.join(path, "metadata.json"), "r") as f:
|
|
132
|
+
metadata = json.load(f)
|
|
133
|
+
if metadata.get('type', None) != __class__.__name__:
|
|
134
|
+
return None
|
|
135
|
+
single_instance_space = bsu.json_to_space(
|
|
136
|
+
metadata["single_instance_space"], backend, device
|
|
137
|
+
)
|
|
138
|
+
return single_instance_space
|
|
139
|
+
return None
|
|
97
140
|
|
|
98
141
|
@staticmethod
|
|
99
142
|
def load_from(
|
|
@@ -167,11 +210,11 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
167
210
|
self._storage_path_relative = storage_path_relative
|
|
168
211
|
self._cache_path = cache_path
|
|
169
212
|
self._multiprocessing = multiprocessing
|
|
170
|
-
if multiprocessing:
|
|
213
|
+
if multiprocessing and storage.is_mutable:
|
|
171
214
|
assert storage.is_multiprocessing_safe, "Storage is not multiprocessing safe"
|
|
172
|
-
self._lock = mp.
|
|
173
|
-
self._count_value = mp.Value(
|
|
174
|
-
self._offset_value = mp.Value(
|
|
215
|
+
self._lock = mp.RLock()
|
|
216
|
+
self._count_value = mp.Value(ctypes.c_long, int(count))
|
|
217
|
+
self._offset_value = mp.Value(ctypes.c_long, int(offset))
|
|
175
218
|
else:
|
|
176
219
|
self._lock = None
|
|
177
220
|
self._count_value = int(count)
|
|
@@ -201,22 +244,22 @@ class ReplayBuffer(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGTy
|
|
|
201
244
|
|
|
202
245
|
@property
|
|
203
246
|
def count(self) -> int:
|
|
204
|
-
return self._count_value.value if self.
|
|
247
|
+
return self._count_value.value if not isinstance(self._count_value, int) else self._count_value
|
|
205
248
|
|
|
206
249
|
@count.setter
|
|
207
250
|
def count(self, value: int) -> None:
|
|
208
|
-
if self.
|
|
251
|
+
if not isinstance(self._count_value, int):
|
|
209
252
|
self._count_value.value = int(value)
|
|
210
253
|
else:
|
|
211
254
|
self._count_value = int(value)
|
|
212
255
|
|
|
213
256
|
@property
|
|
214
257
|
def offset(self) -> int:
|
|
215
|
-
return self._offset_value.value if self.
|
|
258
|
+
return self._offset_value.value if not isinstance(self._offset_value, int) else self._offset_value
|
|
216
259
|
|
|
217
260
|
@offset.setter
|
|
218
261
|
def offset(self, value: int) -> None:
|
|
219
|
-
if self.
|
|
262
|
+
if not isinstance(self._offset_value, int):
|
|
220
263
|
self._offset_value.value = int(value)
|
|
221
264
|
else:
|
|
222
265
|
self._offset_value = int(value)
|