unienv 0.0.1b7__py3-none-any.whl → 0.0.1b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b7
3
+ Version: 0.0.1b9
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -12,7 +12,7 @@ Requires-Python: >=3.10
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  Requires-Dist: numpy
15
- Requires-Dist: xbarray>=0.0.1a13
15
+ Requires-Dist: xbarray>=0.0.1a14
16
16
  Requires-Dist: pillow
17
17
  Requires-Dist: cloudpickle
18
18
  Requires-Dist: pyvers
@@ -1,22 +1,23 @@
1
- unienv-0.0.1b7.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
1
+ unienv-0.0.1b9.dist-info/licenses/LICENSE,sha256=nkklvEaJUR4QDBygz7tkEe1FMVKV1JSjnGzJNLhdIWM,1091
2
2
  unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
3
3
  unienv_data/base/__init__.py,sha256=w-I8A-z7YYArkHc2ZOVGrfzfThsaDBg7aD7qMFprNM8,186
4
4
  unienv_data/base/common.py,sha256=A3RtD3Omqk0Qplsc-44ukAEzbQEU22_MkwUlC7l-HHM,13083
5
5
  unienv_data/base/storage.py,sha256=afICsO_7Zbm9azV0Jxho_z9F7JM30TUDjJM1NHETDHM,5495
6
6
  unienv_data/batches/__init__.py,sha256=Vi92f8ddgFYCqwv7xO2Pi3oJePnioJ4XrJbQVV7eIvk,234
7
7
  unienv_data/batches/backend_compat.py,sha256=tzFG8gTq0yW-J6PLvu--lCGS0lFc0QfelicJ50p_HYc,8207
8
- unienv_data/batches/combined_batch.py,sha256=pNrbLvU565BUDWO0pZLCnSMygmoGVCLxjC9OkLRKtLA,15330
8
+ unienv_data/batches/combined_batch.py,sha256=iMCY7pEXf_YrIvq-G92ttaANl-SUcpzDLvH7fI-GImA,16313
9
9
  unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
10
10
  unienv_data/batches/slicestack_batch.py,sha256=Q3-gsJTvMjKTeZAHWNBTGRsws0HctsfMMTw0vylNxvA,16785
11
11
  unienv_data/batches/transformations.py,sha256=b4HqX3wZ6TuRgQ2q81Jv43PmeHGmP8cwURK_ULjGNgs,5647
12
- unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
12
+ unienv_data/integrations/huggingface.py,sha256=cqaNS1Xpv5udAYgytd0zsVI8k9lSKMKz1jdxWnDseK4,1535
13
+ unienv_data/integrations/pytorch.py,sha256=cyUypd4kZVH8WQFrrv4yFgYZJGT4MeuhvxzZMc4-1dM,4553
13
14
  unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
14
15
  unienv_data/replay_buffer/replay_buffer.py,sha256=8vPma5dL6jDGhI3Oo6IEvNcDYJG9Lb0Xlvxp45tQMEs,14498
15
16
  unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=cqRmzdewFS8IvJcMwxxQgwZf7TvvrViym87OaCOes3Y,24009
16
17
  unienv_data/samplers/__init__.py,sha256=e7uunWN3r-g_2fDaMsYMe8cZcF4N-okCxqBPweQnE0s,97
17
18
  unienv_data/samplers/multiprocessing_sampler.py,sha256=FEBK8pMTnkpA0xuMkbvlv4aIdVTTubeT8BjL60BJL5o,13254
18
19
  unienv_data/samplers/step_sampler.py,sha256=ZCcrx9WbILtaR6izhIP3DhtmFcP7KQBdaYaSZ7vWwRk,3010
19
- unienv_data/storages/_episode_storage.py,sha256=OpZt4P-P6LHrBR4F-tNcCFROLskWaOKWCDfoPV7qz1I,21970
20
+ unienv_data/storages/_episode_storage.py,sha256=gdrj_bIRquyT0CIm9-0du5cCNhF-4cxwvf9u9POcZVg,22603
20
21
  unienv_data/storages/_list_storage.py,sha256=pH9xZOqXCx65NBRRD-INcP8OP-NWsI-JvdzVsPj9MSg,5225
21
22
  unienv_data/storages/backend_compat.py,sha256=BxeMJlC3FI60KLJ7QB5kF-mrGlJ6xi584Dcu4IN4Zrc,10714
22
23
  unienv_data/storages/dict_storage.py,sha256=DSqRIgo3m1XtUcLtyjYSqqpi01mr_nJOLg5BCddwPcg,13862
@@ -24,9 +25,9 @@ unienv_data/storages/flattened.py,sha256=Yf1G4D6KE36sESyDMGWKXqhFjz6Idx7N1aEhihm
24
25
  unienv_data/storages/hdf5.py,sha256=Jnls1rs7nlOOp9msmAfhuZp80OZd8S2Llls176EOUc4,27096
25
26
  unienv_data/storages/image_storage.py,sha256=4J1ZiGFHbGLHmReMztImJoDcRmiB_llD2wbMB3rdvOQ,5137
26
27
  unienv_data/storages/npz_storage.py,sha256=IP2DXbUs_ySzILne3s3hq3gwHiy9tfpWz6HcNciA8DU,4868
27
- unienv_data/storages/pytorch.py,sha256=bf3ys6eBlMvjyPK4XE-itENjEWq5Vm60qNwBNqJIZqg,7345
28
+ unienv_data/storages/pytorch.py,sha256=MltiTcBuvNWQqE5-RAfjBz3gwdfATkVMTmSBR2escIE,7363
28
29
  unienv_data/storages/transformation.py,sha256=-9_jPZNpx6RXY_ojv_1UCSTa4Z9apI9V9jit8nG93oM,8133
29
- unienv_data/storages/video_storage.py,sha256=2vcNlghhDZWWzAdf9t0VeCMZrv-x_rYkYaCw8XV8AJA,13331
30
+ unienv_data/storages/video_storage.py,sha256=0d3FP-BwXiq7q_s9UJE0G7v76Jfgz7Q07vX159UGVnE,22572
30
31
  unienv_data/third_party/tensordict/memmap_tensor.py,sha256=J6SkFf-FDy43XuaHLgbvDsHt6v2vYfuhRyeoV02P8vw,42589
