torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. torchcodec/__init__.py +27 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +130 -0
  7. torchcodec/_core/AVIOTensorContext.h +44 -0
  8. torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
  9. torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
  10. torchcodec/_core/CMakeLists.txt +295 -0
  11. torchcodec/_core/CUDACommon.cpp +330 -0
  12. torchcodec/_core/CUDACommon.h +51 -0
  13. torchcodec/_core/Cache.h +124 -0
  14. torchcodec/_core/CpuDeviceInterface.cpp +509 -0
  15. torchcodec/_core/CpuDeviceInterface.h +141 -0
  16. torchcodec/_core/CudaDeviceInterface.cpp +602 -0
  17. torchcodec/_core/CudaDeviceInterface.h +79 -0
  18. torchcodec/_core/DeviceInterface.cpp +117 -0
  19. torchcodec/_core/DeviceInterface.h +191 -0
  20. torchcodec/_core/Encoder.cpp +1054 -0
  21. torchcodec/_core/Encoder.h +192 -0
  22. torchcodec/_core/FFMPEGCommon.cpp +684 -0
  23. torchcodec/_core/FFMPEGCommon.h +314 -0
  24. torchcodec/_core/FilterGraph.cpp +159 -0
  25. torchcodec/_core/FilterGraph.h +59 -0
  26. torchcodec/_core/Frame.cpp +47 -0
  27. torchcodec/_core/Frame.h +72 -0
  28. torchcodec/_core/Metadata.cpp +124 -0
  29. torchcodec/_core/Metadata.h +92 -0
  30. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  31. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  32. torchcodec/_core/NVDECCache.cpp +60 -0
  33. torchcodec/_core/NVDECCache.h +102 -0
  34. torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
  35. torchcodec/_core/SingleStreamDecoder.h +391 -0
  36. torchcodec/_core/StreamOptions.h +70 -0
  37. torchcodec/_core/Transform.cpp +128 -0
  38. torchcodec/_core/Transform.h +86 -0
  39. torchcodec/_core/ValidationUtils.cpp +35 -0
  40. torchcodec/_core/ValidationUtils.h +21 -0
  41. torchcodec/_core/__init__.py +46 -0
  42. torchcodec/_core/_metadata.py +262 -0
  43. torchcodec/_core/custom_ops.cpp +1090 -0
  44. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
  45. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  46. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  47. torchcodec/_core/ops.py +605 -0
  48. torchcodec/_core/pybind_ops.cpp +50 -0
  49. torchcodec/_frame.py +146 -0
  50. torchcodec/_internally_replaced_utils.py +68 -0
  51. torchcodec/_samplers/__init__.py +7 -0
  52. torchcodec/_samplers/video_clip_sampler.py +419 -0
  53. torchcodec/decoders/__init__.py +12 -0
  54. torchcodec/decoders/_audio_decoder.py +185 -0
  55. torchcodec/decoders/_decoder_utils.py +113 -0
  56. torchcodec/decoders/_video_decoder.py +601 -0
  57. torchcodec/encoders/__init__.py +2 -0
  58. torchcodec/encoders/_audio_encoder.py +149 -0
  59. torchcodec/encoders/_video_encoder.py +196 -0
  60. torchcodec/libtorchcodec_core4.so +0 -0
  61. torchcodec/libtorchcodec_core5.so +0 -0
  62. torchcodec/libtorchcodec_core6.so +0 -0
  63. torchcodec/libtorchcodec_core7.so +0 -0
  64. torchcodec/libtorchcodec_core8.so +0 -0
  65. torchcodec/libtorchcodec_custom_ops4.so +0 -0
  66. torchcodec/libtorchcodec_custom_ops5.so +0 -0
  67. torchcodec/libtorchcodec_custom_ops6.so +0 -0
  68. torchcodec/libtorchcodec_custom_ops7.so +0 -0
  69. torchcodec/libtorchcodec_custom_ops8.so +0 -0
  70. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  71. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  72. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  73. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  74. torchcodec/libtorchcodec_pybind_ops8.so +0 -0
  75. torchcodec/samplers/__init__.py +2 -0
  76. torchcodec/samplers/_common.py +84 -0
  77. torchcodec/samplers/_index_based.py +287 -0
  78. torchcodec/samplers/_time_based.py +358 -0
  79. torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
  80. torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
  81. torchcodec/transforms/__init__.py +12 -0
  82. torchcodec/transforms/_decoder_transforms.py +375 -0
  83. torchcodec/version.py +2 -0
  84. torchcodec-0.10.0.dist-info/METADATA +286 -0
  85. torchcodec-0.10.0.dist-info/RECORD +88 -0
  86. torchcodec-0.10.0.dist-info/WHEEL +5 -0
  87. torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
  88. torchcodec-0.10.0.dist-info/top_level.txt +2 -0
