flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,129 @@
1
+ from abc import ABC, abstractmethod
2
+ import grain.python as pygrain
3
+ from typing import Dict, Any, Callable, List, Optional
4
+ import jax.numpy as jnp
5
+ from functools import partial
6
+
7
+
8
+ class DataSource(ABC):
9
+ """Base class for all data sources in FlaxDiff."""
10
+
11
+ @abstractmethod
12
+ def get_source(self, path_override: str) -> Any:
13
+ """Return the data source object.
14
+
15
+ Args:
16
+ path_override: Path to the dataset, overriding the default.
17
+
18
+ Returns:
19
+ A data source object compatible with grain or other loaders.
20
+ """
21
+ pass
22
+
23
+ @staticmethod
24
+ def create(source_type: str, **kwargs) -> 'DataSource':
25
+ """Factory method to create a data source of the specified type.
26
+
27
+ Args:
28
+ source_type: Type of the data source ("image", "video", etc.)
29
+ **kwargs: Additional arguments for the specific data source.
30
+
31
+ Returns:
32
+ An instance of a DataSource subclass.
33
+ """
34
+ from .images import ImageTFDSSource, ImageGCSSource, CombinedImageGCSSource
35
+ from .videos import VideoTFDSSource, VideoLocalSource
36
+
37
+ source_map = {
38
+ "image_tfds": ImageTFDSSource,
39
+ "image_gcs": ImageGCSSource,
40
+ "image_combined_gcs": CombinedImageGCSSource,
41
+ "video_tfds": VideoTFDSSource,
42
+ "video_local": VideoLocalSource
43
+ }
44
+
45
+ if source_type not in source_map:
46
+ raise ValueError(f"Unknown source type: {source_type}")
47
+ return source_map[source_type](**kwargs)
48
+
49
+
50
+ class DataAugmenter(ABC):
51
+ """Base class for all data augmenters in FlaxDiff."""
52
+
53
+ @abstractmethod
54
+ def create_transform(self, **kwargs) -> Callable[[], pygrain.MapTransform]:
55
+ """Create a transformation function for the data.
56
+
57
+ Args:
58
+ **kwargs: Additional arguments for the transformation.
59
+
60
+ Returns:
61
+ A callable that returns a pygrain.MapTransform instance.
62
+ """
63
+ pass
64
+
65
+ @staticmethod
66
+ def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
67
+ """Factory method to create a data augmenter of the specified type.
68
+
69
+ Args:
70
+ augmenter_type: Type of the data augmenter ("image", "video", etc.)
71
+ **kwargs: Additional arguments for the specific augmenter.
72
+
73
+ Returns:
74
+ An instance of a DataAugmenter subclass.
75
+ """
76
+ from .images import ImageTFDSAugmenter, ImageGCSAugmenter
77
+ from .videos import VideoAugmenter
78
+
79
+ augmenter_map = {
80
+ "image_tfds": ImageTFDSAugmenter,
81
+ "image_gcs": ImageGCSAugmenter,
82
+ "video": VideoAugmenter
83
+ }
84
+
85
+ if augmenter_type not in augmenter_map:
86
+ raise ValueError(f"Unknown augmenter type: {augmenter_type}")
87
+
88
+ return augmenter_map[augmenter_type](**kwargs)
89
+
90
+
91
+ class MediaDataset:
92
+ """A class combining a data source and an augmenter for a complete dataset."""
93
+
94
+ def __init__(self,
95
+ source: DataSource,
96
+ augmenter: DataAugmenter,
97
+ media_type: str = "image"):
98
+ """Initialize a MediaDataset.
99
+
100
+ Args:
101
+ source: The data source.
102
+ augmenter: The data augmenter.
103
+ media_type: Type of media ("image", "video", etc.)
104
+ """
105
+ self.source = source
106
+ self.augmenter = augmenter
107
+ self.media_type = media_type
108
+
109
+ def get_source(self, path_override: str) -> Any:
110
+ """Get the data source.
111
+
112
+ Args:
113
+ path_override: Path to override the default data source path.
114
+
115
+ Returns:
116
+ A data source object.
117
+ """
118
+ return self.source.get_source(path_override)
119
+
120
+ def get_augmenter(self, **kwargs) -> Callable[[], pygrain.MapTransform]:
121
+ """Get the augmenter transformation.
122
+
123
+ Args:
124
+ **kwargs: Additional arguments for the augmenter.
125
+
126
+ Returns:
127
+ A callable that returns a pygrain.MapTransform instance.
128
+ """
129
+ return self.augmenter.create_transform(**kwargs)
@@ -0,0 +1,309 @@
1
+ import cv2
2
+ import jax.numpy as jnp
3
+ import grain.python as pygrain
4
+ from flaxdiff.utils import AutoTextTokenizer
5
+ from typing import Dict, Any, Callable, List, Optional
6
+ import random
7
+ import augmax
8
+ import jax
9
+ import os
10
+ import struct as st
11
+ from functools import partial
12
+ import numpy as np
13
+ from .base import DataSource, DataAugmenter
14
+
15
+
16
+ # ----------------------------------------------------------------------------------
17
+ # Utility functions
18
+ # ----------------------------------------------------------------------------------
19
+
20
+ def unpack_dict_of_byte_arrays(packed_data):
21
+ """Unpacks a dictionary of byte arrays from a packed binary format."""
22
+ unpacked_dict = {}
23
+ offset = 0
24
+ while offset < len(packed_data):
25
+ # Unpack the key length
26
+ key_length = st.unpack_from('I', packed_data, offset)[0]
27
+ offset += st.calcsize('I')
28
+ # Unpack the key bytes and convert to string
29
+ key = packed_data[offset:offset+key_length].decode('utf-8')
30
+ offset += key_length
31
+ # Unpack the byte array length
32
+ byte_array_length = st.unpack_from('I', packed_data, offset)[0]
33
+ offset += st.calcsize('I')
34
+ # Unpack the byte array
35
+ byte_array = packed_data[offset:offset+byte_array_length]
36
+ offset += byte_array_length
37
+ unpacked_dict[key] = byte_array
38
+ return unpacked_dict
39
+
40
+
41
+ # ----------------------------------------------------------------------------------
42
+ # Image augmentation utilities
43
+ # ----------------------------------------------------------------------------------
44
+
45
+ def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
46
+ """Basic image augmentation: convert color and resize."""
47
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
48
+ image = cv2.resize(image, (image_scale, image_scale),
49
+ interpolation=method)
50
+ return image
51
+
52
+
53
+ PROMPT_TEMPLATES = [
54
+ "a photo of a {}",
55
+ "a photo of a {} flower",
56
+ "This is a photo of a {}",
57
+ "This is a photo of a {} flower",
58
+ "A photo of a {} flower",
59
+ ]
60
+
61
+
62
+ def labelizer_oxford_flowers102(path):
63
+ """Creates a label generator for Oxford Flowers 102 dataset."""
64
+ with open(path, "r") as f:
65
+ textlabels = [i.strip() for i in f.readlines()]
66
+
67
+ def load_labels(sample):
68
+ raw = textlabels[int(sample['label'])]
69
+ # randomly select a prompt template
70
+ template = random.choice(PROMPT_TEMPLATES)
71
+ # format the template with the label
72
+ caption = template.format(raw)
73
+ # return the caption
74
+ return caption
75
+ return load_labels
76
+
77
+
78
+ # ----------------------------------------------------------------------------------
79
+ # TFDS Image Source
80
+ # ----------------------------------------------------------------------------------
81
+
82
+ class ImageTFDSSource(DataSource):
83
+ """Data source for TensorFlow Datasets (TFDS) image datasets."""
84
+
85
+ def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
86
+ """Initialize a TFDS image data source.
87
+
88
+ Args:
89
+ name: Name of the TFDS dataset.
90
+ use_tf: Whether to use TensorFlow for loading.
91
+ split: Dataset split to use.
92
+ """
93
+ self.name = name
94
+ self.use_tf = use_tf
95
+ self.split = split
96
+
97
+ def get_source(self, path_override: str) -> Any:
98
+ """Get the TFDS data source.
99
+
100
+ Args:
101
+ path_override: Override path for the dataset.
102
+
103
+ Returns:
104
+ A TFDS dataset.
105
+ """
106
+ import tensorflow_datasets as tfds
107
+ if self.use_tf:
108
+ return tfds.load(self.name, split=self.split, shuffle_files=True)
109
+ else:
110
+ return tfds.data_source(self.name, split=self.split, try_gcs=False)
111
+
112
+
113
+ class ImageTFDSAugmenter(DataAugmenter):
114
+ """Augmenter for TFDS image datasets."""
115
+
116
+ def __init__(self, label_path: str = "/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"):
117
+ """Initialize a TFDS image augmenter.
118
+
119
+ Args:
120
+ label_path: Path to the labels file for datasets like Oxford Flowers.
121
+ """
122
+ self.label_path = label_path
123
+
124
+ def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
125
+ """Create a transform for TFDS image datasets.
126
+
127
+ Args:
128
+ image_scale: Size to scale images to.
129
+ method: Interpolation method for resizing.
130
+
131
+ Returns:
132
+ A callable that returns a pygrain.MapTransform.
133
+ """
134
+ labelizer = labelizer_oxford_flowers102(self.label_path)
135
+
136
+ if image_scale > 256:
137
+ interpolation = cv2.INTER_CUBIC
138
+ else:
139
+ interpolation = cv2.INTER_AREA
140
+
141
+ from torchvision.transforms import v2
142
+ augments = v2.Compose([
143
+ v2.RandomHorizontalFlip(p=0.5),
144
+ v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
145
+ ])
146
+
147
+ class TFDSTransform(pygrain.MapTransform):
148
+ def __init__(self, *args, **kwargs):
149
+ super().__init__(*args, **kwargs)
150
+ self.tokenize = AutoTextTokenizer(tensor_type="np")
151
+
152
+ def map(self, element) -> Dict[str, jnp.array]:
153
+ image = element['image']
154
+ image = cv2.resize(image, (image_scale, image_scale),
155
+ interpolation=interpolation)
156
+ image = augments(image)
157
+
158
+ caption = labelizer(element)
159
+ results = self.tokenize(caption)
160
+ return {
161
+ "image": image,
162
+ "text": {
163
+ "input_ids": results['input_ids'][0],
164
+ "attention_mask": results['attention_mask'][0],
165
+ }
166
+ }
167
+
168
+ return TFDSTransform
169
+
170
+
171
+ # ----------------------------------------------------------------------------------
172
+ # GCS Image Source
173
+ # ----------------------------------------------------------------------------------
174
+
175
+ class ImageGCSSource(DataSource):
176
+ """Data source for Google Cloud Storage (GCS) image datasets."""
177
+
178
+ def __init__(self, source: str = 'arrayrecord/laion-aesthetics-12m+mscoco-2017'):
179
+ """Initialize a GCS image data source.
180
+
181
+ Args:
182
+ source: Path to the GCS dataset.
183
+ """
184
+ self.source = source
185
+
186
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
187
+ """Get the GCS data source.
188
+
189
+ Args:
190
+ path_override: Base path for GCS mounts.
191
+
192
+ Returns:
193
+ A grain ArrayRecordDataSource.
194
+ """
195
+ records_path = os.path.join(path_override, self.source)
196
+ records = [os.path.join(records_path, i) for i in os.listdir(
197
+ records_path) if 'array_record' in i]
198
+ return pygrain.ArrayRecordDataSource(records)
199
+
200
+
201
+ class CombinedImageGCSSource(DataSource):
202
+ """Data source that combines multiple GCS image datasets."""
203
+
204
+ def __init__(self, sources: List[str] = []):
205
+ """Initialize a combined GCS image data source.
206
+
207
+ Args:
208
+ sources: List of paths to GCS datasets.
209
+ """
210
+ self.sources = sources
211
+
212
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
213
+ """Get the combined GCS data source.
214
+
215
+ Args:
216
+ path_override: Base path for GCS mounts.
217
+
218
+ Returns:
219
+ A grain ArrayRecordDataSource.
220
+ """
221
+ records_paths = [os.path.join(path_override, source) for source in self.sources]
222
+ records = []
223
+ for records_path in records_paths:
224
+ records += [os.path.join(records_path, i) for i in os.listdir(
225
+ records_path) if 'array_record' in i]
226
+ return pygrain.ArrayRecordDataSource(records)
227
+
228
+
229
+ class ImageGCSAugmenter(DataAugmenter):
230
+ """Augmenter for GCS image datasets."""
231
+
232
+ def __init__(self, labelizer: Callable = None):
233
+ """Initialize a GCS image augmenter.
234
+
235
+ Args:
236
+ labelizer: Function to extract text labels from samples.
237
+ """
238
+ self.labelizer = labelizer or (lambda sample: sample['txt'])
239
+
240
+ def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
241
+ """Create a transform for GCS image datasets.
242
+
243
+ Args:
244
+ image_scale: Size to scale images to.
245
+ method: Interpolation method for resizing.
246
+
247
+ Returns:
248
+ A callable that returns a pygrain.MapTransform.
249
+ """
250
+ labelizer = self.labelizer
251
+
252
+ class GCSTransform(pygrain.MapTransform):
253
+ def __init__(self, *args, **kwargs):
254
+ super().__init__(*args, **kwargs)
255
+ self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
256
+ self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
257
+
258
+ def map(self, element) -> Dict[str, jnp.array]:
259
+ element = unpack_dict_of_byte_arrays(element)
260
+ image = np.asarray(bytearray(element['jpg']), dtype="uint8")
261
+ image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
262
+ image = self.image_augmenter(image)
263
+ caption = labelizer(element).decode('utf-8')
264
+ results = self.auto_tokenize(caption)
265
+ return {
266
+ "image": image,
267
+ "text": {
268
+ "input_ids": results['input_ids'][0],
269
+ "attention_mask": results['attention_mask'][0],
270
+ }
271
+ }
272
+
273
+ return GCSTransform
274
+
275
+
276
+ # ----------------------------------------------------------------------------------
277
+ # Legacy compatibility functions
278
+ # ----------------------------------------------------------------------------------
279
+
280
+ # These functions maintain backward compatibility with existing code
281
+
282
+ def data_source_tfds(name, use_tf=True, split="all"):
283
+ """Legacy function for TFDS data sources."""
284
+ source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
285
+ return source.get_source
286
+
287
+
288
+ def tfds_augmenters(image_scale, method):
289
+ """Legacy function for TFDS augmenters."""
290
+ augmenter = ImageTFDSAugmenter()
291
+ return augmenter.create_transform(image_scale=image_scale, method=method)
292
+
293
+
294
+ def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
295
+ """Legacy function for GCS data sources."""
296
+ source_obj = ImageGCSSource(source=source)
297
+ return source_obj.get_source
298
+
299
+
300
+ def data_source_combined_gcs(sources=[]):
301
+ """Legacy function for combined GCS data sources."""
302
+ source_obj = CombinedImageGCSSource(sources=sources)
303
+ return source_obj.get_source
304
+
305
+
306
+ def gcs_augmenters(image_scale, method):
307
+ """Legacy function for GCS augmenters."""
308
+ augmenter = ImageGCSAugmenter()
309
+ return augmenter.create_transform(image_scale=image_scale, method=method)
@@ -0,0 +1,158 @@
1
+
2
+ import numpy as np
3
+ from decord.video_reader import VideoReader
4
+ from decord.audio_reader import AudioReader
5
+
6
+ from decord.ndarray import cpu
7
+ from decord import ndarray as _nd
8
+ from decord.bridge import bridge_out
9
+
10
+ class AVReader(object):
11
+ """Individual audio video reader with convenient indexing function.
12
+
13
+ Parameters
14
+ ----------
15
+ uri: str
16
+ Path of file.
17
+ ctx: decord.Context
18
+ The context to decode the file, can be decord.cpu() or decord.gpu().
19
+ sample_rate: int, default is -1
20
+ Desired output sample rate of the audio, unchanged if `-1` is specified.
21
+ mono: bool, default is True
22
+ Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
23
+ width : int, default is -1
24
+ Desired output width of the video, unchanged if `-1` is specified.
25
+ height : int, default is -1
26
+ Desired output height of the video, unchanged if `-1` is specified.
27
+ num_threads : int, default is 0
28
+ Number of decoding thread, auto if `0` is specified.
29
+ fault_tol : int, default is -1
30
+ The threshold of corupted and recovered frames. This is to prevent silent fault
31
+ tolerance when for example 50% frames of a video cannot be decoded and duplicate
32
+ frames are returned. You may find the fault tolerant feature sweet in many cases,
33
+ but not for training models. Say `N = # recovered frames`
34
+ If `fault_tol` < 0, nothing will happen.
35
+ If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
36
+ If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
37
+ """
38
+
39
+ def __init__(
40
+ self, uri, ctx=cpu(0), sample_rate=-1, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
41
+ ):
42
+ self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
43
+ self.__audio_reader.add_padding()
44
+ if hasattr(uri, "read"):
45
+ uri.seek(0)
46
+ self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
47
+ self.__video_reader.seek(0)
48
+
49
+ def __del__(self):
50
+ del self.__video_reader
51
+ del self.__audio_reader
52
+
53
+ def __len__(self):
54
+ """Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
55
+ we always follow what FFMPEG reports.
56
+ Returns
57
+ -------
58
+ int
59
+ The number of frames in the video file.
60
+ """
61
+ return len(self.__video_reader)
62
+
63
+ def __getitem__(self, idx):
64
+ """Get audio samples and video frame at `idx`.
65
+
66
+ Parameters
67
+ ----------
68
+ idx : int or slice
69
+ The frame index, can be negative which means it will index backwards,
70
+ or slice of frame indices.
71
+
72
+ Returns
73
+ -------
74
+ (ndarray/list of ndarray, ndarray)
75
+ First element is samples of shape CxS or a list of length N containing samples of shape CxS,
76
+ where N is the number of frames, C is the number of channels,
77
+ S is the number of samples of the corresponding frame.
78
+
79
+ Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
80
+ where N is the length of the slice.
81
+ """
82
+ assert self.__video_reader is not None and self.__audio_reader is not None
83
+ if isinstance(idx, slice):
84
+ return self.get_batch(range(*idx.indices(len(self.__video_reader))))
85
+ if idx < 0:
86
+ idx += len(self.__video_reader)
87
+ if idx >= len(self.__video_reader) or idx < 0:
88
+ raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
89
+ audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
90
+ audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
91
+ audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
92
+ results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
93
+ self.__video_reader.seek(0)
94
+ return results
95
+
96
+ def get_batch(self, indices):
97
+ """Get entire batch of audio samples and video frames.
98
+
99
+ Parameters
100
+ ----------
101
+ indices : list of integers
102
+ A list of frame indices. If negative indices detected, the indices will be indexed from backward
103
+ Returns
104
+ -------
105
+ (list of ndarray, ndarray)
106
+ First element is a list of length N containing samples of shape CxS,
107
+ where N is the number of frames, C is the number of channels,
108
+ S is the number of samples of the corresponding frame.
109
+
110
+ Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
111
+ where N is the length of the slice.
112
+
113
+ """
114
+ assert self.__video_reader is not None and self.__audio_reader is not None
115
+ indices = self._validate_indices(indices)
116
+ audio_arr = []
117
+ prev_video_idx = None
118
+ prev_audio_end_idx = None
119
+ for idx in list(indices):
120
+ frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
121
+ # timestamp and sample conversion could have some error that could cause non-continuous audio
122
+ # we detect if retrieving continuous frame and make the audio continuous
123
+ if prev_video_idx and idx == prev_video_idx + 1:
124
+ audio_start_idx = prev_audio_end_idx
125
+ else:
126
+ audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
127
+ audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
128
+ audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
129
+ prev_video_idx = idx
130
+ prev_audio_end_idx = audio_end_idx
131
+ results = (audio_arr, self.__video_reader.get_batch(indices))
132
+ self.__video_reader.seek(0)
133
+ return results
134
+
135
+ def _get_slice(self, sl):
136
+ audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
137
+ for idx in list(sl):
138
+ audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
139
+ audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
140
+ audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
141
+ audio_arr = np.concatenate(
142
+ (audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
143
+ )
144
+ results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
145
+ self.__video_reader.seek(0)
146
+ return results
147
+
148
+ def _validate_indices(self, indices):
149
+ """Validate int64 integers and convert negative integers to positive by backward search"""
150
+ assert self.__video_reader is not None and self.__audio_reader is not None
151
+ indices = np.array(indices, dtype=np.int64)
152
+ # process negative indices
153
+ indices[indices < 0] += len(self.__video_reader)
154
+ if not (indices >= 0).all():
155
+ raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
156
+ if not (indices < len(self.__video_reader)).all():
157
+ raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
158
+ return indices