31
32
  unienv_data/transformations/image_compress.py,sha256=f8JTY4DJEXaiu5lO77T4ROV950rh_bOZBchOF-O0tx8,13130
32
33
  unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
@@ -34,7 +35,7 @@ unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQw
34
35
  unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
35
36
  unienv_interface/backends/jax.py,sha256=26Wu5OQ4EEjolyZoELhlWMPNSZ7LsVoKEGpd09L80Ck,533
36
37
  unienv_interface/backends/numpy.py,sha256=6dMB2Vq7mrWukobyyGvuccluZUgjVkxr7x0hrUc_pe8,542
37
- unienv_interface/backends/pytorch.py,sha256=BddHmZAngsaedFlvj1mKdXpNe6AWvNwEXq_eTEUoFWA,592
38
+ unienv_interface/backends/pytorch.py,sha256=tsxgaSJK0Uux4cjHqEI6RDQtlOnmEWqwTxIb_JqphOw,592
38
39
  unienv_interface/backends/serialization.py,sha256=0TZlpfbP1DRB4FkM8ysDVQmn6RlYtIPisyeHjvHr7bE,2289
39
40
  unienv_interface/env_base/__init__.py,sha256=JuaVgWlg313LZpflt4LSErY94nUrfvUp0LbIPUle0MA,226
40
41
  unienv_interface/env_base/env.py,sha256=PV-AEmKwSjnFDjZFYtBW-At9w4fpm_I5C7GhfxPPrs4,4833
@@ -98,7 +99,7 @@ unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKi
98
99
  unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
99
100
  unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
100
101
  unienv_interface/wrapper/video_record.py,sha256=y_nJRYgo1SeLeO_Ymg9xbbGPKm48AbU3BxZK2wd0gzk,8679
101
- unienv-0.0.1b7.dist-info/METADATA,sha256=HT6qx5dKz7d5lOf4MBzdtJwx7dixbSaeQviHKCjJYnc,3056
102
- unienv-0.0.1b7.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
103
- unienv-0.0.1b7.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
104
- unienv-0.0.1b7.dist-info/RECORD,,
102
+ unienv-0.0.1b9.dist-info/METADATA,sha256=dK0fOZRjsgWzvnwEmASIt14oGnzFhU9FAqQSfETirB0,3056
103
+ unienv-0.0.1b9.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
104
+ unienv-0.0.1b9.dist-info/top_level.txt,sha256=wfcJ5_DruUtOEUZjEyfadaKn7B90hWqz2aw-eM3wX5g,29
105
+ unienv-0.0.1b9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -12,6 +12,102 @@ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils
12
12
 
13
13
  from ..base.common import BatchBase, IndexableType, BatchT
14
14
 
15
+ __all__ = [
16
+ "convert_single_index_to_batch",
17
+ "convert_index_to_batch",
18
+ "CombinedBatch",
19
+ ]
20
+
21
+ def convert_single_index_to_batch(
22
+ backend : ComputeBackend,
23
+ index_starts : BArrayType, # (num_batches, )
24
+ total_length : int, # Total length of all batches
25
+ idx : int,
26
+ device : Optional[BDeviceType] = None,
27
+ ) -> Tuple[int, int]:
28
+ """
29
+ Convert a single index for this batch to a tuple of:
30
+ - The index of the batch
31
+ - The index to index into the batch
32
+ """
33
+ assert -total_length <= idx < total_length, f"Index {idx} out of bounds for batch of size {total_length}"
34
+ if idx < 0:
35
+ idx += total_length
36
+ batch_index = int(backend.sum(
37
+ idx >= index_starts
38
+ ) - 1)
39
+ return batch_index, idx - int(index_starts[batch_index])
40
+
41
+ def convert_index_to_batch(
42
+ backend : ComputeBackend,
43
+ index_starts : BArrayType, # (num_batches, )
44
+ total_length : int, # Total length of all batches
45
+ idx : Union[IndexableType, BArrayType],
46
+ device : Optional[BDeviceType] = None,
47
+ ) -> Tuple[
48
+ int,
49
+ Union[List[
50
+ Tuple[int, BArrayType, BArrayType]
51
+ ], int]
52
+ ]:
53
+ """
54
+ Convert an index for this batch to a tuple of:
55
+ - The length of the resulting array
56
+ - List of tuples, each containing:
57
+ - The index of the batch
58
+ - The index to index into the batch
59
+ - The bool mask to index into the resulting array
60
+ """
61
+ if isinstance(idx, slice):
62
+ idx_array = backend.arange(
63
+ *idx.indices(total_length),
64
+ dtype=backend.default_index_dtype,
65
+ device=device
66
+ )
67
+ elif idx is Ellipsis:
68
+ idx_array = backend.arange(
69
+ total_length,
70
+ dtype=backend.default_index_dtype,
71
+ device=device
72
+ )
73
+ elif backend.is_backendarray(idx):
74
+ assert len(idx.shape) == 1, "Index must be 1D"
75
+ assert backend.dtype_is_real_integer(idx.dtype) or backend.dtype_is_boolean(idx.dtype), \
76
+ f"Index must be of integer or boolean type, got {idx.dtype}"
77
+ if backend.dtype_is_boolean(idx.dtype):
78
+ assert idx.shape[0] == total_length, f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {total_length}"
79
+ idx_array = backend.nonzero(idx)[0]
80
+ else:
81
+ assert backend.all(backend.logical_and(-total_length <= idx, idx < total_length)), \
82
+ f"Index array contains out of bounds indices for batch of size {total_length}"
83
+ idx_array = (idx + total_length) % total_length
84
+ else:
85
+ raise ValueError(f"Invalid index type: {type(idx)}")
86
+
87
+ assert bool(backend.all(
88
+ backend.logical_and(
89
+ -total_length <= idx_array,
90
+ idx_array < total_length
91
+ )
92
+ )), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {total_length}"
93
+
94
+ # Convert negative indices to positive indices
95
+ idx_array = backend.at(idx_array)[idx_array < 0].add(total_length)
96
+ idx_array_bigger = idx_array[:, None] >= index_starts[None, :] # (idx_array_shape, len(self.batches))
97
+ idx_array_batch_idx = backend.sum(
98
+ idx_array_bigger,
99
+ axis=-1
100
+ ) - 1 # (idx_array_shape, )
101
+
102
+ result_batch_list = []
103
+ batch_indexes = backend.unique_values(idx_array_batch_idx)
104
+ for i in range(batch_indexes.shape[0]):
105
+ batch_index = int(batch_indexes[i])
106
+ result_mask = idx_array_batch_idx == batch_index
107
+ index_into_batch = idx_array[result_mask] - index_starts[batch_index]
108
+ result_batch_list.append((batch_index, index_into_batch, result_mask))
109
+ return idx_array.shape[0], result_batch_list
110
+
15
111
  class CombinedBatch(BatchBase[
16
112
  BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
17
113
  ]):
