unienv 0.0.1b6__py3-none-any.whl → 0.0.1b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b6.dist-info → unienv-0.0.1b8.dist-info}/METADATA +1 -1
- {unienv-0.0.1b6.dist-info → unienv-0.0.1b8.dist-info}/RECORD +12 -11
- {unienv-0.0.1b6.dist-info → unienv-0.0.1b8.dist-info}/WHEEL +1 -1
- unienv_data/batches/combined_batch.py +110 -67
- unienv_data/integrations/huggingface.py +47 -0
- unienv_data/integrations/pytorch.py +62 -3
- unienv_data/storages/_episode_storage.py +13 -4
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/video_storage.py +339 -101
- unienv_interface/transformations/batch_and_unbatch.py +1 -0
- {unienv-0.0.1b6.dist-info → unienv-0.0.1b8.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b6.dist-info → unienv-0.0.1b8.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
unienv-0.0.
|
|
1
|
+
unienv-0.0.1b8.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
|
|
2
2
|
unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
|
|
3
3
|
unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
|
|
4
4
|
unienv_data/base/common.py,sha256=A3RtD3Omqk0Qplsc-44ukAEzbQEU22_MkwUlC7l-HHM,13083
|
|
5
5
|
unienv_data/base/storage.py,sha256=afICsO_7Zbm9azV0Jxho_z9F7JM30TUDjJM1NHETDHM,5495
|
|
6
6
|
unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
|
|
7
7
|
unienv_data/batches/backend_compat.py,sha256=tzFG8gTq0yW-J6PLvu--lCGS0lFc0QfelicJ50p_HYc,8207
|
|
8
|
-
unienv_data/batches/combined_batch.py,sha256=
|
|
8
|
+
unienv_data/batches/combined_batch.py,sha256=iMCY7pEXf_YrIvq-G92ttaANl-SUcpzDLvH7fI-GImA,16313
|
|
9
9
|
unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
|
|
10
10
|
unienv_data/batches/slicestack_batch.py,sha256=Q3-gsJTvMjKTeZAHWNBTGRsws0HctsfMMTw0vylNxvA,16785
|
|
11
11
|
unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
|
|
12
|
-
unienv_data/integrations/
|
|
12
|
+
unienv_data/integrations/huggingface.py,sha256=cqaNS1Xpv5udAYgytd0zsVI8k9lSKMKz1jdxWnDseK4,1535
|
|
13
|
+
unienv_data/integrations/pytorch.py,sha256=cyUypd4kZVH8WQFrrv4yFgYZJGT4MeuhvxzZMc4-1dM,4553
|
|
13
14
|
unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
|
|
14
15
|
unienv_data/replay_buffer/replay_buffer.py,sha256=8vPma5dL6jDGhI3Oo6IEvNcDYJG9Lb0Xlvxp45tQMEs,14498
|
|
15
16
|
unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=cqRmzdewFS8IvJcMwxxQgwZf7TvvrViym87OaCOes3Y,24009
|
|
16
17
|
unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
|
|
17
18
|
unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
|
|
18
19
|
unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
|
|
19
|
-
unienv_data/storages/_episode_storage.py,sha256=
|
|
20
|
+
unienv_data/storages/_episode_storage.py,sha256=gdrj_bIRquyT0CIm9-0du5cCNhF-4cxwvf9u9POcZVg,22603
|
|
20
21
|
unienv_data/storages/_list_storage.py,sha256=pH9xZOqXCx65NBRRD-INcP8OP-NWsI-JvdzVsPj9MSg,5225
|
|
21
22
|
unienv_data/storages/backend_compat.py,sha256=BxeMJlC3FI60KLJ7QB5kF-mrGlJ6xi584Dcu4IN4Zrc,10714
|
|
22
23
|
unienv_data/storages/dict_storage.py,sha256=DSqRIgo3m1XtUcLtyjYSqqpi01mr_nJOLg5BCddwPcg,13862
|
|
@@ -24,9 +25,9 @@ unienv_data/storages/flattened.py,sha256=Yf1G4D6KE36sESyDMGWKXqhFjz6Idx7N1aEhihm
|
|
|
24
25
|
unienv_data/storages/hdf5.py,sha256=Jnls1rs7nlOOp9msmAfhuZp80OZd8S2Llls176EOUc4,27096
|
|
25
26
|
unienv_data/storages/image_storage.py,sha256=4J1ZiGFHbGLHmReMztImJoDcRmiB_llD2wbMB3rdvOQ,5137
|
|
26
27
|
unienv_data/storages/npz_storage.py,sha256=IP2DXbUs_ySzILne3s3hq3gwHiy9tfpWz6HcNciA8DU,4868
|
|
27
|
-
unienv_data/storages/pytorch.py,sha256=
|
|
28
|
+
unienv_data/storages/pytorch.py,sha256=MltiTcBuvNWQqE5-RAfjBz3gwdfATkVMTmSBR2escIE,7363
|
|
28
29
|
unienv_data/storages/transformation.py,sha256=-9_jPZNpx6RXY_ojv_1UCSTa4Z9apI9V9jit8nG93oM,8133
|
|
29
|
-
unienv_data/storages/video_storage.py,sha256=
|
|
30
|
+
unienv_data/storages/video_storage.py,sha256=0d3FP-BwXiq7q_s9UJE0G7v76Jfgz7Q07vX159UGVnE,22572
|
|
30
31
|
unienv_data/third_party/tensordict/memmap_tensor.py,sha256=J6SkFf-FDy43XuaHLgbvDsHt6v2vYfuhRyeoV02P8vw,42589
|
|
31
32
|
unienv_data/transformations/image_compress.py,sha256=f8JTY4DJEXaiu5lO77T4ROV950rh_bOZBchOF-O0tx8,13130
|
|
32
33
|
unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
|
|
@@ -64,7 +65,7 @@ unienv_interface/space/spaces/text.py,sha256=ePGGJdiD3q-BAX6IHLO7HMe0OH4VrzF043K
|
|
|
64
65
|
unienv_interface/space/spaces/tuple.py,sha256=mmJab6kl5VtQStyn754pmk0RLPSQW06Mu15Hp3Qad80,4287
|
|
65
66
|
unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
|
|
66
67
|
unienv_interface/transformations/__init__.py,sha256=zf8NbY-HW4EgHri9PxpuelEvBpFwUtDEcJiXXhFSDNQ,435
|
|
67
|
-
unienv_interface/transformations/batch_and_unbatch.py,sha256=
|
|
68
|
+
unienv_interface/transformations/batch_and_unbatch.py,sha256=LIEQ_rtAdccdw38VdmWJT_DuqdOyb7aMFcMWlyQBz2U,2164
|
|
68
69
|
unienv_interface/transformations/chained_transform.py,sha256=_6E1g_8u-WAxKd-f2sHJwKQk9HTIRnulyXwHUwJP12I,2203
|
|
69
70
|
unienv_interface/transformations/crop.py,sha256=sigcQcLklp3P6b6KQfP-Ja3OV1CWeusCLNKMvNNdACQ,3107
|
|
70
71
|
unienv_interface/transformations/dict_transform.py,sha256=GhFSN9t3mL3gvoD_GH-np68Fo4m78YnSyHbUHeyzKcw,5540
|
|
@@ -98,7 +99,7 @@ unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKi
|
|
|
98
99
|
unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
|
|
99
100
|
unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
|
|
100
101
|
unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
|
|
101
|
-
unienv-0.0.
|
|
102
|
-
unienv-0.0.
|
|
103
|
-
unienv-0.0.
|
|
104
|
-
unienv-0.0.
|
|
102
|
+
unienv-0.0.1b8.dist-info/METADATA,sha256=eCut1OMadsIJC5XAqCZHVr9UsfdXqqrumSkH4uuNcTI,3056
|
|
103
|
+
unienv-0.0.1b8.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
104
|
+
unienv-0.0.1b8.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
|
|
105
|
+
unienv-0.0.1b8.dist-info/RECORD,,
|
|
@@ -12,6 +12,102 @@ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils
|
|
|
12
12
|
|
|
13
13
|
from ..base.common import BatchBase, IndexableType, BatchT
|
|
14
14
|
|
|
15
|
+
__all__ = [
|
|
16
|
+
"convert_single_index_to_batch",
|
|
17
|
+
"convert_index_to_batch",
|
|
18
|
+
"CombinedBatch",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
def convert_single_index_to_batch(
|
|
22
|
+
backend : ComputeBackend,
|
|
23
|
+
index_starts : BArrayType, # (num_batches, )
|
|
24
|
+
total_length : int, # Total length of all batches
|
|
25
|
+
idx : int,
|
|
26
|
+
device : Optional[BDeviceType] = None,
|
|
27
|
+
) -> Tuple[int, int]:
|
|
28
|
+
"""
|
|
29
|
+
Convert a single index for this batch to a tuple of:
|
|
30
|
+
- The index of the batch
|
|
31
|
+
- The index to index into the batch
|
|
32
|
+
"""
|
|
33
|
+
assert -total_length <= idx < total_length, f"Index {idx} out of bounds for batch of size {total_length}"
|
|
34
|
+
if idx < 0:
|
|
35
|
+
idx += total_length
|
|
36
|
+
batch_index = int(backend.sum(
|
|
37
|
+
idx >= index_starts
|
|
38
|
+
) - 1)
|
|
39
|
+
return batch_index, idx - int(index_starts[batch_index])
|
|
40
|
+
|
|
41
|
+
def convert_index_to_batch(
|
|
42
|
+
backend : ComputeBackend,
|
|
43
|
+
index_starts : BArrayType, # (num_batches, )
|
|
44
|
+
total_length : int, # Total length of all batches
|
|
45
|
+
idx : Union[IndexableType, BArrayType],
|
|
46
|
+
device : Optional[BDeviceType] = None,
|
|
47
|
+
) -> Tuple[
|
|
48
|
+
int,
|
|
49
|
+
Union[List[
|
|
50
|
+
Tuple[int, BArrayType, BArrayType]
|
|
51
|
+
], int]
|
|
52
|
+
]:
|
|
53
|
+
"""
|
|
54
|
+
Convert an index for this batch to a tuple of:
|
|
55
|
+
- The length of the resulting array
|
|
56
|
+
- List of tuples, each containing:
|
|
57
|
+
- The index of the batch
|
|
58
|
+
- The index to index into the batch
|
|
59
|
+
- The bool mask to index into the resulting array
|
|
60
|
+
"""
|
|
61
|
+
if isinstance(idx, slice):
|
|
62
|
+
idx_array = backend.arange(
|
|
63
|
+
*idx.indices(total_length),
|
|
64
|
+
dtype=backend.default_index_dtype,
|
|
65
|
+
device=device
|
|
66
|
+
)
|
|
67
|
+
elif idx is Ellipsis:
|
|
68
|
+
idx_array = backend.arange(
|
|
69
|
+
total_length,
|
|
70
|
+
dtype=backend.default_index_dtype,
|
|
71
|
+
device=device
|
|
72
|
+
)
|
|
73
|
+
elif backend.is_backendarray(idx):
|
|
74
|
+
assert len(idx.shape) == 1, "Index must be 1D"
|
|
75
|
+
assert backend.dtype_is_real_integer(idx.dtype) or backend.dtype_is_boolean(idx.dtype), \
|
|
76
|
+
f"Index must be of integer or boolean type, got {idx.dtype}"
|
|
77
|
+
if backend.dtype_is_boolean(idx.dtype):
|
|
78
|
+
assert idx.shape[0] == total_length, f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {total_length}"
|
|
79
|
+
idx_array = backend.nonzero(idx)[0]
|
|
80
|
+
else:
|
|
81
|
+
assert backend.all(backend.logical_and(-total_length <= idx, idx < total_length)), \
|
|
82
|
+
f"Index array contains out of bounds indices for batch of size {total_length}"
|
|
83
|
+
idx_array = (idx + total_length) % total_length
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(f"Invalid index type: {type(idx)}")
|
|
86
|
+
|
|
87
|
+
assert bool(backend.all(
|
|
88
|
+
backend.logical_and(
|
|
89
|
+
-total_length <= idx_array,
|
|
90
|
+
idx_array < total_length
|
|
91
|
+
)
|
|
92
|
+
)), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {total_length}"
|
|
93
|
+
|
|
94
|
+
# Convert negative indices to positive indices
|
|
95
|
+
idx_array = backend.at(idx_array)[idx_array < 0].add(total_length)
|
|
96
|
+
idx_array_bigger = idx_array[:, None] >= index_starts[None, :] # (idx_array_shape, len(self.batches))
|
|
97
|
+
idx_array_batch_idx = backend.sum(
|
|
98
|
+
idx_array_bigger,
|
|
99
|
+
axis=-1
|
|
100
|
+
) - 1 # (idx_array_shape, )
|
|
101
|
+
|
|
102
|
+
result_batch_list = []
|
|
103
|
+
batch_indexes = backend.unique_values(idx_array_batch_idx)
|
|
104
|
+
for i in range(batch_indexes.shape[0]):
|
|
105
|
+
batch_index = int(batch_indexes[i])
|
|
106
|
+
result_mask = idx_array_batch_idx == batch_index
|
|
107
|
+
index_into_batch = idx_array[result_mask] - index_starts[batch_index]
|
|
108
|
+
result_batch_list.append((batch_index, index_into_batch, result_mask))
|
|
109
|
+
return idx_array.shape[0], result_batch_list
|
|
110
|
+
|
|
15
111
|
class CombinedBatch(BatchBase[
|
|
16
112
|
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
17
113
|
]):
|
|
@@ -83,18 +179,13 @@ class CombinedBatch(BatchBase[
|
|
|
83
179
|
return self.index_caches[-1, 1]
|
|
84
180
|
|
|
85
181
|
def _convert_single_index(self, idx : int) -> Tuple[int, int]:
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
idx += len(self)
|
|
94
|
-
batch_index = int(self.backend.sum(
|
|
95
|
-
idx >= self.index_caches[:, 0]
|
|
96
|
-
) - 1)
|
|
97
|
-
return batch_index, idx - int(self.index_caches[batch_index, 0])
|
|
182
|
+
return convert_single_index_to_batch(
|
|
183
|
+
self.backend,
|
|
184
|
+
self.index_caches[:, 0],
|
|
185
|
+
len(self),
|
|
186
|
+
idx,
|
|
187
|
+
device=self.device,
|
|
188
|
+
)
|
|
98
189
|
|
|
99
190
|
def _convert_index(self, idx : Union[IndexableType, BArrayType]) -> Tuple[
|
|
100
191
|
int,
|
|
@@ -102,61 +193,13 @@ class CombinedBatch(BatchBase[
|
|
|
102
193
|
Tuple[int, BArrayType, BArrayType]
|
|
103
194
|
]
|
|
104
195
|
]:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
"""
|
|
113
|
-
if isinstance(idx, slice):
|
|
114
|
-
idx_array = self.backend.arange(
|
|
115
|
-
*idx.indices(len(self)),
|
|
116
|
-
dtype=self.backend.default_integer_dtype,
|
|
117
|
-
device=self.device
|
|
118
|
-
)
|
|
119
|
-
elif idx is Ellipsis:
|
|
120
|
-
idx_array = self.backend.arange(
|
|
121
|
-
len(self),
|
|
122
|
-
dtype=self.backend.default_integer_dtype,
|
|
123
|
-
device=self.device
|
|
124
|
-
)
|
|
125
|
-
elif self.backend.is_backendarray(idx):
|
|
126
|
-
assert len(idx.shape) == 1, "Index must be 1D"
|
|
127
|
-
assert self.backend.dtype_is_real_integer(idx.dtype) or self.backend.dtype_is_boolean(idx.dtype), \
|
|
128
|
-
f"Index must be of integer or boolean type, got {idx.dtype}"
|
|
129
|
-
if self.backend.dtype_is_boolean(idx.dtype):
|
|
130
|
-
assert idx.shape[0] == len(self), f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {len(self)}"
|
|
131
|
-
idx_array = self.backend.nonzero(idx)[0]
|
|
132
|
-
else:
|
|
133
|
-
idx_array = idx
|
|
134
|
-
else:
|
|
135
|
-
raise ValueError(f"Invalid index type: {type(idx)}")
|
|
136
|
-
|
|
137
|
-
assert bool(self.backend.all(
|
|
138
|
-
self.backend.logical_and(
|
|
139
|
-
-len(self) <= idx_array,
|
|
140
|
-
idx_array < len(self)
|
|
141
|
-
)
|
|
142
|
-
)), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {len(self)}"
|
|
143
|
-
|
|
144
|
-
# Convert negative indices to positive indices
|
|
145
|
-
idx_array = self.backend.at(idx_array)[idx_array < 0].add(len(self))
|
|
146
|
-
idx_array_bigger = idx_array[:, None] >= self.index_caches[None, :, 0] # (idx_array_shape, len(self.batches))
|
|
147
|
-
idx_array_batch_idx = self.backend.sum(
|
|
148
|
-
idx_array_bigger,
|
|
149
|
-
axis=-1
|
|
150
|
-
) - 1 # (idx_array_shape, )
|
|
151
|
-
|
|
152
|
-
result_batch_list = []
|
|
153
|
-
batch_indexes = self.backend.unique_values(idx_array_batch_idx)
|
|
154
|
-
for i in range(batch_indexes.shape[0]):
|
|
155
|
-
batch_index = int(batch_indexes[i])
|
|
156
|
-
result_mask = idx_array_batch_idx == batch_index
|
|
157
|
-
index_into_batch = idx_array[result_mask] - self.index_caches[batch_index, 0]
|
|
158
|
-
result_batch_list.append((batch_index, index_into_batch, result_mask))
|
|
159
|
-
return idx_array.shape[0], result_batch_list
|
|
196
|
+
return convert_index_to_batch(
|
|
197
|
+
self.backend,
|
|
198
|
+
self.index_caches[:, 0],
|
|
199
|
+
len(self),
|
|
200
|
+
idx,
|
|
201
|
+
device=self.device,
|
|
202
|
+
)
|
|
160
203
|
|
|
161
204
|
def get_flattened_at(self, idx : Union[IndexableType, BaseException]) -> BArrayType:
|
|
162
205
|
if isinstance(idx, int):
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from datasets import Dataset as HFDataset
|
|
3
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
4
|
+
from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
|
|
5
|
+
from unienv_data.base import BatchBase, BatchT
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'HFAsUniEnvDataset'
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
class HFAsUniEnvDataset(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
12
|
+
BACKEND_TO_FORMAT_MAP = {
|
|
13
|
+
"numpy": "numpy",
|
|
14
|
+
"pytorch": "torch",
|
|
15
|
+
"jax": "jax",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
is_mutable = False
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
hf_dataset: HFDataset,
|
|
23
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
24
|
+
device : Optional[BDeviceType] = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
kwargs = {}
|
|
27
|
+
if backend.simplified_name != 'numpy':
|
|
28
|
+
kwargs['device'] = device
|
|
29
|
+
self.hf_dataset = hf_dataset.with_format(
|
|
30
|
+
self.BACKEND_TO_FORMAT_MAP[backend.simplified_name],
|
|
31
|
+
**kwargs
|
|
32
|
+
)
|
|
33
|
+
first_data = self.hf_dataset[0]
|
|
34
|
+
super().__init__(
|
|
35
|
+
scu.construct_space_from_data(first_data, backend)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def __len__(self) -> int:
|
|
39
|
+
return len(self.hf_dataset)
|
|
40
|
+
|
|
41
|
+
def get_at_with_metadata(self, idx):
|
|
42
|
+
if self.backend.is_backendarray(idx) and self.backend.dtype_is_boolean(idx):
|
|
43
|
+
idx = self.backend.nonzero(idx)[0]
|
|
44
|
+
return self.hf_dataset[idx], {}
|
|
45
|
+
|
|
46
|
+
def get_at(self, idx):
|
|
47
|
+
return self.get_at_with_metadata(idx)[0]
|
|
@@ -3,12 +3,19 @@ from typing import Optional, Union, Tuple, Dict, Any
|
|
|
3
3
|
from torch.utils.data import Dataset
|
|
4
4
|
|
|
5
5
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
|
|
7
|
+
from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
|
|
6
8
|
from unienv_data.base import BatchBase, BatchT
|
|
7
9
|
|
|
8
|
-
|
|
10
|
+
__all__ = [
|
|
11
|
+
"UniEnvAsPyTorchDataset",
|
|
12
|
+
"PyTorchAsUniEnvDataset",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
class UniEnvAsPyTorchDataset(Dataset):
|
|
9
16
|
def __init__(
|
|
10
17
|
self,
|
|
11
|
-
batch : BatchBase[BatchT,
|
|
18
|
+
batch : BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
12
19
|
include_metadata : bool = False,
|
|
13
20
|
):
|
|
14
21
|
"""
|
|
@@ -60,4 +67,56 @@ class UniEnvPyTorchDataset(Dataset):
|
|
|
60
67
|
indices = self.batch.backend.asarray(indices, dtype=self.batch.backend.default_integer_dtype, device=self.batch.device)
|
|
61
68
|
if self.include_metadata:
|
|
62
69
|
return self.batch.get_at_with_metadata(indices)
|
|
63
|
-
return self.batch.get_at(indices)
|
|
70
|
+
return self.batch.get_at(indices)
|
|
71
|
+
|
|
72
|
+
class PyTorchAsUniEnvDataset(BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType]):
|
|
73
|
+
is_mutable = False
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dataset: Dataset,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
A UniEnvPy BatchBase wrapper for PyTorch Datasets.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dataset (Dataset): The PyTorch Dataset to wrap.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
assert len(dataset) >= 0, "The provided PyTorch Dataset must have a defined length."
|
|
86
|
+
self.dataset = dataset
|
|
87
|
+
tmp_data = dataset[0]
|
|
88
|
+
single_space = scu.construct_space_from_data(tmp_data, PyTorchComputeBackend)
|
|
89
|
+
super().__init__(
|
|
90
|
+
single_space,
|
|
91
|
+
None
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def __len__(self) -> int:
|
|
95
|
+
return len(self.dataset)
|
|
96
|
+
|
|
97
|
+
def get_at_with_metadata(self, idx):
|
|
98
|
+
if isinstance(idx, int):
|
|
99
|
+
data = self.dataset[idx]
|
|
100
|
+
return data, {}
|
|
101
|
+
elif idx is Ellipsis:
|
|
102
|
+
idx = self.backend.arange(0, len(self))
|
|
103
|
+
elif isinstance(idx, slice):
|
|
104
|
+
idx = self.backend.arange(*idx.indices(len(self)))
|
|
105
|
+
elif self.backend.is_backendarray(idx):
|
|
106
|
+
if self.backend.dtype_is_boolean(idx.dtype):
|
|
107
|
+
idx = self.backend.nonzero(idx)[0]
|
|
108
|
+
else:
|
|
109
|
+
assert self.backend.dtype_is_real_integer(idx.dtype), "Index array must be of integer or boolean type."
|
|
110
|
+
idx = (idx + len(self)) % len(self)
|
|
111
|
+
|
|
112
|
+
idx_list = self.backend.to_numpy(idx).tolist()
|
|
113
|
+
if hasattr(self.dataset, "__getitems__"):
|
|
114
|
+
data_list = self.dataset.__getitems__(idx_list)
|
|
115
|
+
else:
|
|
116
|
+
data_list = [self.dataset[i] for i in idx_list]
|
|
117
|
+
aggregated_data = sbu.concatenate(self._batched_space, data_list)
|
|
118
|
+
return aggregated_data, {}
|
|
119
|
+
|
|
120
|
+
def get_at(self, idx):
|
|
121
|
+
data, _ = self.get_at_with_metadata(idx)
|
|
122
|
+
return data
|
|
@@ -188,8 +188,12 @@ class EpisodeStorageBase(SpaceStorage[
|
|
|
188
188
|
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
189
189
|
else:
|
|
190
190
|
index = self.backend.arange(0, self.length, device=self.device)
|
|
191
|
-
elif self.backend.is_backendarray(index)
|
|
192
|
-
|
|
191
|
+
elif self.backend.is_backendarray(index):
|
|
192
|
+
if self.backend.dtype_is_boolean(index.dtype):
|
|
193
|
+
index = self.backend.nonzero(index)[0]
|
|
194
|
+
else:
|
|
195
|
+
assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
|
|
196
|
+
index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
|
|
193
197
|
|
|
194
198
|
batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
|
|
195
199
|
|
|
@@ -368,8 +372,13 @@ class EpisodeStorageBase(SpaceStorage[
|
|
|
368
372
|
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
369
373
|
else:
|
|
370
374
|
index = self.backend.arange(0, self.length, device=self.device)
|
|
371
|
-
elif self.backend.is_backendarray(index)
|
|
372
|
-
|
|
375
|
+
elif self.backend.is_backendarray(index):
|
|
376
|
+
if self.backend.dtype_is_boolean(index.dtype):
|
|
377
|
+
index = self.backend.nonzero(index)[0]
|
|
378
|
+
else:
|
|
379
|
+
assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
|
|
380
|
+
index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
|
|
381
|
+
|
|
373
382
|
assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
|
|
374
383
|
sorted_indexes_arg = self.backend.argsort(index)
|
|
375
384
|
sorted_indexes = index[sorted_indexes_arg]
|
unienv_data/storages/pytorch.py
CHANGED
|
@@ -72,6 +72,7 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
72
72
|
capacity : Optional[int] = None,
|
|
73
73
|
read_only : bool = True,
|
|
74
74
|
multiprocessing : bool = False,
|
|
75
|
+
**kwargs,
|
|
75
76
|
) -> "PytorchTensorStorage":
|
|
76
77
|
assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
|
|
77
78
|
assert capacity is not None, "Capacity must be specified when creating a new tensor"
|
|
@@ -1,23 +1,239 @@
|
|
|
1
1
|
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type, Literal, cast
|
|
2
2
|
from fractions import Fraction
|
|
3
|
-
|
|
4
|
-
from unienv_interface.space
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import BoxSpace
|
|
5
5
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from unienv_interface.backends.pytorch import PyTorchComputeBackend
|
|
6
7
|
from unienv_interface.utils.symbol_util import *
|
|
7
8
|
|
|
8
|
-
from unienv_data.base import SpaceStorage
|
|
9
9
|
from ._episode_storage import IndexableType, EpisodeStorageBase
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import os
|
|
13
13
|
import json
|
|
14
|
-
import
|
|
14
|
+
import logging
|
|
15
|
+
import importlib
|
|
16
|
+
import importlib.util
|
|
15
17
|
|
|
16
18
|
import av
|
|
17
|
-
|
|
18
|
-
from imageio.plugins.pyav import PyAVPlugin
|
|
19
|
-
from av.codec.hwaccel import HWAccel, hwdevices_available
|
|
19
|
+
from av.codec.hwaccel import HWAccel as PyAvHWAccel
|
|
20
20
|
from av.codec import codecs_available
|
|
21
|
+
import av.error
|
|
22
|
+
from av.video.reformatter import VideoReformatter
|
|
23
|
+
|
|
24
|
+
# TorchCodec backend for video encoding / decoding
|
|
25
|
+
# import av
|
|
26
|
+
# from av.codec import codecs_available
|
|
27
|
+
# from torchcodec.decoders import VideoDecoder, set_cuda_backend
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import torch
|
|
31
|
+
except ImportError:
|
|
32
|
+
torch = None
|
|
33
|
+
|
|
34
|
+
LOGGER = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
class PyAvVideoReader:
|
|
37
|
+
HWAccel = PyAvHWAccel
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def available_hwdevices() -> List[str]:
|
|
41
|
+
from av.codec.hwaccel import hwdevices_available
|
|
42
|
+
return hwdevices_available()
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_auto_hwaccel() -> Optional[HWAccel]:
|
|
46
|
+
available_hwdevices = __class__.available_hwdevices()
|
|
47
|
+
target_hwaccel = None
|
|
48
|
+
if "d3d11va" in available_hwdevices:
|
|
49
|
+
target_hwaccel = PyAvHWAccel(device_type="d3d11va", allow_software_fallback=True)
|
|
50
|
+
elif "cuda" in available_hwdevices:
|
|
51
|
+
target_hwaccel = PyAvHWAccel(device_type="cuda", allow_software_fallback=True)
|
|
52
|
+
elif "vaapi" in available_hwdevices:
|
|
53
|
+
target_hwaccel = PyAvHWAccel(device_type="vaapi", allow_software_fallback=True)
|
|
54
|
+
elif "videotoolbox" in available_hwdevices:
|
|
55
|
+
target_hwaccel = PyAvHWAccel(device_type="videotoolbox", allow_software_fallback=True)
|
|
56
|
+
return target_hwaccel
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
backend : ComputeBackend,
|
|
61
|
+
filename: str,
|
|
62
|
+
buffer_pixel_format : Optional[str] = None,
|
|
63
|
+
hwaccel: Optional[Union[HWAccel, Literal['auto']]] = None,
|
|
64
|
+
seek_mode: Literal['exact', 'approximate'] = 'exact',
|
|
65
|
+
device : Optional[BDeviceType] = None,
|
|
66
|
+
):
|
|
67
|
+
if seek_mode != 'exact':
|
|
68
|
+
LOGGER.warning("PyAvVideoReader only supports 'exact' seek mode. Falling back to 'exact'.")
|
|
69
|
+
|
|
70
|
+
if hwaccel == 'auto':
|
|
71
|
+
hwaccel = __class__.get_auto_hwaccel()
|
|
72
|
+
self.container = av.open(filename, mode='r', hwaccel=hwaccel)
|
|
73
|
+
self.video_stream = self.container.streams.video[0]
|
|
74
|
+
if buffer_pixel_format is not None:
|
|
75
|
+
self.video_reformatter = VideoReformatter()
|
|
76
|
+
else:
|
|
77
|
+
self.video_reformatter = None
|
|
78
|
+
self.buffer_pixel_format = buffer_pixel_format
|
|
79
|
+
self.total_frames = self.video_stream.frames # Sometimes this reads 0 for some containers (as they don't explicitly store it)
|
|
80
|
+
self.frame_iterator = self.container.decode(self.video_stream)
|
|
81
|
+
self.backend = backend
|
|
82
|
+
self.device = device
|
|
83
|
+
|
|
84
|
+
def seek(self, frame_index: int):
|
|
85
|
+
self.container.seek(frame_index, any_frame=True, backward=True, stream=self.video_stream)
|
|
86
|
+
self.frame_iterator = self.container.decode(self.video_stream)
|
|
87
|
+
|
|
88
|
+
def __next__(self):
|
|
89
|
+
try:
|
|
90
|
+
frame = next(self.frame_iterator)
|
|
91
|
+
if self.video_reformatter is not None:
|
|
92
|
+
frame = self.video_reformatter.reformat(
|
|
93
|
+
frame,
|
|
94
|
+
format=self.buffer_pixel_format,
|
|
95
|
+
)
|
|
96
|
+
frame = frame.to_ndarray(channel_last=True)
|
|
97
|
+
except av.error.EOFError:
|
|
98
|
+
raise StopIteration
|
|
99
|
+
if np.prod(frame.shape) == 0:
|
|
100
|
+
raise StopIteration
|
|
101
|
+
|
|
102
|
+
return frame
|
|
103
|
+
|
|
104
|
+
def __iter__(self):
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def __enter__(self):
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
def __exit__(self, *args, **kwargs):
|
|
111
|
+
self.container.close()
|
|
112
|
+
|
|
113
|
+
def read(self, index : Union[IndexableType, BArrayType], total_length : int) -> BArrayType:
|
|
114
|
+
if isinstance(index, int):
|
|
115
|
+
self.seek(index)
|
|
116
|
+
frame_np = next(self)
|
|
117
|
+
frame = self.backend.from_numpy(frame_np)
|
|
118
|
+
if self.device is not None:
|
|
119
|
+
frame = self.backend.to_device(frame, self.device)
|
|
120
|
+
return frame
|
|
121
|
+
else:
|
|
122
|
+
if index is Ellipsis:
|
|
123
|
+
index = np.arange(total_length)
|
|
124
|
+
elif isinstance(index, slice):
|
|
125
|
+
index = np.arange(*index.indices(total_length))
|
|
126
|
+
elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
|
|
127
|
+
index = self.backend.nonzero(index)[0]
|
|
128
|
+
if self.backend.is_backendarray(index):
|
|
129
|
+
index = self.backend.to_numpy(index)
|
|
130
|
+
|
|
131
|
+
argsorted_indices = np.argsort(index)
|
|
132
|
+
sorted_index = index[argsorted_indices]
|
|
133
|
+
reserve_index = np.argsort(argsorted_indices)
|
|
134
|
+
|
|
135
|
+
if len(index) < total_length // 2:
|
|
136
|
+
all_frames_np = []
|
|
137
|
+
past_frame_np = None
|
|
138
|
+
for frame_i in sorted_index:
|
|
139
|
+
self.seek(frame_i)
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
frame_np = next(self)
|
|
143
|
+
except StopIteration:
|
|
144
|
+
frame_np = past_frame_np
|
|
145
|
+
past_frame_np = frame_np
|
|
146
|
+
all_frames_np.append(frame_np)
|
|
147
|
+
|
|
148
|
+
all_frames_np = np.stack(all_frames_np, axis=0)
|
|
149
|
+
# Reorder from sorted order back to original order
|
|
150
|
+
all_frames_np = all_frames_np[reserve_index]
|
|
151
|
+
else:
|
|
152
|
+
# Create a set for O(1) lookup and a mapping from frame index to position in sorted_index
|
|
153
|
+
sorted_index_set = set(sorted_index)
|
|
154
|
+
frame_to_sorted_pos = {int(frame_idx): pos for pos, frame_idx in enumerate(sorted_index)}
|
|
155
|
+
|
|
156
|
+
# Pre-allocate array to store frames in sorted order
|
|
157
|
+
all_frames_list = [None] * len(sorted_index)
|
|
158
|
+
past_frame_np = None
|
|
159
|
+
self.seek(0)
|
|
160
|
+
for frame_i in range(total_length):
|
|
161
|
+
try:
|
|
162
|
+
frame_np = next(self)
|
|
163
|
+
except StopIteration:
|
|
164
|
+
frame_np = past_frame_np
|
|
165
|
+
if frame_i in sorted_index_set:
|
|
166
|
+
# Store at the position corresponding to sorted_index order
|
|
167
|
+
all_frames_list[frame_to_sorted_pos[frame_i]] = frame_np
|
|
168
|
+
past_frame_np = frame_np
|
|
169
|
+
all_frames_np = np.stack(all_frames_list, axis=0)
|
|
170
|
+
# Reorder from sorted order back to original order
|
|
171
|
+
all_frames_np = all_frames_np[reserve_index]
|
|
172
|
+
all_frames = self.backend.from_numpy(all_frames_np)
|
|
173
|
+
if self.device is not None:
|
|
174
|
+
all_frames = self.backend.to_device(all_frames, self.device)
|
|
175
|
+
|
|
176
|
+
return all_frames
|
|
177
|
+
|
|
178
|
+
class TorchCodecVideoReader:
|
|
179
|
+
HWAccel = Literal['beta', 'ffmpeg']
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
backend : ComputeBackend,
|
|
183
|
+
filename: str,
|
|
184
|
+
buffer_pixel_format : Optional[str] = None,
|
|
185
|
+
hwaccel: Optional[Union[HWAccel, Literal['auto']]] = None,
|
|
186
|
+
seek_mode: Literal['exact', 'approximate'] = 'exact',
|
|
187
|
+
device : Optional[BDeviceType] = None,
|
|
188
|
+
):
|
|
189
|
+
from torchcodec.decoders import VideoDecoder, set_cuda_backend
|
|
190
|
+
assert torch is not None, "TorchCodecVideoReader requires PyTorch and TorchCodec to be installed."
|
|
191
|
+
assert buffer_pixel_format == 'rgb24', "TorchCodecVideoReader currently only supports 'rgb24' buffer pixel format."
|
|
192
|
+
if hwaccel == 'auto':
|
|
193
|
+
hwaccel = __class__.get_auto_hwaccel()
|
|
194
|
+
if hwaccel is not None:
|
|
195
|
+
with set_cuda_backend(hwaccel):
|
|
196
|
+
self.decoder = VideoDecoder(filename, device='cuda', seek_mode=seek_mode)
|
|
197
|
+
else:
|
|
198
|
+
self.decoder = VideoDecoder(filename, seek_mode=seek_mode)
|
|
199
|
+
self.backend = backend
|
|
200
|
+
self.device = device
|
|
201
|
+
self.buffer_pixel_format = buffer_pixel_format
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def get_auto_hwaccel() -> Optional[HWAccel]:
|
|
205
|
+
assert torch is not None, "TorchCodecVideoReader requires PyTorch to be installed."
|
|
206
|
+
if torch.cuda.is_available():
|
|
207
|
+
return 'beta'
|
|
208
|
+
else:
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
def __enter__(self):
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
def __exit__(self, *args, **kwargs):
|
|
215
|
+
pass
|
|
216
|
+
|
|
217
|
+
def read(self, index : Union[IndexableType, np.ndarray], total_length : int) -> np.ndarray:
|
|
218
|
+
if isinstance(index, int):
|
|
219
|
+
ret = self.decoder.get_frame_at(index).data.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
|
220
|
+
elif index is Ellipsis:
|
|
221
|
+
ret = self.decoder.get_frames_in_range(0, len(self.decoder)).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
222
|
+
elif isinstance(index, slice):
|
|
223
|
+
start, stop, step = index.indices(len(self))
|
|
224
|
+
ret = self.decoder.get_frames_in_range(start, stop, step=step).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
225
|
+
else:
|
|
226
|
+
if self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
|
|
227
|
+
index = self.backend.nonzero(index)[0]
|
|
228
|
+
if self.backend.simplified_name != 'pytorch':
|
|
229
|
+
index = PyTorchComputeBackend.from_other_backend(self.backend, index).to('cpu', torch.int64)
|
|
230
|
+
ret = self.decoder.get_frames_at(index).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
231
|
+
|
|
232
|
+
if self.backend.simplified_name != 'pytorch':
|
|
233
|
+
ret = self.backend.from_other_backend(PyTorchComputeBackend, ret)
|
|
234
|
+
if self.device is not None:
|
|
235
|
+
ret = self.backend.to_device(ret, self.device)
|
|
236
|
+
return ret
|
|
21
237
|
|
|
22
238
|
class VideoStorage(EpisodeStorageBase[
|
|
23
239
|
BArrayType,
|
|
@@ -30,7 +246,7 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
30
246
|
A storage for RGB or depth video data using video files
|
|
31
247
|
If encoding RGB video
|
|
32
248
|
- Set `buffer_pixel_format` to `rgb24`
|
|
33
|
-
- Set `file_pixel_format` to `None`
|
|
249
|
+
- Set `file_pixel_format` to `None` (especially when running with nvenc codec)
|
|
34
250
|
- Set `file_ext` to anything you like (e.g., "mp4", "avi", "mkv", etc.)
|
|
35
251
|
If encoding depth video
|
|
36
252
|
- Set `buffer_pixel_format` to `gray16le` (You can use rescale transform inside a `TransformedStorage` to convert depth values to this format, where `dtype` should be `np.uint16`) - if in meters, set min to 0 and max to 65.535 as the multiplication factor is 1000 (i.e., depth in mm)
|
|
@@ -40,12 +256,15 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
40
256
|
"""
|
|
41
257
|
|
|
42
258
|
# ========== Class Attributes ==========
|
|
259
|
+
PyAV_LOG_LEVEL = av.logging.WARNING
|
|
260
|
+
|
|
43
261
|
@classmethod
|
|
44
262
|
def create(
|
|
45
263
|
cls,
|
|
46
264
|
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
47
265
|
*args,
|
|
48
|
-
|
|
266
|
+
seek_mode : Literal['exact', 'approximate'] = 'exact',
|
|
267
|
+
hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
|
|
49
268
|
codec : Union[str, Literal['auto']] = 'auto',
|
|
50
269
|
file_ext : str = "mp4",
|
|
51
270
|
file_pixel_format : Optional[str] = None,
|
|
@@ -64,6 +283,7 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
64
283
|
return VideoStorage(
|
|
65
284
|
single_instance_space,
|
|
66
285
|
cache_filename=cache_path,
|
|
286
|
+
seek_mode=seek_mode,
|
|
67
287
|
hardware_acceleration=hardware_acceleration,
|
|
68
288
|
codec=codec,
|
|
69
289
|
file_ext=file_ext,
|
|
@@ -79,7 +299,8 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
79
299
|
path : Union[str, os.PathLike],
|
|
80
300
|
single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
81
301
|
*,
|
|
82
|
-
|
|
302
|
+
seek_mode : Literal['exact', 'approximate'] = 'exact',
|
|
303
|
+
hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
|
|
83
304
|
codec : Union[str, Literal['auto']] = 'auto',
|
|
84
305
|
capacity : Optional[int] = None,
|
|
85
306
|
read_only : bool = True,
|
|
@@ -113,6 +334,7 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
113
334
|
return VideoStorage(
|
|
114
335
|
single_instance_space,
|
|
115
336
|
cache_filename=path,
|
|
337
|
+
seek_mode=seek_mode,
|
|
116
338
|
hardware_acceleration=hardware_acceleration,
|
|
117
339
|
codec=codec,
|
|
118
340
|
file_ext=file_ext,
|
|
@@ -127,31 +349,20 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
127
349
|
# ========== Instance Implementations ==========
|
|
128
350
|
single_file_ext = None
|
|
129
351
|
|
|
130
|
-
@staticmethod
|
|
131
|
-
def get_auto_hwaccel() -> Optional[HWAccel]:
|
|
132
|
-
if hasattr(__class__, "_auto_hwaccel"):
|
|
133
|
-
return __class__._auto_hwaccel
|
|
134
|
-
available_hwdevices = hwdevices_available()
|
|
135
|
-
target_hwaccel = None
|
|
136
|
-
if "d3d11va" in available_hwdevices:
|
|
137
|
-
target_hwaccel = HWAccel(device_type="d3d11va", allow_software_fallback=True)
|
|
138
|
-
elif "cuda" in available_hwdevices:
|
|
139
|
-
target_hwaccel = HWAccel(device_type="cuda", allow_software_fallback=True)
|
|
140
|
-
elif "vaapi" in available_hwdevices:
|
|
141
|
-
target_hwaccel = HWAccel(device_type="vaapi", allow_software_fallback=True)
|
|
142
|
-
elif "videotoolbox" in available_hwdevices:
|
|
143
|
-
target_hwaccel = HWAccel(device_type="videotoolbox", allow_software_fallback=True)
|
|
144
|
-
__class__._auto_hwaccel = target_hwaccel
|
|
145
|
-
return target_hwaccel
|
|
146
|
-
|
|
147
352
|
@staticmethod
|
|
148
353
|
def get_auto_codec(
|
|
149
354
|
base : Optional[str] = None
|
|
150
355
|
) -> str:
|
|
151
356
|
if hasattr(__class__, "_auto_codec"):
|
|
152
357
|
return __class__._auto_codec
|
|
358
|
+
|
|
359
|
+
# --- PyAV Implementation ---
|
|
153
360
|
preferred_codecs = ["av1", "hevc", "h264", "mpeg4", "vp9", "vp8"] if base is None else [base]
|
|
154
|
-
preferred_suffixes = ["_nvenc", "_amf", "_qsv"]
|
|
361
|
+
preferred_suffixes = ["_nvenc", "_vaapi", "_amf", "_qsv", "_videotoolbox"]
|
|
362
|
+
|
|
363
|
+
# --- TorchCodec Implementation ---
|
|
364
|
+
# preferred_codecs = ["hevc", "av1", "h264", "mpeg4", "vp9", "vp8"] if base is None else [base]
|
|
365
|
+
# preferred_suffixes = ["_nvenc"]
|
|
155
366
|
|
|
156
367
|
target_codec = None
|
|
157
368
|
for codec in preferred_codecs:
|
|
@@ -177,8 +388,10 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
177
388
|
self,
|
|
178
389
|
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
179
390
|
cache_filename : Union[str, os.PathLike],
|
|
180
|
-
|
|
391
|
+
seek_mode : Literal['exact', 'approximate'] = 'exact',
|
|
392
|
+
hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
|
|
181
393
|
codec : Union[str, Literal['auto']] = 'auto',
|
|
394
|
+
decode_backend : Literal['torchcodec', 'pyav', 'auto'] = 'auto',
|
|
182
395
|
file_ext : str = "mp4",
|
|
183
396
|
file_pixel_format : Optional[str] = None,
|
|
184
397
|
buffer_pixel_format : str = "rgb24",
|
|
@@ -195,87 +408,112 @@ class VideoStorage(EpisodeStorageBase[
|
|
|
195
408
|
capacity=capacity,
|
|
196
409
|
length=length,
|
|
197
410
|
)
|
|
198
|
-
self.
|
|
199
|
-
|
|
200
|
-
)
|
|
411
|
+
self.seek_mode = seek_mode
|
|
412
|
+
self.hwaccel = hardware_acceleration
|
|
201
413
|
self.codec = self.get_auto_codec() if codec == 'auto' else codec
|
|
414
|
+
|
|
415
|
+
if decode_backend == 'auto':
|
|
416
|
+
if importlib.util.find_spec("torchcodec"):
|
|
417
|
+
decode_backend = 'torchcodec'
|
|
418
|
+
else:
|
|
419
|
+
decode_backend = 'pyav'
|
|
420
|
+
self.decode_backend = decode_backend
|
|
421
|
+
|
|
202
422
|
self.fps = fps
|
|
203
423
|
self.file_pixel_format = file_pixel_format
|
|
204
424
|
self.buffer_pixel_format = buffer_pixel_format
|
|
205
|
-
|
|
425
|
+
|
|
206
426
|
def get_from_file(self, filename : str, index : Union[IndexableType, BArrayType], total_length : int) -> BArrayType:
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
if self.backend.is_backendarray(index):
|
|
223
|
-
index = self.backend.to_numpy(index)
|
|
224
|
-
|
|
225
|
-
argsorted_indices = np.argsort(index)
|
|
226
|
-
sorted_index = index[argsorted_indices]
|
|
227
|
-
reserve_index = np.argsort(argsorted_indices)
|
|
228
|
-
|
|
229
|
-
if len(index) < total_length // 2:
|
|
230
|
-
all_frames_np = []
|
|
231
|
-
for frame_i in sorted_index:
|
|
232
|
-
frame_np = video.read(index=frame_i, format=self.buffer_pixel_format)
|
|
233
|
-
all_frames_np.append(frame_np)
|
|
234
|
-
all_frames_np = np.stack(all_frames_np, axis=0)
|
|
235
|
-
# Reorder from sorted order back to original order
|
|
236
|
-
all_frames_np = all_frames_np[reserve_index]
|
|
237
|
-
else:
|
|
238
|
-
# Create a set for O(1) lookup and a mapping from frame index to position in sorted_index
|
|
239
|
-
sorted_index_set = set(sorted_index)
|
|
240
|
-
frame_to_sorted_pos = {int(frame_idx): pos for pos, frame_idx in enumerate(sorted_index)}
|
|
241
|
-
|
|
242
|
-
# Pre-allocate array to store frames in sorted order
|
|
243
|
-
all_frames_list = [None] * len(sorted_index)
|
|
244
|
-
past_frame_np = None
|
|
245
|
-
video_iter = video.iter(format=self.buffer_pixel_format)
|
|
246
|
-
for frame_i in range(total_length):
|
|
247
|
-
try:
|
|
248
|
-
frame_np = next(video_iter)
|
|
249
|
-
except StopIteration:
|
|
250
|
-
frame_np = past_frame_np
|
|
251
|
-
if frame_i in sorted_index_set:
|
|
252
|
-
# Store at the position corresponding to sorted_index order
|
|
253
|
-
all_frames_list[frame_to_sorted_pos[frame_i]] = frame_np
|
|
254
|
-
past_frame_np = frame_np
|
|
255
|
-
all_frames_np = np.stack(all_frames_list, axis=0)
|
|
256
|
-
# Reorder from sorted order back to original order
|
|
257
|
-
all_frames_np = all_frames_np[reserve_index]
|
|
258
|
-
all_frames = self.backend.from_numpy(all_frames_np)
|
|
259
|
-
if self.device is not None:
|
|
260
|
-
all_frames = self.backend.to_device(all_frames, self.device)
|
|
261
|
-
return all_frames
|
|
427
|
+
if self.decode_backend == 'pyav':
|
|
428
|
+
reader_cls = PyAvVideoReader
|
|
429
|
+
elif self.decode_backend == 'torchcodec':
|
|
430
|
+
reader_cls = TorchCodecVideoReader
|
|
431
|
+
else:
|
|
432
|
+
raise ValueError(f"Unknown decode_backend {self.decode_backend}")
|
|
433
|
+
with reader_cls(
|
|
434
|
+
backend=self.backend,
|
|
435
|
+
filename=filename,
|
|
436
|
+
buffer_pixel_format=self.buffer_pixel_format,
|
|
437
|
+
hwaccel=self.hwaccel,
|
|
438
|
+
seek_mode=self.seek_mode,
|
|
439
|
+
device=self.device,
|
|
440
|
+
) as video_reader:
|
|
441
|
+
return video_reader.read(index, total_length)
|
|
262
442
|
|
|
443
|
+
# PyAV Implementation (Commented out due to bugs with hevc codec)
|
|
263
444
|
def set_to_file(self, filename : str, value : BArrayType):
|
|
264
|
-
|
|
265
|
-
with iio.imopen(
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
) as video:
|
|
270
|
-
|
|
271
|
-
|
|
445
|
+
# ImageIO Implementation
|
|
446
|
+
# with iio.imopen(
|
|
447
|
+
# filename,
|
|
448
|
+
# 'w',
|
|
449
|
+
# plugin='pyav',
|
|
450
|
+
# ) as video:
|
|
451
|
+
# video = cast(PyAVPlugin, video)
|
|
452
|
+
# video.init_video_stream(self.codec, fps=self.fps, pixel_format=self.file_pixel_format)
|
|
453
|
+
|
|
454
|
+
# # Fix codec time base if not set:
|
|
455
|
+
# if video._video_stream.codec_context.time_base is None:
|
|
456
|
+
# video._video_stream.codec_context.time_base = Fraction(1 / self.fps).limit_denominator(int(2**16 - 1))
|
|
457
|
+
|
|
458
|
+
# for i, frame in enumerate(value_np):
|
|
459
|
+
# video.write_frame(frame, pixel_format=self.buffer_pixel_format)
|
|
460
|
+
|
|
461
|
+
# PyAV Implementation
|
|
462
|
+
logging.getLogger("libav").setLevel(self.PyAV_LOG_LEVEL)
|
|
463
|
+
with av.open(filename, mode='w') as container:
|
|
464
|
+
output_stream = container.add_stream(self.codec, rate=self.fps)
|
|
465
|
+
if len(self.single_instance_space.shape) == 3: # (H, W, C)
|
|
466
|
+
output_stream.width = self.single_instance_space.shape[1]
|
|
467
|
+
output_stream.height = self.single_instance_space.shape[0]
|
|
468
|
+
else: # (H, W)
|
|
469
|
+
output_stream.width = self.single_instance_space.shape[1]
|
|
470
|
+
output_stream.height = self.single_instance_space.shape[0]
|
|
471
|
+
if self.file_pixel_format is not None:
|
|
472
|
+
output_stream.pix_fmt = self.file_pixel_format
|
|
473
|
+
# output_stream.time_base = Fraction(1, self.fps).limit_denominator(int(2**16 - 1))
|
|
474
|
+
value_np = self.backend.to_numpy(value)
|
|
475
|
+
for i, frame_np in enumerate(value_np):
|
|
476
|
+
frame = av.VideoFrame.from_ndarray(frame_np, format=self.buffer_pixel_format, channel_last=True)
|
|
477
|
+
packets = output_stream.encode(frame)
|
|
478
|
+
if packets:
|
|
479
|
+
container.mux(packets)
|
|
480
|
+
# Flush stream
|
|
481
|
+
packets = output_stream.encode()
|
|
482
|
+
if packets:
|
|
483
|
+
container.mux(packets)
|
|
272
484
|
|
|
273
|
-
#
|
|
274
|
-
if
|
|
275
|
-
|
|
485
|
+
# Close container
|
|
486
|
+
if hasattr(output_stream, 'close'):
|
|
487
|
+
output_stream.close()
|
|
488
|
+
# container.close()
|
|
489
|
+
|
|
490
|
+
# Restore logging level
|
|
491
|
+
av.logging.restore_default_callback()
|
|
276
492
|
|
|
277
|
-
|
|
278
|
-
|
|
493
|
+
# TorchCodec Implementation (extremely slow, and seems to have double-write issues)
|
|
494
|
+
# def set_to_file(self, filename : str, value : BArrayType):
|
|
495
|
+
# if self.backend.simplified_name != 'pytorch':
|
|
496
|
+
# value_pt = PyTorchComputeBackend.from_other_backend(self.backend, value)
|
|
497
|
+
# else:
|
|
498
|
+
# value_pt = value
|
|
499
|
+
# if self.codec.endswith('_nvenc'):
|
|
500
|
+
# value_pt = value_pt.to("cuda")
|
|
501
|
+
# else:
|
|
502
|
+
# value_pt = value_pt.to("cpu")
|
|
503
|
+
|
|
504
|
+
# if len(self.single_instance_space.shape) < 3:
|
|
505
|
+
# # Add channel dimension for grayscale images
|
|
506
|
+
# value_pt = value_pt[..., None] # (N, H, W) -> (N, H, W, 1)
|
|
507
|
+
|
|
508
|
+
# encoder = VideoEncoder(
|
|
509
|
+
# value_pt.permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W)
|
|
510
|
+
# frame_rate=self.fps,
|
|
511
|
+
# )
|
|
512
|
+
# encoder.to_file(
|
|
513
|
+
# filename,
|
|
514
|
+
# codec=self.codec,
|
|
515
|
+
# pixel_format=self.file_pixel_format,
|
|
516
|
+
# )
|
|
279
517
|
|
|
280
518
|
def dumps(self, path):
|
|
281
519
|
assert os.path.samefile(path, self.cache_filename), \
|
|
File without changes
|
|
File without changes
|