torchcodec/_frame.py ADDED
@@ -0,0 +1,146 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import dataclasses
9
+ from collections.abc import Iterable, Iterator
10
+ from dataclasses import dataclass
11
+
12
+ from torch import Tensor
13
+
14
+
15
+ def _frame_repr(self):
16
+ # Utility to replace __repr__ method of dataclasses below. This prints the
17
+ # shape of the .data tensor rather than printing the (potentially very long)
18
+ # data tensor itself.
19
+ s = self.__class__.__name__ + ":\n"
20
+ spaces = " "
21
+ for field in dataclasses.fields(self):
22
+ field_name = field.name
23
+ field_val = getattr(self, field_name)
24
+ if field_name == "data":
25
+ field_name = "data (shape)"
26
+ field_val = field_val.shape
27
+ s += f"{spaces}{field_name}: {field_val}\n"
28
+ return s
29
+
30
+
31
+ @dataclass
32
+ class Frame(Iterable):
33
+ """A single video frame with associated metadata."""
34
+
35
+ data: Tensor
36
+ """The frame data as (3-D ``torch.Tensor``)."""
37
+ pts_seconds: float
38
+ """The :term:`pts` of the frame, in seconds (float)."""
39
+ duration_seconds: float
40
+ """The duration of the frame, in seconds (float)."""
41
+
42
+ def __post_init__(self):
43
+ # This is called after __init__() when a Frame is created. We can run
44
+ # input validation checks here.
45
+ if not self.data.ndim == 3:
46
+ raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
47
+ self.pts_seconds = float(self.pts_seconds)
48
+ self.duration_seconds = float(self.duration_seconds)
49
+
50
+ def __iter__(self) -> Iterator[Tensor | float]:
51
+ for field in dataclasses.fields(self):
52
+ yield getattr(self, field.name)
53
+
54
+ def __repr__(self):
55
+ return _frame_repr(self)
56
+
57
+
58
+ @dataclass
59
+ class FrameBatch(Iterable):
60
+ """Multiple video frames with associated metadata.
61
+
62
+ The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
63
+ or 5D for sequences of clips, as returned by the :ref:`samplers
64
+ <sphx_glr_generated_examples_decoding_sampling.py>`. When ``data`` is 4D (resp. 5D)
65
+ the ``pts_seconds`` and ``duration_seconds`` tensors are 1D (resp. 2D).
66
+
67
+ .. note::
68
+ The ``pts_seconds`` and ``duration_seconds`` Tensors are always returned
69
+ on CPU, even if ``data`` is on GPU.
70
+ """
71
+
72
+ data: Tensor
73
+ """The frames data (``torch.Tensor`` of uint8)."""
74
+ pts_seconds: Tensor
75
+ """The :term:`pts` of the frame, in seconds (``torch.Tensor`` of floats)."""
76
+ duration_seconds: Tensor
77
+ """The duration of the frame, in seconds (``torch.Tensor`` of floats)."""
78
+
79
+ def __post_init__(self):
80
+ # This is called after __init__() when a FrameBatch is created. We can
81
+ # run input validation checks here.
82
+ if self.data.ndim < 3:
83
+ raise ValueError(
84
+ f"data must be at least 3-dimensional, got {self.data.shape = }"
85
+ )
86
+
87
+ leading_dims = self.data.shape[:-3]
88
+ if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape):
89
+ raise ValueError(
90
+ "Tried to create a FrameBatch but the leading dimensions of the inputs do not match. "
91
+ f"Got {self.data.shape = } so we expected the shape of pts_seconds and "
92
+ f"duration_seconds to be {leading_dims = }, but got "
93
+ f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
94
+ )
95
+
96
+ def __iter__(self) -> Iterator["FrameBatch"]:
97
+ for data, pts_seconds, duration_seconds in zip(
98
+ self.data, self.pts_seconds, self.duration_seconds
99
+ ):
100
+ yield FrameBatch(
101
+ data=data,
102
+ pts_seconds=pts_seconds,
103
+ duration_seconds=duration_seconds,
104
+ )
105
+
106
+ def __getitem__(self, key) -> "FrameBatch":
107
+ return FrameBatch(
108
+ data=self.data[key],
109
+ pts_seconds=self.pts_seconds[key],
110
+ duration_seconds=self.duration_seconds[key],
111
+ )
112
+
113
+ def __len__(self):
114
+ return len(self.data)
115
+
116
+ def __repr__(self):
117
+ return _frame_repr(self)
118
+
119
+
120
+ @dataclass
121
+ class AudioSamples(Iterable):
122
+ """Audio samples with associated metadata."""
123
+
124
+ data: Tensor
125
+ """The sample data (``torch.Tensor`` of float in [-1, 1], shape is ``(num_channels, num_samples)``)."""
126
+ pts_seconds: float
127
+ """The :term:`pts` of the first sample, in seconds."""
128
+ duration_seconds: float
129
+ """The duration of the samples, in seconds."""
130
+ sample_rate: int
131
+ """The sample rate of the samples, in Hz."""
132
+
133
+ def __post_init__(self):
134
+ # This is called after __init__() when a Frame is created. We can run
135
+ # input validation checks here.
136
+ if not self.data.ndim == 2:
137
+ raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }")
138
+ self.pts_seconds = float(self.pts_seconds)
139
+ self.sample_rate = int(self.sample_rate)
140
+
141
+ def __iter__(self) -> Iterator[Tensor | float]:
142
+ for field in dataclasses.fields(self):
143
+ yield getattr(self, field.name)
144
+
145
+ def __repr__(self):
146
+ return _frame_repr(self)
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib
8
+ import importlib.util
9
+ import sys
10
+ from pathlib import Path
11
+ from types import ModuleType
12
+
13
+
14
+ # Copy pasted from torchvision
15
+ # https://github.com/pytorch/vision/blob/947ae1dc71867f28021d5bc0ff3a19c249236e2a/torchvision/_internally_replaced_utils.py#L25
16
+ def _get_extension_path(lib_name: str) -> str:
17
+ extension_suffixes = []
18
+ if sys.platform == "linux":
19
+ extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES
20
+ elif sys.platform == "darwin":
21
+ extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES + [".dylib"]
22
+ elif sys.platform in ("win32", "cygwin"):
23
+ extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES + [".dll", ".pyd"]
24
+ else:
25
+ raise NotImplementedError(f"{sys.platform = } is not not supported")
26
+ loader_details = (
27
+ importlib.machinery.ExtensionFileLoader,
28
+ extension_suffixes,
29
+ )
30
+
31
+ extfinder = importlib.machinery.FileFinder(
32
+ str(Path(__file__).parent), loader_details
33
+ )
34
+ ext_specs = extfinder.find_spec(lib_name)
35
+ if ext_specs is None:
36
+ raise ImportError(f"No spec found for {lib_name}")
37
+
38
+ if ext_specs.origin is None:
39
+ raise ImportError(f"Existing spec found for {lib_name} does not have an origin")
40
+
41
+ return ext_specs.origin
42
+
43
+
44
+ def _load_pybind11_module(module_name: str, library_path: str) -> ModuleType:
45
+ spec = importlib.util.spec_from_file_location(
46
+ module_name,
47
+ library_path,
48
+ )
49
+ if spec is None or spec.loader is None:
50
+ raise ImportError(
51
+ f"Unable to load spec or spec.loader for module {module_name} from path {library_path}"
52
+ )
53
+
54
+ mod = importlib.util.module_from_spec(spec)
55
+ spec.loader.exec_module(mod)
56
+
57
+ return mod
58
+
59
+
60
+ # Note that the return value from this function must match the value used as
61
+ # PYBIND_OPS_MODULE_NAME when we compile _core/pybind_ops.cpp. If the values
62
+ # do not match, we will not be able to import the C++ shared library as a
63
+ # Python module at runtime.
64
+ #
65
+ # The parameter ffmpeg_major_version is unused externally, but used
66
+ # internally.
67
+ def _get_pybind_ops_module_name(ffmpeg_major_version: int) -> str:
68
+ return "core_pybind_ops"
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .video_clip_sampler import * # noqa
@@ -0,0 +1,419 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import abc
9
+ import json
10
+ import sys
11
+ from dataclasses import dataclass, field
12
+ from typing import Any
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from torchcodec._core import (
18
+ add_video_stream,
19
+ create_from_tensor,
20
+ get_frames_at_indices,
21
+ get_json_metadata,
22
+ get_next_frame,
23
+ scan_all_streams_to_update_metadata,
24
+ seek_to_pts,
25
+ )
26
+
27
+
28
+ class VideoTooShortException(Exception):
29
+ pass
30
+
31
+
32
+ @dataclass
33
+ class DecoderArgs:
34
+ num_threads: int = 0
35
+
36
+
37
+ @dataclass
38
+ class VideoArgs:
39
+ """
40
+ VideoArgs contains video related information. Video width/heigh can't be co-exist with video min/max dimension.
41
+ Args:
42
+ desired_width (`int`): Target width of the video
43
+ desired_height (`int`): Target height of the video
44
+ desired_max_dimension (`int`): Target maximum dimension of the video
45
+ desired_min_dimension (`int`): Target minimum dimension of the video
46
+ """
47
+
48
+ desired_width: int = 0
49
+ desired_height: int = 0
50
+ desired_max_dimension: int = 0
51
+ desired_min_dimension: int = 0
52
+
53
+
54
+ @dataclass
55
+ class SamplerArgs(abc.ABC):
56
+ """
57
+ Abstract class of sampler args, extended by TimeBasedSamplerArgs and IndexBasedSamplerArgs.
58
+ Frame refers to a video/audio frame, and clip is a list of frames which may be non-consecutive.
59
+ Args:
60
+ sampler_type (`str`): Sampler type, can be random, uniform, periodic, target
61
+ clips_per_video (`int`): Number of clips per video, this applys to random and uniform sampling
62
+ frames_per_clip (`int`): Number of frames per clip
63
+ """
64
+
65
+ sampler_type: str
66
+ clips_per_video: int
67
+ frames_per_clip: int
68
+
69
+
70
+ @dataclass
71
+ class TimeBasedSamplerArgs(SamplerArgs):
72
+ """
73
+ TimeBasedSamplerArgs inherits from SamplerArgs and describe the time based sampling behavior.
74
+ Args:
75
+ video_frame_dilation (`int`): Frame dilation of the video, if frame dilation is 2, we will sample every other frame within a clip.
76
+ sample_start_second (`float`): Start second of the sampler range, applies to all sampler types
77
+ sample_end_second (`float`): End second of the sampler range, applies to all sampler types
78
+ sample_per_second (`float`): Sample per second of the sampler range, applies to periodic sampling
79
+ target_sample_start_second (`float`): Start second of the target sampling range, applies to target sampling
80
+ """
81
+
82
+ video_frame_dilation: int = 1
83
+ sample_start_second: float = 0.0
84
+ sample_end_second: float = float("inf")
85
+ sample_per_second: float = 0.0
86
+ target_sample_start_second: list[float] = field(default_factory=lambda: [])
87
+
88
+
89
+ @dataclass
90
+ class IndexBasedSamplerArgs(SamplerArgs):
91
+ """
92
+ IndexBasedSamplerArgs inherits from SamplerArgs and describe the index based sampling behavior.
93
+ sample_start_index and sample_end_index together decide the range of the sampling.
94
+ sample_step decides step between each clip.
95
+ video_frame_dilation decides step between each frame within a clip.
96
+ Args:
97
+ video_frame_dilation (`int`): Frame dilation of the video, if frame dilation is 2, we will sample every other frame within a clip, applies to all sampler types
98
+ sample_start_index (`int`): Start index of the sampler range, applies to all sampler types
99
+ sample_end_index (`int`): End index of the sampler range, this is last possile frame you want to sample, applies to all sampler types
100
+ sample_step (`int`): Step of the sampler range, if step is 10, the interval between start frames of each clip will be 10, applies to periodic sampling only.
101
+ """
102
+
103
+ video_frame_dilation: int = 1
104
+ sample_start_index: int = 0
105
+ sample_end_index: int = sys.maxsize
106
+ sample_step: int = 1
107
+
108
+
109
+ class DEPRECATED_VideoClipSampler(nn.Module):
110
+ """
111
+ DEPRECATED: Do not use. The supported samplers are in `torchcodec.samplers`. See:
112
+
113
+ * https://docs.pytorch.org/torchcodec/stable/api_ref_torchcodec.html
114
+ * https://docs.pytorch.org/torchcodec/stable/generated_examples/decoding/sampling.html
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ video_args: VideoArgs,
120
+ sampler_args: SamplerArgs,
121
+ decoder_args: DecoderArgs | None = None,
122
+ ) -> None:
123
+ super().__init__()
124
+ self.video_args = video_args
125
+ self.sampler_args = sampler_args
126
+ self.decoder_args = DecoderArgs() if decoder_args is None else decoder_args
127
+
128
+ def forward(self, video_data: Tensor) -> list[Any]:
129
+ """Sample video clips from the video data
130
+
131
+ Args:
132
+ video_data (`Tensor`): The video data
133
+
134
+ Return
135
+ clips (` list[list[Tensor]]`): List of clips, where each clip is a list of Tensors, each tensor represents a frame image.
136
+
137
+ """
138
+
139
+ video_decoder = create_from_tensor(video_data)
140
+ scan_all_streams_to_update_metadata(video_decoder)
141
+ add_video_stream(video_decoder)
142
+ metadata_json = json.loads(get_json_metadata(video_decoder))
143
+ target_width, target_height = self._compute_frame_width_height(
144
+ metadata_json["width"], metadata_json["height"]
145
+ )
146
+
147
+ video_decoder = create_from_tensor(video_data)
148
+ scan_all_streams_to_update_metadata(video_decoder)
149
+ add_video_stream(
150
+ video_decoder,
151
+ transform_specs=f"resize, {target_height}, {target_width}",
152
+ num_threads=self.decoder_args.num_threads,
153
+ )
154
+
155
+ clips: list[Any] = []
156
+ # Cast sampler args to be time based or index based
157
+ if isinstance(self.sampler_args, TimeBasedSamplerArgs):
158
+ time_based_sampler_args = self.sampler_args
159
+ clip_starts_in_seconds = self._get_start_seconds(
160
+ metadata_json, time_based_sampler_args
161
+ )
162
+ for start_ts in clip_starts_in_seconds:
163
+ clip = self._get_clip_with_start_second(
164
+ start_ts,
165
+ video_decoder,
166
+ time_based_sampler_args.video_frame_dilation,
167
+ )
168
+ clips.append(clip)
169
+ elif isinstance(self.sampler_args, IndexBasedSamplerArgs):
170
+ index_based_sampler_args = self.sampler_args
171
+ clips = self._get_clips_for_index_based_sampling(
172
+ video_decoder,
173
+ index_based_sampler_args,
174
+ metadata_json,
175
+ )
176
+
177
+ return clips
178
+
179
+ def _get_clips_for_index_based_sampling(
180
+ self,
181
+ video_decoder: Tensor,
182
+ index_based_sampler_args: IndexBasedSamplerArgs,
183
+ metadata_json: dict[str, Any],
184
+ ) -> list[Tensor]:
185
+ """Get clips for index based sampling, the sampling is done in 3 steps:
186
+ 1. Compute clip_start_idxs based on the sampler type and the sampler args;
187
+ 2. For each clip, given clip_start_idx, video_frame_dilation, frames_per_clip, get indexes for all frames
188
+ 3. With given index, fetch the frame and group into clip and then clips
189
+
190
+ Args:
191
+ video_decoder (`Tensor`): The video decoder
192
+ index_based_sampler_args (`IndexBasedSamplerArgs`): The index based sampler args
193
+ metadata_json (`dict[str, Any]`): The metadata of the video in json format
194
+
195
+ Returns:
196
+ clips (` list[Tensor]`): List of clips, where each clip is a Tensor represents list of frames, Tensor shape default is NCHW.
197
+ """
198
+
199
+ sample_start_index = max(0, index_based_sampler_args.sample_start_index)
200
+ sample_end_index = (
201
+ min(
202
+ index_based_sampler_args.sample_end_index + 1,
203
+ metadata_json["numFramesFromHeader"],
204
+ )
205
+ - index_based_sampler_args.video_frame_dilation
206
+ * index_based_sampler_args.frames_per_clip
207
+ )
208
+ sampler_type = index_based_sampler_args.sampler_type
209
+
210
+ if sampler_type == "random":
211
+ clip_start_idxs = torch.randint(
212
+ sample_start_index,
213
+ sample_end_index,
214
+ (index_based_sampler_args.clips_per_video,),
215
+ )
216
+ elif sampler_type == "uniform":
217
+ clip_start_idxs = torch.linspace(
218
+ sample_start_index,
219
+ sample_end_index,
220
+ index_based_sampler_args.clips_per_video,
221
+ dtype=torch.int32,
222
+ )
223
+
224
+ clips = []
225
+ for clip_start_idx in clip_start_idxs:
226
+ batch_indexes = [
227
+ clip_start_idx + i * index_based_sampler_args.video_frame_dilation
228
+ for i in range(index_based_sampler_args.frames_per_clip)
229
+ ]
230
+ # Need torch.stack to convert list[Tensor[int]] into 1D Tensor[int]
231
+ batch_indexes = torch.stack(batch_indexes)
232
+ frames, *_ = get_frames_at_indices(
233
+ video_decoder,
234
+ frame_indices=batch_indexes,
235
+ )
236
+ clips.append(frames)
237
+
238
+ return clips
239
+
240
+ def _get_start_seconds(
241
+ self,
242
+ metadata_json: dict[str, Any],
243
+ time_based_sampler_args: TimeBasedSamplerArgs,
244
+ ) -> list[float]:
245
+ """Get start seconds for each clip.
246
+ Given different sampler type, the API returns different clip start seconds.
247
+
248
+ Args:
249
+ metadata_json (`dict[str, Any]`): The metadata of the video in json format
250
+ time_based_sampler_args: (`TimeBasedSamplerArgs`): The time based sampler args
251
+
252
+ Returns:
253
+ (`list[float]`): List of the sampled clip start position in seconds
254
+ """
255
+ video_duration_in_seconds = metadata_json["durationSecondsFromHeader"]
256
+
257
+ clip_duration_in_seconds = (
258
+ time_based_sampler_args.frames_per_clip
259
+ * time_based_sampler_args.video_frame_dilation
260
+ + 1
261
+ ) / metadata_json["averageFpsFromHeader"]
262
+
263
+ beginStreamSecondsFromContent = (
264
+ metadata_json["beginStreamSecondsFromContent"]
265
+ if metadata_json["beginStreamSecondsFromContent"]
266
+ else 0
267
+ )
268
+ endStreamSecondsFromContent = (
269
+ metadata_json["endStreamSecondsFromContent"]
270
+ if metadata_json["endStreamSecondsFromContent"] > 0
271
+ else video_duration_in_seconds
272
+ )
273
+ last_possible_clip_start_in_seconds = (
274
+ endStreamSecondsFromContent - clip_duration_in_seconds
275
+ )
276
+ if last_possible_clip_start_in_seconds < 0:
277
+ raise VideoTooShortException(
278
+ "Cannot get clips because video duration is shorter than the clip duration!"
279
+ )
280
+ sampler_type = time_based_sampler_args.sampler_type
281
+ clip_starts_in_seconds: list[float] = []
282
+ sample_start_second = max(
283
+ time_based_sampler_args.sample_start_second,
284
+ beginStreamSecondsFromContent,
285
+ )
286
+ sample_end_second = min(
287
+ last_possible_clip_start_in_seconds,
288
+ time_based_sampler_args.sample_end_second,
289
+ )
290
+ if sampler_type == "random":
291
+ clip_starts_in_seconds = (
292
+ torch.rand(time_based_sampler_args.clips_per_video)
293
+ * (sample_end_second - sample_start_second)
294
+ + sample_start_second
295
+ ).tolist()
296
+ clip_starts_in_seconds.sort()
297
+ elif sampler_type == "uniform":
298
+ clip_starts_in_seconds = torch.linspace(
299
+ sample_start_second,
300
+ sample_end_second,
301
+ time_based_sampler_args.clips_per_video,
302
+ ).tolist()
303
+ else:
304
+ raise NotImplementedError
305
+
306
+ return clip_starts_in_seconds
307
+
308
+ def _get_clip_with_start_second(
309
+ self, start_second: float, video_decoder: Tensor, video_frame_dilation: int
310
+ ) -> list[Tensor]:
311
+ """Get clip with start second.
312
+
313
+ Args:
314
+ `start_second` (`float`): The start second of the clip
315
+ `video_decoder` (`Tensor`): The video decoder
316
+ `video_frame_dilation` (`int`): The video frame dilation, by default it's 1.
317
+
318
+ Returns:
319
+ `clip` (`list[Tensor]`): clip is list of frame tensor. Dimension of each frame tensor is user specified, by default it's HWC.
320
+ """
321
+ seek_to_pts(video_decoder, start_second)
322
+ frames_needed_per_clip = (
323
+ self.sampler_args.frames_per_clip - 1
324
+ ) * video_frame_dilation + 1
325
+ clip = []
326
+ for _ in range(frames_needed_per_clip):
327
+ frame, _, _ = get_next_frame(video_decoder)
328
+ clip.append(frame)
329
+
330
+ # slice the list of tensor with frame_dilation and stack to tensor
331
+ clip = clip[::video_frame_dilation]
332
+ return clip
333
+
334
+ def _compute_frame_width_height(
335
+ self, ori_width: int, ori_height: int
336
+ ) -> tuple[int, int]:
337
+ """Compute output frame width and height
338
+ desired_width, desired_height, desired_min_dimension, desired_max_dimension, (`int`): Together decide the size of the decoded video clips. (Default: `0`).
339
+ Note that the desired_width/desired_height parameters are mutually exclusive with desired_min_dimension/desired_max_dimension parameters.
340
+ - When desired_width = 0, desired_height = 0, desired_min_dimension = 0,
341
+ and desired_max_dimension = 0, keep the original frame resolution
342
+ - When desired_width = 0, desired_height != 0, desired_min_dimension = 0,
343
+ and desired_max_dimension = 0, keep the aspect ratio and resize
344
+ the frame so that frame target_height is $desired_height
345
+ - When desired_width != 0, desired_height == 0, desired_min_dimension = 0,
346
+ and desired_max_dimension = 0, keep the aspect ratio and resize
347
+ the frame so that frame target_width is $desired_width
348
+ - When desired_width != 0, desired_height != 0, video_min_dimension = 0,
349
+ and desired_max_dimension = 0, resize the frame so that frame
350
+ target_width and target_height are set to $desired_width and
351
+ $desired_height, respectively
352
+ - When desired_width = 0, desired_height = 0, desired_min_dimension != 0,
353
+ and desired_max_dimension = 0, keep the aspect ratio and resize the
354
+ frame so that shorter edge size is desired_min_dimension
355
+ - When desired_width = 0, desired_height = 0, desired_min_dimension = 0,
356
+ and desired_max_dimension != 0, keep the aspect ratio and resize
357
+ the frame so that longer edge size is desired_max_dimension
358
+ - When desired_width = 0, desired_height = 0, desired_min_dimension != 0,
359
+ and desired_max_dimension != 0, resize the frame so that shorter
360
+ edge size is desired_min_dimension, and longer edge size is
361
+ desired_max_dimension. The aspect ratio may not be preserved
362
+
363
+ Args:
364
+ ori_width (`int`): Original width of the video
365
+ ori_height (`int`): Original height of the video
366
+
367
+ Returns:
368
+ (`tuple[int, int]`): output frame width and height
369
+ """
370
+ width_height_ratio = ori_width / ori_height
371
+ height_width_ratio = ori_height / ori_width
372
+
373
+ target_width, target_height = ori_width, ori_height
374
+
375
+ # video_height and/or video_width is non zero
376
+ if self.video_args.desired_width == 0 and self.video_args.desired_height != 0:
377
+ target_height = self.video_args.desired_height
378
+ target_width = int(width_height_ratio * target_height)
379
+ elif self.video_args.desired_width != 0 and self.video_args.desired_height == 0:
380
+ target_width = self.video_args.desired_width
381
+ target_height = int(height_width_ratio * target_width)
382
+ elif self.video_args.desired_width != 0 and self.video_args.desired_height != 0:
383
+ target_width, target_height = (
384
+ self.video_args.desired_width,
385
+ self.video_args.desired_height,
386
+ )
387
+ # video_min_dimension and/or video_max_dimension is non zero
388
+ elif (
389
+ self.video_args.desired_min_dimension != 0
390
+ and self.video_args.desired_max_dimension == 0
391
+ ):
392
+ if ori_width > ori_height:
393
+ target_height = self.video_args.desired_min_dimension
394
+ target_width = int(width_height_ratio * target_height)
395
+ else:
396
+ target_width = self.video_args.desired_min_dimension
397
+ target_height = int(height_width_ratio * target_width)
398
+ elif (
399
+ self.video_args.desired_min_dimension == 0
400
+ and self.video_args.desired_max_dimension != 0
401
+ ):
402
+ if ori_width > ori_height:
403
+ target_width = self.video_args.desired_max_dimension
404
+ target_height = int(height_width_ratio * target_width)
405
+ else:
406
+ target_height = self.video_args.desired_max_dimension
407
+ target_width = int(width_height_ratio * target_height)
408
+ elif (
409
+ self.video_args.desired_min_dimension != 0
410
+ and self.video_args.desired_max_dimension != 0
411
+ ):
412
+ if ori_width > ori_height:
413
+ target_width = self.video_args.desired_max_dimension
414
+ target_height = self.video_args.desired_min_dimension
415
+ else:
416
+ target_height = self.video_args.desired_max_dimension
417
+ target_width = self.video_args.desired_min_dimension
418
+
419
+ return target_width, target_height
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .._core import AudioStreamMetadata, VideoStreamMetadata
8
+ from ._audio_decoder import AudioDecoder # noqa
9
+ from ._decoder_utils import set_cuda_backend # noqa
10
+ from ._video_decoder import CpuFallbackStatus, VideoDecoder # noqa
11
+
12
+ SimpleVideoDecoder = VideoDecoder