@@ -83,18 +179,13 @@ class CombinedBatch(BatchBase[
83
179
  return self.index_caches[-1, 1]
84
180
 
85
181
  def _convert_single_index(self, idx : int) -> Tuple[int, int]:
86
- """
87
- Convert a single index to a tuple containing
88
- - the batch index
89
- - the index within the batch
90
- """
91
- assert -len(self) <= idx < len(self), f"Index {idx} out of bounds for batch of size {len(self)}"
92
- if idx < 0:
93
- idx += len(self)
94
- batch_index = int(self.backend.sum(
95
- idx >= self.index_caches[:, 0]
96
- ) - 1)
97
- return batch_index, idx - int(self.index_caches[batch_index, 0])
182
+ return convert_single_index_to_batch(
183
+ self.backend,
184
+ self.index_caches[:, 0],
185
+ len(self),
186
+ idx,
187
+ device=self.device,
188
+ )
98
189
 
99
190
  def _convert_index(self, idx : Union[IndexableType, BArrayType]) -> Tuple[
100
191
  int,
@@ -102,61 +193,13 @@ class CombinedBatch(BatchBase[
102
193
  Tuple[int, BArrayType, BArrayType]
103
194
  ]
104
195
  ]:
105
- """
106
- Convert an index for this batch to a tuple of:
107
- - The length of the resulting array
108
- - List of tuples, each containing:
109
- - The index of the batch
110
- - The index to index into the batch
111
- - The bool mask to index into the resulting array
112
- """
113
- if isinstance(idx, slice):
114
- idx_array = self.backend.arange(
115
- *idx.indices(len(self)),
116
- dtype=self.backend.default_integer_dtype,
117
- device=self.device
118
- )
119
- elif idx is Ellipsis:
120
- idx_array = self.backend.arange(
121
- len(self),
122
- dtype=self.backend.default_integer_dtype,
123
- device=self.device
124
- )
125
- elif self.backend.is_backendarray(idx):
126
- assert len(idx.shape) == 1, "Index must be 1D"
127
- assert self.backend.dtype_is_real_integer(idx.dtype) or self.backend.dtype_is_boolean(idx.dtype), \
128
- f"Index must be of integer or boolean type, got {idx.dtype}"
129
- if self.backend.dtype_is_boolean(idx.dtype):
130
- assert idx.shape[0] == len(self), f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {len(self)}"
131
- idx_array = self.backend.nonzero(idx)[0]
132
- else:
133
- idx_array = idx
134
- else:
135
- raise ValueError(f"Invalid index type: {type(idx)}")
136
-
137
- assert bool(self.backend.all(
138
- self.backend.logical_and(
139
- -len(self) <= idx_array,
140
- idx_array < len(self)
141
- )
142
- )), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {len(self)}"
143
-
144
- # Convert negative indices to positive indices
145
- idx_array = self.backend.at(idx_array)[idx_array < 0].add(len(self))
146
- idx_array_bigger = idx_array[:, None] >= self.index_caches[None, :, 0] # (idx_array_shape, len(self.batches))
147
- idx_array_batch_idx = self.backend.sum(
148
- idx_array_bigger,
149
- axis=-1
150
- ) - 1 # (idx_array_shape, )
151
-
152
- result_batch_list = []
153
- batch_indexes = self.backend.unique_values(idx_array_batch_idx)
154
- for i in range(batch_indexes.shape[0]):
155
- batch_index = int(batch_indexes[i])
156
- result_mask = idx_array_batch_idx == batch_index
157
- index_into_batch = idx_array[result_mask] - self.index_caches[batch_index, 0]
158
- result_batch_list.append((batch_index, index_into_batch, result_mask))
159
- return idx_array.shape[0], result_batch_list
196
+ return convert_index_to_batch(
197
+ self.backend,
198
+ self.index_caches[:, 0],
199
+ len(self),
200
+ idx,
201
+ device=self.device,
202
+ )
160
203
 
161
204
  def get_flattened_at(self, idx : Union[IndexableType, BaseException]) -> BArrayType:
162
205
  if isinstance(idx, int):
@@ -0,0 +1,47 @@
1
+ from typing import Optional
2
+ from datasets import Dataset as HFDataset
3
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
4
+ from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
5
+ from unienv_data.base import BatchBase, BatchT
6
+
7
+ __all__ = [
8
+ 'HFAsUniEnvDataset'
9
+ ]
10
+
11
+ class HFAsUniEnvDataset(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
12
+ BACKEND_TO_FORMAT_MAP = {
13
+ "numpy": "numpy",
14
+ "pytorch": "torch",
15
+ "jax": "jax",
16
+ }
17
+
18
+ is_mutable = False
19
+
20
+ def __init__(
21
+ self,
22
+ hf_dataset: HFDataset,
23
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
24
+ device : Optional[BDeviceType] = None,
25
+ ) -> None:
26
+ kwargs = {}
27
+ if backend.simplified_name != 'numpy':
28
+ kwargs['device'] = device
29
+ self.hf_dataset = hf_dataset.with_format(
30
+ self.BACKEND_TO_FORMAT_MAP[backend.simplified_name],
31
+ **kwargs
32
+ )
33
+ first_data = self.hf_dataset[0]
34
+ super().__init__(
35
+ scu.construct_space_from_data(first_data, backend)
36
+ )
37
+
38
+ def __len__(self) -> int:
39
+ return len(self.hf_dataset)
40
+
41
+ def get_at_with_metadata(self, idx):
42
+ if self.backend.is_backendarray(idx) and self.backend.dtype_is_boolean(idx):
43
+ idx = self.backend.nonzero(idx)[0]
44
+ return self.hf_dataset[idx], {}
45
+
46
+ def get_at(self, idx):
47
+ return self.get_at_with_metadata(idx)[0]
@@ -3,12 +3,19 @@ from typing import Optional, Union, Tuple, Dict, Any
3
3
  from torch.utils.data import Dataset
4
4
 
5
5
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
7
+ from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
6
8
  from unienv_data.base import BatchBase, BatchT
7
9
 
8
- class UniEnvPyTorchDataset(Dataset):
10
+ __all__ = [
11
+ "UniEnvAsPyTorchDataset",
12
+ "PyTorchAsUniEnvDataset",
13
+ ]
14
+
15
+ class UniEnvAsPyTorchDataset(Dataset):
9
16
  def __init__(
10
17
  self,
11
- batch : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType],
18
+ batch : BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
12
19
  include_metadata : bool = False,
13
20
  ):
14
21
  """
@@ -60,4 +67,56 @@ class UniEnvPyTorchDataset(Dataset):
60
67
  indices = self.batch.backend.asarray(indices, dtype=self.batch.backend.default_integer_dtype, device=self.batch.device)
61
68
  if self.include_metadata:
62
69
  return self.batch.get_at_with_metadata(indices)
63
- return self.batch.get_at(indices)
70
+ return self.batch.get_at(indices)
71
+
72
+ class PyTorchAsUniEnvDataset(BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType]):
73
+ is_mutable = False
74
+ def __init__(
75
+ self,
76
+ dataset: Dataset,
77
+ ):
78
+ """
79
+ A UniEnvPy BatchBase wrapper for PyTorch Datasets.
80
+
81
+ Args:
82
+ dataset (Dataset): The PyTorch Dataset to wrap.
83
+ """
84
+
85
+ assert len(dataset) >= 0, "The provided PyTorch Dataset must have a defined length."
86
+ self.dataset = dataset
87
+ tmp_data = dataset[0]
88
+ single_space = scu.construct_space_from_data(tmp_data, PyTorchComputeBackend)
89
+ super().__init__(
90
+ single_space,
91
+ None
92
+ )
93
+
94
+ def __len__(self) -> int:
95
+ return len(self.dataset)
96
+
97
+ def get_at_with_metadata(self, idx):
98
+ if isinstance(idx, int):
99
+ data = self.dataset[idx]
100
+ return data, {}
101
+ elif idx is Ellipsis:
102
+ idx = self.backend.arange(0, len(self))
103
+ elif isinstance(idx, slice):
104
+ idx = self.backend.arange(*idx.indices(len(self)))
105
+ elif self.backend.is_backendarray(idx):
106
+ if self.backend.dtype_is_boolean(idx.dtype):
107
+ idx = self.backend.nonzero(idx)[0]
108
+ else:
109
+ assert self.backend.dtype_is_real_integer(idx.dtype), "Index array must be of integer or boolean type."
110
+ idx = (idx + len(self)) % len(self)
111
+
112
+ idx_list = self.backend.to_numpy(idx).tolist()
113
+ if hasattr(self.dataset, "__getitems__"):
114
+ data_list = self.dataset.__getitems__(idx_list)
115
+ else:
116
+ data_list = [self.dataset[i] for i in idx_list]
117
+ aggregated_data = sbu.concatenate(self._batched_space, data_list)
118
+ return aggregated_data, {}
119
+
120
+ def get_at(self, idx):
121
+ data, _ = self.get_at_with_metadata(idx)
122
+ return data
@@ -188,8 +188,12 @@ class EpisodeStorageBase(SpaceStorage[
188
188
  index = self.backend.arange(0, self.capacity, device=self.device)
189
189
  else:
190
190
  index = self.backend.arange(0, self.length, device=self.device)
191
- elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
192
- index = self.backend.nonzero(index)[0]
191
+ elif self.backend.is_backendarray(index):
192
+ if self.backend.dtype_is_boolean(index.dtype):
193
+ index = self.backend.nonzero(index)[0]
194
+ else:
195
+ assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
196
+ index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
193
197
 
194
198
  batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
195
199
 
@@ -368,8 +372,13 @@ class EpisodeStorageBase(SpaceStorage[
368
372
  index = self.backend.arange(0, self.capacity, device=self.device)
369
373
  else:
370
374
  index = self.backend.arange(0, self.length, device=self.device)
371
- elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
372
- index = self.backend.nonzero(index)[0]
375
+ elif self.backend.is_backendarray(index):
376
+ if self.backend.dtype_is_boolean(index.dtype):
377
+ index = self.backend.nonzero(index)[0]
378
+ else:
379
+ assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
380
+ index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
381
+
373
382
  assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
374
383
  sorted_indexes_arg = self.backend.argsort(index)
375
384
  sorted_indexes = index[sorted_indexes_arg]
@@ -72,6 +72,7 @@ class PytorchTensorStorage(SpaceStorage[
72
72
  capacity : Optional[int] = None,
73
73
  read_only : bool = True,
74
74
  multiprocessing : bool = False,
75
+ **kwargs,
75
76
  ) -> "PytorchTensorStorage":
76
77
  assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
77
78
  assert capacity is not None, "Capacity must be specified when creating a new tensor"
@@ -1,23 +1,239 @@
1
1
  from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type, Literal, cast
2
2
  from fractions import Fraction
3
- from unienv_interface.space import Space, BoxSpace
4
- from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
3
+
4
+ from unienv_interface.space import BoxSpace
5
5
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from unienv_interface.backends.pytorch import PyTorchComputeBackend
6
7
  from unienv_interface.utils.symbol_util import *
7
8
 
8
- from unienv_data.base import SpaceStorage
9
9
  from ._episode_storage import IndexableType, EpisodeStorageBase
10
10
 
11
11
  import numpy as np
12
12
  import os
13
13
  import json
14
- import shutil
14
+ import logging
15
+ import importlib
16
+ import importlib.util
15
17
 
16
18
  import av
17
- import imageio.v3 as iio
18
- from imageio.plugins.pyav import PyAVPlugin
19
- from av.codec.hwaccel import HWAccel, hwdevices_available
19
+ from av.codec.hwaccel import HWAccel as PyAvHWAccel
20
20
  from av.codec import codecs_available
21
+ import av.error
22
+ from av.video.reformatter import VideoReformatter
23
+
24
+ # TorchCodec backend for video encoding / decoding
25
+ # import av
26
+ # from av.codec import codecs_available
27
+ # from torchcodec.decoders import VideoDecoder, set_cuda_backend
28
+
29
+ try:
30
+ import torch
31
+ except ImportError:
32
+ torch = None
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+ class PyAvVideoReader:
37
+ HWAccel = PyAvHWAccel
38
+
39
+ @staticmethod
40
+ def available_hwdevices() -> List[str]:
41
+ from av.codec.hwaccel import hwdevices_available
42
+ return hwdevices_available()
43
+
44
+ @staticmethod
45
+ def get_auto_hwaccel() -> Optional[HWAccel]:
46
+ available_hwdevices = __class__.available_hwdevices()
47
+ target_hwaccel = None
48
+ if "d3d11va" in available_hwdevices:
49
+ target_hwaccel = PyAvHWAccel(device_type="d3d11va", allow_software_fallback=True)
50
+ elif "cuda" in available_hwdevices:
51
+ target_hwaccel = PyAvHWAccel(device_type="cuda", allow_software_fallback=True)
52
+ elif "vaapi" in available_hwdevices:
53
+ target_hwaccel = PyAvHWAccel(device_type="vaapi", allow_software_fallback=True)
54
+ elif "videotoolbox" in available_hwdevices:
55
+ target_hwaccel = PyAvHWAccel(device_type="videotoolbox", allow_software_fallback=True)
56
+ return target_hwaccel
57
+
58
+ def __init__(
59
+ self,
60
+ backend : ComputeBackend,
61
+ filename: str,
62
+ buffer_pixel_format : Optional[str] = None,
63
+ hwaccel: Optional[Union[HWAccel, Literal['auto']]] = None,
64
+ seek_mode: Literal['exact', 'approximate'] = 'exact',
65
+ device : Optional[BDeviceType] = None,
66
+ ):
67
+ if seek_mode != 'exact':
68
+ LOGGER.warning("PyAvVideoReader only supports 'exact' seek mode. Falling back to 'exact'.")
69
+
70
+ if hwaccel == 'auto':
71
+ hwaccel = __class__.get_auto_hwaccel()
72
+ self.container = av.open(filename, mode='r', hwaccel=hwaccel)
73
+ self.video_stream = self.container.streams.video[0]
74
+ if buffer_pixel_format is not None:
75
+ self.video_reformatter = VideoReformatter()
76
+ else:
77
+ self.video_reformatter = None
78
+ self.buffer_pixel_format = buffer_pixel_format
79
+ self.total_frames = self.video_stream.frames # Sometimes this reads 0 for some containers (as they don't explicitly store it)
80
+ self.frame_iterator = self.container.decode(self.video_stream)
81
+ self.backend = backend
82
+ self.device = device
83
+
84
+ def seek(self, frame_index: int):
85
+ self.container.seek(frame_index, any_frame=True, backward=True, stream=self.video_stream)
86
+ self.frame_iterator = self.container.decode(self.video_stream)
87
+
88
+ def __next__(self):
89
+ try:
90
+ frame = next(self.frame_iterator)
91
+ if self.video_reformatter is not None:
92
+ frame = self.video_reformatter.reformat(
93
+ frame,
94
+ format=self.buffer_pixel_format,
95
+ )
96
+ frame = frame.to_ndarray(channel_last=True)
97
+ except av.error.EOFError:
98
+ raise StopIteration
99
+ if np.prod(frame.shape) == 0:
100
+ raise StopIteration
101
+
102
+ return frame
103
+
104
+ def __iter__(self):
105
+ return self
106
+
107
+ def __enter__(self):
108
+ return self
109
+
110
+ def __exit__(self, *args, **kwargs):
111
+ self.container.close()
112
+
113
+ def read(self, index : Union[IndexableType, BArrayType], total_length : int) -> BArrayType:
114
+ if isinstance(index, int):
115
+ self.seek(index)
116
+ frame_np = next(self)
117
+ frame = self.backend.from_numpy(frame_np)
118
+ if self.device is not None:
119
+ frame = self.backend.to_device(frame, self.device)
120
+ return frame
121
+ else:
122
+ if index is Ellipsis:
123
+ index = np.arange(total_length)
124
+ elif isinstance(index, slice):
125
+ index = np.arange(*index.indices(total_length))
126
+ elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
127
+ index = self.backend.nonzero(index)[0]
128
+ if self.backend.is_backendarray(index):
129
+ index = self.backend.to_numpy(index)
130
+
131
+ argsorted_indices = np.argsort(index)
132
+ sorted_index = index[argsorted_indices]
133
+ reserve_index = np.argsort(argsorted_indices)
134
+
135
+ if len(index) < total_length // 2:
136
+ all_frames_np = []
137
+ past_frame_np = None
138
+ for frame_i in sorted_index:
139
+ self.seek(frame_i)
140
+
141
+ try:
142
+ frame_np = next(self)
143
+ except StopIteration:
144
+ frame_np = past_frame_np
145
+ past_frame_np = frame_np
146
+ all_frames_np.append(frame_np)
147
+
148
+ all_frames_np = np.stack(all_frames_np, axis=0)
149
+ # Reorder from sorted order back to original order
150
+ all_frames_np = all_frames_np[reserve_index]
151
+ else:
152
+ # Create a set for O(1) lookup and a mapping from frame index to position in sorted_index
153
+ sorted_index_set = set(sorted_index)
154
+ frame_to_sorted_pos = {int(frame_idx): pos for pos, frame_idx in enumerate(sorted_index)}
155
+
156
+ # Pre-allocate array to store frames in sorted order
157
+ all_frames_list = [None] * len(sorted_index)
158
+ past_frame_np = None
159
+ self.seek(0)
160
+ for frame_i in range(total_length):
161
+ try:
162
+ frame_np = next(self)
163
+ except StopIteration:
164
+ frame_np = past_frame_np
165
+ if frame_i in sorted_index_set:
166
+ # Store at the position corresponding to sorted_index order
167
+ all_frames_list[frame_to_sorted_pos[frame_i]] = frame_np
168
+ past_frame_np = frame_np
169
+ all_frames_np = np.stack(all_frames_list, axis=0)
170
+ # Reorder from sorted order back to original order
171
+ all_frames_np = all_frames_np[reserve_index]
172
+ all_frames = self.backend.from_numpy(all_frames_np)
173
+ if self.device is not None:
174
+ all_frames = self.backend.to_device(all_frames, self.device)
175
+
176
+ return all_frames
177
+
178
+ class TorchCodecVideoReader:
179
+ HWAccel = Literal['beta', 'ffmpeg']
180
+ def __init__(
181
+ self,
182
+ backend : ComputeBackend,
183
+ filename: str,
184
+ buffer_pixel_format : Optional[str] = None,
185
+ hwaccel: Optional[Union[HWAccel, Literal['auto']]] = None,
186
+ seek_mode: Literal['exact', 'approximate'] = 'exact',
187
+ device : Optional[BDeviceType] = None,
188
+ ):
189
+ from torchcodec.decoders import VideoDecoder, set_cuda_backend
190
+ assert torch is not None, "TorchCodecVideoReader requires PyTorch and TorchCodec to be installed."
191
+ assert buffer_pixel_format == 'rgb24', "TorchCodecVideoReader currently only supports 'rgb24' buffer pixel format."
192
+ if hwaccel == 'auto':
193
+ hwaccel = __class__.get_auto_hwaccel()
194
+ if hwaccel is not None:
195
+ with set_cuda_backend(hwaccel):
196
+ self.decoder = VideoDecoder(filename, device='cuda', seek_mode=seek_mode)
197
+ else:
198
+ self.decoder = VideoDecoder(filename, seek_mode=seek_mode)
199
+ self.backend = backend
200
+ self.device = device
201
+ self.buffer_pixel_format = buffer_pixel_format
202
+
203
+ @staticmethod
204
+ def get_auto_hwaccel() -> Optional[HWAccel]:
205
+ assert torch is not None, "TorchCodecVideoReader requires PyTorch to be installed."
206
+ if torch.cuda.is_available():
207
+ return 'beta'
208
+ else:
209
+ return None
210
+
211
+ def __enter__(self):
212
+ return self
213
+
214
+ def __exit__(self, *args, **kwargs):
215
+ pass
216
+
217
+ def read(self, index : Union[IndexableType, np.ndarray], total_length : int) -> np.ndarray:
218
+ if isinstance(index, int):
219
+ ret = self.decoder.get_frame_at(index).data.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
220
+ elif index is Ellipsis:
221
+ ret = self.decoder.get_frames_in_range(0, len(self.decoder)).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
222
+ elif isinstance(index, slice):
223
+ start, stop, step = index.indices(len(self))
224
+ ret = self.decoder.get_frames_in_range(start, stop, step=step).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
225
+ else:
226
+ if self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
227
+ index = self.backend.nonzero(index)[0]
228
+ if self.backend.simplified_name != 'pytorch':
229
+ index = PyTorchComputeBackend.from_other_backend(self.backend, index).to('cpu', torch.int64)
230
+ ret = self.decoder.get_frames_at(index).data.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
231
+
232
+ if self.backend.simplified_name != 'pytorch':
233
+ ret = self.backend.from_other_backend(PyTorchComputeBackend, ret)
234
+ if self.device is not None:
235
+ ret = self.backend.to_device(ret, self.device)
236
+ return ret
21
237
 
22
238
  class VideoStorage(EpisodeStorageBase[
23
239
  BArrayType,
@@ -30,7 +246,7 @@ class VideoStorage(EpisodeStorageBase[
30
246
  A storage for RGB or depth video data using video files
31
247
  If encoding RGB video
32
248
  - Set `buffer_pixel_format` to `rgb24`
33
- - Set `file_pixel_format` to `None`
249
+ - Set `file_pixel_format` to `None` (especially when running with nvenc codec)
34
250
  - Set `file_ext` to anything you like (e.g., "mp4", "avi", "mkv", etc.)
35
251
  If encoding depth video
36
252
  - Set `buffer_pixel_format` to `gray16le` (You can use rescale transform inside a `TransformedStorage` to convert depth values to this format, where `dtype` should be `np.uint16`) - if in meters, set min to 0 and max to 65.535 as the multiplication factor is 1000 (i.e., depth in mm)
@@ -40,12 +256,15 @@ class VideoStorage(EpisodeStorageBase[
40
256
  """
41
257
 
42
258
  # ========== Class Attributes ==========
259
+ PyAV_LOG_LEVEL = av.logging.WARNING
260
+
43
261
  @classmethod
44
262
  def create(
45
263
  cls,
46
264
  single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
47
265
  *args,
48
- hardware_acceleration : Optional[Union[HWAccel, Literal['auto']]] = 'auto',
266
+ seek_mode : Literal['exact', 'approximate'] = 'exact',
267
+ hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
49
268
  codec : Union[str, Literal['auto']] = 'auto',
50
269
  file_ext : str = "mp4",
51
270
  file_pixel_format : Optional[str] = None,
@@ -64,6 +283,7 @@ class VideoStorage(EpisodeStorageBase[
64
283
  return VideoStorage(
65
284
  single_instance_space,
66
285
  cache_filename=cache_path,
286
+ seek_mode=seek_mode,
67
287
  hardware_acceleration=hardware_acceleration,
68
288
  codec=codec,
69
289
  file_ext=file_ext,
@@ -79,7 +299,8 @@ class VideoStorage(EpisodeStorageBase[
79
299
  path : Union[str, os.PathLike],
80
300
  single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
81
301
  *,
82
- hardware_acceleration : Optional[Union[HWAccel, Literal['auto']]] = 'auto',
302
+ seek_mode : Literal['exact', 'approximate'] = 'exact',
303
+ hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
83
304
  codec : Union[str, Literal['auto']] = 'auto',
84
305
  capacity : Optional[int] = None,
85
306
  read_only : bool = True,
@@ -113,6 +334,7 @@ class VideoStorage(EpisodeStorageBase[
113
334
  return VideoStorage(
114
335
  single_instance_space,
115
336
  cache_filename=path,
337
+ seek_mode=seek_mode,
116
338
  hardware_acceleration=hardware_acceleration,
117
339
  codec=codec,
118
340
  file_ext=file_ext,
@@ -127,31 +349,20 @@ class VideoStorage(EpisodeStorageBase[
127
349
  # ========== Instance Implementations ==========
128
350
  single_file_ext = None
129
351
 
130
- @staticmethod
131
- def get_auto_hwaccel() -> Optional[HWAccel]:
132
- if hasattr(__class__, "_auto_hwaccel"):
133
- return __class__._auto_hwaccel
134
- available_hwdevices = hwdevices_available()
135
- target_hwaccel = None
136
- if "d3d11va" in available_hwdevices:
137
- target_hwaccel = HWAccel(device_type="d3d11va", allow_software_fallback=True)
138
- elif "cuda" in available_hwdevices:
139
- target_hwaccel = HWAccel(device_type="cuda", allow_software_fallback=True)
140
- elif "vaapi" in available_hwdevices:
141
- target_hwaccel = HWAccel(device_type="vaapi", allow_software_fallback=True)
142
- elif "videotoolbox" in available_hwdevices:
143
- target_hwaccel = HWAccel(device_type="videotoolbox", allow_software_fallback=True)
144
- __class__._auto_hwaccel = target_hwaccel
145
- return target_hwaccel
146
-
147
352
  @staticmethod
148
353
  def get_auto_codec(
149
354
  base : Optional[str] = None
150
355
  ) -> str:
151
356
  if hasattr(__class__, "_auto_codec"):
152
357
  return __class__._auto_codec
358
+
359
+ # --- PyAV Implementation ---
153
360
  preferred_codecs = ["av1", "hevc", "h264", "mpeg4", "vp9", "vp8"] if base is None else [base]
154
- preferred_suffixes = ["_nvenc", "_amf", "_qsv"]
361
+ preferred_suffixes = ["_nvenc", "_vaapi", "_amf", "_qsv", "_videotoolbox"]
362
+
363
+ # --- TorchCodec Implementation ---
364
+ # preferred_codecs = ["hevc", "av1", "h264", "mpeg4", "vp9", "vp8"] if base is None else [base]
365
+ # preferred_suffixes = ["_nvenc"]
155
366
 
156
367
  target_codec = None
157
368
  for codec in preferred_codecs:
@@ -177,8 +388,10 @@ class VideoStorage(EpisodeStorageBase[
177
388
  self,
178
389
  single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
179
390
  cache_filename : Union[str, os.PathLike],
180
- hardware_acceleration : Optional[Union[HWAccel, Literal['auto']]] = 'auto',
391
+ seek_mode : Literal['exact', 'approximate'] = 'exact',
392
+ hardware_acceleration : Optional[Union[Any, Literal['auto']]] = 'auto',
181
393
  codec : Union[str, Literal['auto']] = 'auto',
394
+ decode_backend : Literal['torchcodec', 'pyav', 'auto'] = 'auto',
182
395
  file_ext : str = "mp4",
183
396
  file_pixel_format : Optional[str] = None,
184
397
  buffer_pixel_format : str = "rgb24",
@@ -195,87 +408,112 @@ class VideoStorage(EpisodeStorageBase[
195
408
  capacity=capacity,
196
409
  length=length,
197
410
  )
198
- self.hwaccel = None if hardware_acceleration is None else (
199
- self.get_auto_hwaccel() if hardware_acceleration == 'auto' else hardware_acceleration
200
- )
411
+ self.seek_mode = seek_mode
412
+ self.hwaccel = hardware_acceleration
201
413
  self.codec = self.get_auto_codec() if codec == 'auto' else codec
414
+
415
+ if decode_backend == 'auto':
416
+ if importlib.util.find_spec("torchcodec"):
417
+ decode_backend = 'torchcodec'
418
+ else:
419
+ decode_backend = 'pyav'
420
+ self.decode_backend = decode_backend
421
+
202
422
  self.fps = fps
203
423
  self.file_pixel_format = file_pixel_format
204
424
  self.buffer_pixel_format = buffer_pixel_format
205
-
425
+
206
426
  def get_from_file(self, filename : str, index : Union[IndexableType, BArrayType], total_length : int) -> BArrayType:
207
- with iio.imopen(filename, 'r', plugin='pyav', hwaccel=self.hwaccel) as video:
208
- video = cast(PyAVPlugin, video)
209
- if isinstance(index, int):
210
- frame_np = video.read(index=index, format=self.buffer_pixel_format)
211
- frame = self.backend.from_numpy(frame_np)
212
- if self.device is not None:
213
- frame = self.backend.to_device(frame, self.device)
214
- return frame
215
- else:
216
- if index is Ellipsis:
217
- index = np.arange(total_length)
218
- elif isinstance(index, slice):
219
- index = np.arange(*index.indices(total_length))
220
- elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
221
- index = self.backend.nonzero(index)[0]
222
- if self.backend.is_backendarray(index):
223
- index = self.backend.to_numpy(index)
224
-
225
- argsorted_indices = np.argsort(index)
226
- sorted_index = index[argsorted_indices]
227
- reserve_index = np.argsort(argsorted_indices)
228
-
229
- if len(index) < total_length // 2:
230
- all_frames_np = []
231
- for frame_i in sorted_index:
232
- frame_np = video.read(index=frame_i, format=self.buffer_pixel_format)
233
- all_frames_np.append(frame_np)
234
- all_frames_np = np.stack(all_frames_np, axis=0)
235
- # Reorder from sorted order back to original order
236
- all_frames_np = all_frames_np[reserve_index]
237
- else:
238
- # Create a set for O(1) lookup and a mapping from frame index to position in sorted_index
239
- sorted_index_set = set(sorted_index)
240
- frame_to_sorted_pos = {int(frame_idx): pos for pos, frame_idx in enumerate(sorted_index)}
241
-
242
- # Pre-allocate array to store frames in sorted order
243
- all_frames_list = [None] * len(sorted_index)
244
- past_frame_np = None
245
- video_iter = video.iter(format=self.buffer_pixel_format)
246
- for frame_i in range(total_length):
247
- try:
248
- frame_np = next(video_iter)
249
- except StopIteration:
250
- frame_np = past_frame_np
251
- if frame_i in sorted_index_set:
252
- # Store at the position corresponding to sorted_index order
253
- all_frames_list[frame_to_sorted_pos[frame_i]] = frame_np
254
- past_frame_np = frame_np
255
- all_frames_np = np.stack(all_frames_list, axis=0)
256
- # Reorder from sorted order back to original order
257
- all_frames_np = all_frames_np[reserve_index]
258
- all_frames = self.backend.from_numpy(all_frames_np)
259
- if self.device is not None:
260
- all_frames = self.backend.to_device(all_frames, self.device)
261
- return all_frames
427
+ if self.decode_backend == 'pyav':
428
+ reader_cls = PyAvVideoReader
429
+ elif self.decode_backend == 'torchcodec':
430
+ reader_cls = TorchCodecVideoReader
431
+ else:
432
+ raise ValueError(f"Unknown decode_backend {self.decode_backend}")
433
+ with reader_cls(
434
+ backend=self.backend,
435
+ filename=filename,
436
+ buffer_pixel_format=self.buffer_pixel_format,
437
+ hwaccel=self.hwaccel,
438
+ seek_mode=self.seek_mode,
439
+ device=self.device,
440
+ ) as video_reader:
441
+ return video_reader.read(index, total_length)
262
442
 
443
+ # PyAV Implementation (Commented out due to bugs with hevc codec)
263
444
  def set_to_file(self, filename : str, value : BArrayType):
264
- value_np = self.backend.to_numpy(value)
265
- with iio.imopen(
266
- filename,
267
- 'w',
268
- plugin='pyav',
269
- ) as video:
270
- video = cast(PyAVPlugin, video)
271
- video.init_video_stream(self.codec, fps=self.fps, pixel_format=self.file_pixel_format)
445
+ # ImageIO Implementation
446
+ # with iio.imopen(
447
+ # filename,
448
+ # 'w',
449
+ # plugin='pyav',
450
+ # ) as video:
451
+ # video = cast(PyAVPlugin, video)
452
+ # video.init_video_stream(self.codec, fps=self.fps, pixel_format=self.file_pixel_format)
453
+
454
+ # # Fix codec time base if not set:
455
+ # if video._video_stream.codec_context.time_base is None:
456
+ # video._video_stream.codec_context.time_base = Fraction(1 / self.fps).limit_denominator(int(2**16 - 1))
457
+
458
+ # for i, frame in enumerate(value_np):
459
+ # video.write_frame(frame, pixel_format=self.buffer_pixel_format)
460
+
461
+ # PyAV Implementation
462
+ logging.getLogger("libav").setLevel(self.PyAV_LOG_LEVEL)
463
+ with av.open(filename, mode='w') as container:
464
+ output_stream = container.add_stream(self.codec, rate=self.fps)
465
+ if len(self.single_instance_space.shape) == 3: # (H, W, C)
466
+ output_stream.width = self.single_instance_space.shape[1]
467
+ output_stream.height = self.single_instance_space.shape[0]
468
+ else: # (H, W)
469
+ output_stream.width = self.single_instance_space.shape[1]
470
+ output_stream.height = self.single_instance_space.shape[0]
471
+ if self.file_pixel_format is not None:
472
+ output_stream.pix_fmt = self.file_pixel_format
473
+ # output_stream.time_base = Fraction(1, self.fps).limit_denominator(int(2**16 - 1))
474
+ value_np = self.backend.to_numpy(value)
475
+ for i, frame_np in enumerate(value_np):
476
+ frame = av.VideoFrame.from_ndarray(frame_np, format=self.buffer_pixel_format, channel_last=True)
477
+ packets = output_stream.encode(frame)
478
+ if packets:
479
+ container.mux(packets)
480
+ # Flush stream
481
+ packets = output_stream.encode()
482
+ if packets:
483
+ container.mux(packets)
272
484
 
273
- # Fix codec time base if not set:
274
- if video._video_stream.codec_context.time_base is None:
275
- video._video_stream.codec_context.time_base = Fraction(1 / self.fps).limit_denominator(int(2**16 - 1))
485
+ # Close container
486
+ if hasattr(output_stream, 'close'):
487
+ output_stream.close()
488
+ # container.close()
489
+
490
+ # Restore logging level
491
+ av.logging.restore_default_callback()
276
492
 
277
- for i, frame in enumerate(value_np):
278
- video.write_frame(frame, pixel_format=self.buffer_pixel_format)
493
+ # TorchCodec Implementation (extremely slow, and seems to have double-write issues)
494
+ # def set_to_file(self, filename : str, value : BArrayType):
495
+ # if self.backend.simplified_name != 'pytorch':
496
+ # value_pt = PyTorchComputeBackend.from_other_backend(self.backend, value)
497
+ # else:
498
+ # value_pt = value
499
+ # if self.codec.endswith('_nvenc'):
500
+ # value_pt = value_pt.to("cuda")
501
+ # else:
502
+ # value_pt = value_pt.to("cpu")
503
+
504
+ # if len(self.single_instance_space.shape) < 3:
505
+ # # Add channel dimension for grayscale images
506
+ # value_pt = value_pt[..., None] # (N, H, W) -> (N, H, W, 1)
507
+
508
+ # encoder = VideoEncoder(
509
+ # value_pt.permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W)
510
+ # frame_rate=self.fps,
511
+ # )
512
+ # encoder.to_file(
513
+ # filename,
514
+ # codec=self.codec,
515
+ # pixel_format=self.file_pixel_format,
516
+ # )
279
517
 
280
518
  def dumps(self, path):
281
519
  assert os.path.samefile(path, self.cache_filename), \
@@ -1,7 +1,7 @@
1
1
  try:
2
- from xbarray.backends.pytorch import PytorchComputeBackend as XBPytorchBackend
2
+ from xbarray.backends.pytorch import PyTorchComputeBackend as XBPytorchBackend
3
3
  except ImportError:
4
- from xbarray.pytorch import PytorchComputeBackend as XBPytorchBackend
4
+ from xbarray.pytorch import PyTorchComputeBackend as XBPytorchBackend
5
5
  from xbarray import ComputeBackend
6
6
 
7
7
  from typing import Union