flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,129 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
import grain.python as pygrain
|
3
|
+
from typing import Dict, Any, Callable, List, Optional
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from functools import partial
|
6
|
+
|
7
|
+
|
8
|
+
class DataSource(ABC):
|
9
|
+
"""Base class for all data sources in FlaxDiff."""
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def get_source(self, path_override: str) -> Any:
|
13
|
+
"""Return the data source object.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
path_override: Path to the dataset, overriding the default.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
A data source object compatible with grain or other loaders.
|
20
|
+
"""
|
21
|
+
pass
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def create(source_type: str, **kwargs) -> 'DataSource':
|
25
|
+
"""Factory method to create a data source of the specified type.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
source_type: Type of the data source ("image", "video", etc.)
|
29
|
+
**kwargs: Additional arguments for the specific data source.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
An instance of a DataSource subclass.
|
33
|
+
"""
|
34
|
+
from .images import ImageTFDSSource, ImageGCSSource, CombinedImageGCSSource
|
35
|
+
from .videos import VideoTFDSSource, VideoLocalSource
|
36
|
+
|
37
|
+
source_map = {
|
38
|
+
"image_tfds": ImageTFDSSource,
|
39
|
+
"image_gcs": ImageGCSSource,
|
40
|
+
"image_combined_gcs": CombinedImageGCSSource,
|
41
|
+
"video_tfds": VideoTFDSSource,
|
42
|
+
"video_local": VideoLocalSource
|
43
|
+
}
|
44
|
+
|
45
|
+
if source_type not in source_map:
|
46
|
+
raise ValueError(f"Unknown source type: {source_type}")
|
47
|
+
return source_map[source_type](**kwargs)
|
48
|
+
|
49
|
+
|
50
|
+
class DataAugmenter(ABC):
|
51
|
+
"""Base class for all data augmenters in FlaxDiff."""
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def create_transform(self, **kwargs) -> Callable[[], pygrain.MapTransform]:
|
55
|
+
"""Create a transformation function for the data.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
**kwargs: Additional arguments for the transformation.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
A callable that returns a pygrain.MapTransform instance.
|
62
|
+
"""
|
63
|
+
pass
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
|
67
|
+
"""Factory method to create a data augmenter of the specified type.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
augmenter_type: Type of the data augmenter ("image", "video", etc.)
|
71
|
+
**kwargs: Additional arguments for the specific augmenter.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
An instance of a DataAugmenter subclass.
|
75
|
+
"""
|
76
|
+
from .images import ImageTFDSAugmenter, ImageGCSAugmenter
|
77
|
+
from .videos import VideoAugmenter
|
78
|
+
|
79
|
+
augmenter_map = {
|
80
|
+
"image_tfds": ImageTFDSAugmenter,
|
81
|
+
"image_gcs": ImageGCSAugmenter,
|
82
|
+
"video": VideoAugmenter
|
83
|
+
}
|
84
|
+
|
85
|
+
if augmenter_type not in augmenter_map:
|
86
|
+
raise ValueError(f"Unknown augmenter type: {augmenter_type}")
|
87
|
+
|
88
|
+
return augmenter_map[augmenter_type](**kwargs)
|
89
|
+
|
90
|
+
|
91
|
+
class MediaDataset:
|
92
|
+
"""A class combining a data source and an augmenter for a complete dataset."""
|
93
|
+
|
94
|
+
def __init__(self,
|
95
|
+
source: DataSource,
|
96
|
+
augmenter: DataAugmenter,
|
97
|
+
media_type: str = "image"):
|
98
|
+
"""Initialize a MediaDataset.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
source: The data source.
|
102
|
+
augmenter: The data augmenter.
|
103
|
+
media_type: Type of media ("image", "video", etc.)
|
104
|
+
"""
|
105
|
+
self.source = source
|
106
|
+
self.augmenter = augmenter
|
107
|
+
self.media_type = media_type
|
108
|
+
|
109
|
+
def get_source(self, path_override: str) -> Any:
|
110
|
+
"""Get the data source.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
path_override: Path to override the default data source path.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
A data source object.
|
117
|
+
"""
|
118
|
+
return self.source.get_source(path_override)
|
119
|
+
|
120
|
+
def get_augmenter(self, **kwargs) -> Callable[[], pygrain.MapTransform]:
|
121
|
+
"""Get the augmenter transformation.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
**kwargs: Additional arguments for the augmenter.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
A callable that returns a pygrain.MapTransform instance.
|
128
|
+
"""
|
129
|
+
return self.augmenter.create_transform(**kwargs)
|
@@ -0,0 +1,309 @@
|
|
1
|
+
import cv2
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import grain.python as pygrain
|
4
|
+
from flaxdiff.utils import AutoTextTokenizer
|
5
|
+
from typing import Dict, Any, Callable, List, Optional
|
6
|
+
import random
|
7
|
+
import augmax
|
8
|
+
import jax
|
9
|
+
import os
|
10
|
+
import struct as st
|
11
|
+
from functools import partial
|
12
|
+
import numpy as np
|
13
|
+
from .base import DataSource, DataAugmenter
|
14
|
+
|
15
|
+
|
16
|
+
# ----------------------------------------------------------------------------------
|
17
|
+
# Utility functions
|
18
|
+
# ----------------------------------------------------------------------------------
|
19
|
+
|
20
|
+
def unpack_dict_of_byte_arrays(packed_data):
|
21
|
+
"""Unpacks a dictionary of byte arrays from a packed binary format."""
|
22
|
+
unpacked_dict = {}
|
23
|
+
offset = 0
|
24
|
+
while offset < len(packed_data):
|
25
|
+
# Unpack the key length
|
26
|
+
key_length = st.unpack_from('I', packed_data, offset)[0]
|
27
|
+
offset += st.calcsize('I')
|
28
|
+
# Unpack the key bytes and convert to string
|
29
|
+
key = packed_data[offset:offset+key_length].decode('utf-8')
|
30
|
+
offset += key_length
|
31
|
+
# Unpack the byte array length
|
32
|
+
byte_array_length = st.unpack_from('I', packed_data, offset)[0]
|
33
|
+
offset += st.calcsize('I')
|
34
|
+
# Unpack the byte array
|
35
|
+
byte_array = packed_data[offset:offset+byte_array_length]
|
36
|
+
offset += byte_array_length
|
37
|
+
unpacked_dict[key] = byte_array
|
38
|
+
return unpacked_dict
|
39
|
+
|
40
|
+
|
41
|
+
# ----------------------------------------------------------------------------------
|
42
|
+
# Image augmentation utilities
|
43
|
+
# ----------------------------------------------------------------------------------
|
44
|
+
|
45
|
+
def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
|
46
|
+
"""Basic image augmentation: convert color and resize."""
|
47
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
48
|
+
image = cv2.resize(image, (image_scale, image_scale),
|
49
|
+
interpolation=method)
|
50
|
+
return image
|
51
|
+
|
52
|
+
|
53
|
+
PROMPT_TEMPLATES = [
|
54
|
+
"a photo of a {}",
|
55
|
+
"a photo of a {} flower",
|
56
|
+
"This is a photo of a {}",
|
57
|
+
"This is a photo of a {} flower",
|
58
|
+
"A photo of a {} flower",
|
59
|
+
]
|
60
|
+
|
61
|
+
|
62
|
+
def labelizer_oxford_flowers102(path):
|
63
|
+
"""Creates a label generator for Oxford Flowers 102 dataset."""
|
64
|
+
with open(path, "r") as f:
|
65
|
+
textlabels = [i.strip() for i in f.readlines()]
|
66
|
+
|
67
|
+
def load_labels(sample):
|
68
|
+
raw = textlabels[int(sample['label'])]
|
69
|
+
# randomly select a prompt template
|
70
|
+
template = random.choice(PROMPT_TEMPLATES)
|
71
|
+
# format the template with the label
|
72
|
+
caption = template.format(raw)
|
73
|
+
# return the caption
|
74
|
+
return caption
|
75
|
+
return load_labels
|
76
|
+
|
77
|
+
|
78
|
+
# ----------------------------------------------------------------------------------
|
79
|
+
# TFDS Image Source
|
80
|
+
# ----------------------------------------------------------------------------------
|
81
|
+
|
82
|
+
class ImageTFDSSource(DataSource):
|
83
|
+
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
|
84
|
+
|
85
|
+
def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
|
86
|
+
"""Initialize a TFDS image data source.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
name: Name of the TFDS dataset.
|
90
|
+
use_tf: Whether to use TensorFlow for loading.
|
91
|
+
split: Dataset split to use.
|
92
|
+
"""
|
93
|
+
self.name = name
|
94
|
+
self.use_tf = use_tf
|
95
|
+
self.split = split
|
96
|
+
|
97
|
+
def get_source(self, path_override: str) -> Any:
|
98
|
+
"""Get the TFDS data source.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
path_override: Override path for the dataset.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
A TFDS dataset.
|
105
|
+
"""
|
106
|
+
import tensorflow_datasets as tfds
|
107
|
+
if self.use_tf:
|
108
|
+
return tfds.load(self.name, split=self.split, shuffle_files=True)
|
109
|
+
else:
|
110
|
+
return tfds.data_source(self.name, split=self.split, try_gcs=False)
|
111
|
+
|
112
|
+
|
113
|
+
class ImageTFDSAugmenter(DataAugmenter):
|
114
|
+
"""Augmenter for TFDS image datasets."""
|
115
|
+
|
116
|
+
def __init__(self, label_path: str = "/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"):
|
117
|
+
"""Initialize a TFDS image augmenter.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
label_path: Path to the labels file for datasets like Oxford Flowers.
|
121
|
+
"""
|
122
|
+
self.label_path = label_path
|
123
|
+
|
124
|
+
def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
|
125
|
+
"""Create a transform for TFDS image datasets.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
image_scale: Size to scale images to.
|
129
|
+
method: Interpolation method for resizing.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
A callable that returns a pygrain.MapTransform.
|
133
|
+
"""
|
134
|
+
labelizer = labelizer_oxford_flowers102(self.label_path)
|
135
|
+
|
136
|
+
if image_scale > 256:
|
137
|
+
interpolation = cv2.INTER_CUBIC
|
138
|
+
else:
|
139
|
+
interpolation = cv2.INTER_AREA
|
140
|
+
|
141
|
+
from torchvision.transforms import v2
|
142
|
+
augments = v2.Compose([
|
143
|
+
v2.RandomHorizontalFlip(p=0.5),
|
144
|
+
v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
|
145
|
+
])
|
146
|
+
|
147
|
+
class TFDSTransform(pygrain.MapTransform):
|
148
|
+
def __init__(self, *args, **kwargs):
|
149
|
+
super().__init__(*args, **kwargs)
|
150
|
+
self.tokenize = AutoTextTokenizer(tensor_type="np")
|
151
|
+
|
152
|
+
def map(self, element) -> Dict[str, jnp.array]:
|
153
|
+
image = element['image']
|
154
|
+
image = cv2.resize(image, (image_scale, image_scale),
|
155
|
+
interpolation=interpolation)
|
156
|
+
image = augments(image)
|
157
|
+
|
158
|
+
caption = labelizer(element)
|
159
|
+
results = self.tokenize(caption)
|
160
|
+
return {
|
161
|
+
"image": image,
|
162
|
+
"text": {
|
163
|
+
"input_ids": results['input_ids'][0],
|
164
|
+
"attention_mask": results['attention_mask'][0],
|
165
|
+
}
|
166
|
+
}
|
167
|
+
|
168
|
+
return TFDSTransform
|
169
|
+
|
170
|
+
|
171
|
+
# ----------------------------------------------------------------------------------
|
172
|
+
# GCS Image Source
|
173
|
+
# ----------------------------------------------------------------------------------
|
174
|
+
|
175
|
+
class ImageGCSSource(DataSource):
|
176
|
+
"""Data source for Google Cloud Storage (GCS) image datasets."""
|
177
|
+
|
178
|
+
def __init__(self, source: str = 'arrayrecord/laion-aesthetics-12m+mscoco-2017'):
|
179
|
+
"""Initialize a GCS image data source.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
source: Path to the GCS dataset.
|
183
|
+
"""
|
184
|
+
self.source = source
|
185
|
+
|
186
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
187
|
+
"""Get the GCS data source.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
path_override: Base path for GCS mounts.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
A grain ArrayRecordDataSource.
|
194
|
+
"""
|
195
|
+
records_path = os.path.join(path_override, self.source)
|
196
|
+
records = [os.path.join(records_path, i) for i in os.listdir(
|
197
|
+
records_path) if 'array_record' in i]
|
198
|
+
return pygrain.ArrayRecordDataSource(records)
|
199
|
+
|
200
|
+
|
201
|
+
class CombinedImageGCSSource(DataSource):
|
202
|
+
"""Data source that combines multiple GCS image datasets."""
|
203
|
+
|
204
|
+
def __init__(self, sources: List[str] = []):
|
205
|
+
"""Initialize a combined GCS image data source.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
sources: List of paths to GCS datasets.
|
209
|
+
"""
|
210
|
+
self.sources = sources
|
211
|
+
|
212
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
213
|
+
"""Get the combined GCS data source.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
path_override: Base path for GCS mounts.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
A grain ArrayRecordDataSource.
|
220
|
+
"""
|
221
|
+
records_paths = [os.path.join(path_override, source) for source in self.sources]
|
222
|
+
records = []
|
223
|
+
for records_path in records_paths:
|
224
|
+
records += [os.path.join(records_path, i) for i in os.listdir(
|
225
|
+
records_path) if 'array_record' in i]
|
226
|
+
return pygrain.ArrayRecordDataSource(records)
|
227
|
+
|
228
|
+
|
229
|
+
class ImageGCSAugmenter(DataAugmenter):
|
230
|
+
"""Augmenter for GCS image datasets."""
|
231
|
+
|
232
|
+
def __init__(self, labelizer: Callable = None):
|
233
|
+
"""Initialize a GCS image augmenter.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
labelizer: Function to extract text labels from samples.
|
237
|
+
"""
|
238
|
+
self.labelizer = labelizer or (lambda sample: sample['txt'])
|
239
|
+
|
240
|
+
def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
|
241
|
+
"""Create a transform for GCS image datasets.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
image_scale: Size to scale images to.
|
245
|
+
method: Interpolation method for resizing.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
A callable that returns a pygrain.MapTransform.
|
249
|
+
"""
|
250
|
+
labelizer = self.labelizer
|
251
|
+
|
252
|
+
class GCSTransform(pygrain.MapTransform):
|
253
|
+
def __init__(self, *args, **kwargs):
|
254
|
+
super().__init__(*args, **kwargs)
|
255
|
+
self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
256
|
+
self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
|
257
|
+
|
258
|
+
def map(self, element) -> Dict[str, jnp.array]:
|
259
|
+
element = unpack_dict_of_byte_arrays(element)
|
260
|
+
image = np.asarray(bytearray(element['jpg']), dtype="uint8")
|
261
|
+
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
|
262
|
+
image = self.image_augmenter(image)
|
263
|
+
caption = labelizer(element).decode('utf-8')
|
264
|
+
results = self.auto_tokenize(caption)
|
265
|
+
return {
|
266
|
+
"image": image,
|
267
|
+
"text": {
|
268
|
+
"input_ids": results['input_ids'][0],
|
269
|
+
"attention_mask": results['attention_mask'][0],
|
270
|
+
}
|
271
|
+
}
|
272
|
+
|
273
|
+
return GCSTransform
|
274
|
+
|
275
|
+
|
276
|
+
# ----------------------------------------------------------------------------------
|
277
|
+
# Legacy compatibility functions
|
278
|
+
# ----------------------------------------------------------------------------------
|
279
|
+
|
280
|
+
# These functions maintain backward compatibility with existing code
|
281
|
+
|
282
|
+
def data_source_tfds(name, use_tf=True, split="all"):
|
283
|
+
"""Legacy function for TFDS data sources."""
|
284
|
+
source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
|
285
|
+
return source.get_source
|
286
|
+
|
287
|
+
|
288
|
+
def tfds_augmenters(image_scale, method):
|
289
|
+
"""Legacy function for TFDS augmenters."""
|
290
|
+
augmenter = ImageTFDSAugmenter()
|
291
|
+
return augmenter.create_transform(image_scale=image_scale, method=method)
|
292
|
+
|
293
|
+
|
294
|
+
def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
|
295
|
+
"""Legacy function for GCS data sources."""
|
296
|
+
source_obj = ImageGCSSource(source=source)
|
297
|
+
return source_obj.get_source
|
298
|
+
|
299
|
+
|
300
|
+
def data_source_combined_gcs(sources=[]):
|
301
|
+
"""Legacy function for combined GCS data sources."""
|
302
|
+
source_obj = CombinedImageGCSSource(sources=sources)
|
303
|
+
return source_obj.get_source
|
304
|
+
|
305
|
+
|
306
|
+
def gcs_augmenters(image_scale, method):
|
307
|
+
"""Legacy function for GCS augmenters."""
|
308
|
+
augmenter = ImageGCSAugmenter()
|
309
|
+
return augmenter.create_transform(image_scale=image_scale, method=method)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
|
2
|
+
import numpy as np
|
3
|
+
from decord.video_reader import VideoReader
|
4
|
+
from decord.audio_reader import AudioReader
|
5
|
+
|
6
|
+
from decord.ndarray import cpu
|
7
|
+
from decord import ndarray as _nd
|
8
|
+
from decord.bridge import bridge_out
|
9
|
+
|
10
|
+
class AVReader(object):
|
11
|
+
"""Individual audio video reader with convenient indexing function.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
uri: str
|
16
|
+
Path of file.
|
17
|
+
ctx: decord.Context
|
18
|
+
The context to decode the file, can be decord.cpu() or decord.gpu().
|
19
|
+
sample_rate: int, default is -1
|
20
|
+
Desired output sample rate of the audio, unchanged if `-1` is specified.
|
21
|
+
mono: bool, default is True
|
22
|
+
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
|
23
|
+
width : int, default is -1
|
24
|
+
Desired output width of the video, unchanged if `-1` is specified.
|
25
|
+
height : int, default is -1
|
26
|
+
Desired output height of the video, unchanged if `-1` is specified.
|
27
|
+
num_threads : int, default is 0
|
28
|
+
Number of decoding thread, auto if `0` is specified.
|
29
|
+
fault_tol : int, default is -1
|
30
|
+
The threshold of corupted and recovered frames. This is to prevent silent fault
|
31
|
+
tolerance when for example 50% frames of a video cannot be decoded and duplicate
|
32
|
+
frames are returned. You may find the fault tolerant feature sweet in many cases,
|
33
|
+
but not for training models. Say `N = # recovered frames`
|
34
|
+
If `fault_tol` < 0, nothing will happen.
|
35
|
+
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
|
36
|
+
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self, uri, ctx=cpu(0), sample_rate=-1, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
|
41
|
+
):
|
42
|
+
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
|
43
|
+
self.__audio_reader.add_padding()
|
44
|
+
if hasattr(uri, "read"):
|
45
|
+
uri.seek(0)
|
46
|
+
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
|
47
|
+
self.__video_reader.seek(0)
|
48
|
+
|
49
|
+
def __del__(self):
|
50
|
+
del self.__video_reader
|
51
|
+
del self.__audio_reader
|
52
|
+
|
53
|
+
def __len__(self):
|
54
|
+
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
|
55
|
+
we always follow what FFMPEG reports.
|
56
|
+
Returns
|
57
|
+
-------
|
58
|
+
int
|
59
|
+
The number of frames in the video file.
|
60
|
+
"""
|
61
|
+
return len(self.__video_reader)
|
62
|
+
|
63
|
+
def __getitem__(self, idx):
|
64
|
+
"""Get audio samples and video frame at `idx`.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
idx : int or slice
|
69
|
+
The frame index, can be negative which means it will index backwards,
|
70
|
+
or slice of frame indices.
|
71
|
+
|
72
|
+
Returns
|
73
|
+
-------
|
74
|
+
(ndarray/list of ndarray, ndarray)
|
75
|
+
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
|
76
|
+
where N is the number of frames, C is the number of channels,
|
77
|
+
S is the number of samples of the corresponding frame.
|
78
|
+
|
79
|
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
80
|
+
where N is the length of the slice.
|
81
|
+
"""
|
82
|
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
83
|
+
if isinstance(idx, slice):
|
84
|
+
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
|
85
|
+
if idx < 0:
|
86
|
+
idx += len(self.__video_reader)
|
87
|
+
if idx >= len(self.__video_reader) or idx < 0:
|
88
|
+
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
|
89
|
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
90
|
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
91
|
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
92
|
+
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
|
93
|
+
self.__video_reader.seek(0)
|
94
|
+
return results
|
95
|
+
|
96
|
+
def get_batch(self, indices):
|
97
|
+
"""Get entire batch of audio samples and video frames.
|
98
|
+
|
99
|
+
Parameters
|
100
|
+
----------
|
101
|
+
indices : list of integers
|
102
|
+
A list of frame indices. If negative indices detected, the indices will be indexed from backward
|
103
|
+
Returns
|
104
|
+
-------
|
105
|
+
(list of ndarray, ndarray)
|
106
|
+
First element is a list of length N containing samples of shape CxS,
|
107
|
+
where N is the number of frames, C is the number of channels,
|
108
|
+
S is the number of samples of the corresponding frame.
|
109
|
+
|
110
|
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
111
|
+
where N is the length of the slice.
|
112
|
+
|
113
|
+
"""
|
114
|
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
115
|
+
indices = self._validate_indices(indices)
|
116
|
+
audio_arr = []
|
117
|
+
prev_video_idx = None
|
118
|
+
prev_audio_end_idx = None
|
119
|
+
for idx in list(indices):
|
120
|
+
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
|
121
|
+
# timestamp and sample conversion could have some error that could cause non-continuous audio
|
122
|
+
# we detect if retrieving continuous frame and make the audio continuous
|
123
|
+
if prev_video_idx and idx == prev_video_idx + 1:
|
124
|
+
audio_start_idx = prev_audio_end_idx
|
125
|
+
else:
|
126
|
+
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
|
127
|
+
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
|
128
|
+
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
|
129
|
+
prev_video_idx = idx
|
130
|
+
prev_audio_end_idx = audio_end_idx
|
131
|
+
results = (audio_arr, self.__video_reader.get_batch(indices))
|
132
|
+
self.__video_reader.seek(0)
|
133
|
+
return results
|
134
|
+
|
135
|
+
def _get_slice(self, sl):
|
136
|
+
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
|
137
|
+
for idx in list(sl):
|
138
|
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
139
|
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
140
|
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
141
|
+
audio_arr = np.concatenate(
|
142
|
+
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
|
143
|
+
)
|
144
|
+
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
|
145
|
+
self.__video_reader.seek(0)
|
146
|
+
return results
|
147
|
+
|
148
|
+
def _validate_indices(self, indices):
|
149
|
+
"""Validate int64 integers and convert negative integers to positive by backward search"""
|
150
|
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
151
|
+
indices = np.array(indices, dtype=np.int64)
|
152
|
+
# process negative indices
|
153
|
+
indices[indices < 0] += len(self.__video_reader)
|
154
|
+
if not (indices >= 0).all():
|
155
|
+
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
|
156
|
+
if not (indices < len(self.__video_reader)).all():
|
157
|
+
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
|
158
|
+
return indices
|