torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. torchcodec/__init__.py +27 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +130 -0
  7. torchcodec/_core/AVIOTensorContext.h +44 -0
  8. torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
  9. torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
  10. torchcodec/_core/CMakeLists.txt +295 -0
  11. torchcodec/_core/CUDACommon.cpp +330 -0
  12. torchcodec/_core/CUDACommon.h +51 -0
  13. torchcodec/_core/Cache.h +124 -0
  14. torchcodec/_core/CpuDeviceInterface.cpp +509 -0
  15. torchcodec/_core/CpuDeviceInterface.h +141 -0
  16. torchcodec/_core/CudaDeviceInterface.cpp +602 -0
  17. torchcodec/_core/CudaDeviceInterface.h +79 -0
  18. torchcodec/_core/DeviceInterface.cpp +117 -0
  19. torchcodec/_core/DeviceInterface.h +191 -0
  20. torchcodec/_core/Encoder.cpp +1054 -0
  21. torchcodec/_core/Encoder.h +192 -0
  22. torchcodec/_core/FFMPEGCommon.cpp +684 -0
  23. torchcodec/_core/FFMPEGCommon.h +314 -0
  24. torchcodec/_core/FilterGraph.cpp +159 -0
  25. torchcodec/_core/FilterGraph.h +59 -0
  26. torchcodec/_core/Frame.cpp +47 -0
  27. torchcodec/_core/Frame.h +72 -0
  28. torchcodec/_core/Metadata.cpp +124 -0
  29. torchcodec/_core/Metadata.h +92 -0
  30. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  31. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  32. torchcodec/_core/NVDECCache.cpp +60 -0
  33. torchcodec/_core/NVDECCache.h +102 -0
  34. torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
  35. torchcodec/_core/SingleStreamDecoder.h +391 -0
  36. torchcodec/_core/StreamOptions.h +70 -0
  37. torchcodec/_core/Transform.cpp +128 -0
  38. torchcodec/_core/Transform.h +86 -0
  39. torchcodec/_core/ValidationUtils.cpp +35 -0
  40. torchcodec/_core/ValidationUtils.h +21 -0
  41. torchcodec/_core/__init__.py +46 -0
  42. torchcodec/_core/_metadata.py +262 -0
  43. torchcodec/_core/custom_ops.cpp +1090 -0
  44. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
  45. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  46. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  47. torchcodec/_core/ops.py +605 -0
  48. torchcodec/_core/pybind_ops.cpp +50 -0
  49. torchcodec/_frame.py +146 -0
  50. torchcodec/_internally_replaced_utils.py +68 -0
  51. torchcodec/_samplers/__init__.py +7 -0
  52. torchcodec/_samplers/video_clip_sampler.py +419 -0
  53. torchcodec/decoders/__init__.py +12 -0
  54. torchcodec/decoders/_audio_decoder.py +185 -0
  55. torchcodec/decoders/_decoder_utils.py +113 -0
  56. torchcodec/decoders/_video_decoder.py +601 -0
  57. torchcodec/encoders/__init__.py +2 -0
  58. torchcodec/encoders/_audio_encoder.py +149 -0
  59. torchcodec/encoders/_video_encoder.py +196 -0
  60. torchcodec/libtorchcodec_core4.so +0 -0
  61. torchcodec/libtorchcodec_core5.so +0 -0
  62. torchcodec/libtorchcodec_core6.so +0 -0
  63. torchcodec/libtorchcodec_core7.so +0 -0
  64. torchcodec/libtorchcodec_core8.so +0 -0
  65. torchcodec/libtorchcodec_custom_ops4.so +0 -0
  66. torchcodec/libtorchcodec_custom_ops5.so +0 -0
  67. torchcodec/libtorchcodec_custom_ops6.so +0 -0
  68. torchcodec/libtorchcodec_custom_ops7.so +0 -0
  69. torchcodec/libtorchcodec_custom_ops8.so +0 -0
  70. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  71. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  72. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  73. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  74. torchcodec/libtorchcodec_pybind_ops8.so +0 -0
  75. torchcodec/samplers/__init__.py +2 -0
  76. torchcodec/samplers/_common.py +84 -0
  77. torchcodec/samplers/_index_based.py +287 -0
  78. torchcodec/samplers/_time_based.py +358 -0
  79. torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
  80. torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
  81. torchcodec/transforms/__init__.py +12 -0
  82. torchcodec/transforms/_decoder_transforms.py +375 -0
  83. torchcodec/version.py +2 -0
  84. torchcodec-0.10.0.dist-info/METADATA +286 -0
  85. torchcodec-0.10.0.dist-info/RECORD +88 -0
  86. torchcodec-0.10.0.dist-info/WHEEL +5 -0
  87. torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
  88. torchcodec-0.10.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,509 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include "CpuDeviceInterface.h"
8
+
9
+ namespace facebook::torchcodec {
10
+ namespace {
11
+
12
+ static bool g_cpu = registerDeviceInterface(
13
+ DeviceInterfaceKey(torch::kCPU),
14
+ [](const torch::Device& device) { return new CpuDeviceInterface(device); });
15
+
16
+ } // namespace
17
+
18
+ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
19
+ : DeviceInterface(device) {
20
+ TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
21
+ TORCH_CHECK(
22
+ device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
23
+ }
24
+
25
+ void CpuDeviceInterface::initialize(
26
+ const AVStream* avStream,
27
+ [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
28
+ const SharedAVCodecContext& codecContext) {
29
+ TORCH_CHECK(avStream != nullptr, "avStream is null");
30
+ codecContext_ = codecContext;
31
+ timeBase_ = avStream->time_base;
32
+ }
33
+
34
+ void CpuDeviceInterface::initializeVideo(
35
+ const VideoStreamOptions& videoStreamOptions,
36
+ const std::vector<std::unique_ptr<Transform>>& transforms,
37
+ const std::optional<FrameDims>& resizedOutputDims) {
38
+ avMediaType_ = AVMEDIA_TYPE_VIDEO;
39
+ videoStreamOptions_ = videoStreamOptions;
40
+ resizedOutputDims_ = resizedOutputDims;
41
+
42
+ // We can use swscale when we have a single resize transform.
43
+ // With a single resize, we use swscale twice:
44
+ // first for color conversion (YUV->RGB24), then for resize in RGB24 space.
45
+ //
46
+ // Note that this means swscale will not support the case of having several,
47
+ // back-to-back resizes or other transforms.
48
+ //
49
+ // We calculate this value during initialization but we don't refer to it
50
+ // until getColorConversionLibrary() is called. Calculating this value during
51
+ // initialization saves us from having to save all of the transforms.
52
+ areTransformsSwScaleCompatible_ = transforms.empty() ||
53
+ (transforms.size() == 1 && transforms[0]->isResize());
54
+
55
+ // Note that we do not expose this capability in the public API, only through
56
+ // the core API.
57
+ //
58
+ // Same as above, we calculate this value during initialization and refer to
59
+ // it in getColorConversionLibrary().
60
+ userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
61
+ ColorConversionLibrary::SWSCALE;
62
+
63
+ // We can only use swscale when we have a single resize transform. Note that
64
+ // we actually decide on whether or not to actually use swscale at the last
65
+ // possible moment, when we actually convert the frame. This is because we
66
+ // need to know the actual frame dimensions.
67
+ if (transforms.size() == 1 && transforms[0]->isResize()) {
68
+ auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
69
+ TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!");
70
+ swsFlags_ = resize->getSwsFlags();
71
+ }
72
+
73
+ // If we have any transforms, replace filters_ with the filter strings from
74
+ // the transforms. As noted above, we decide between swscale and filtergraph
75
+ // when we actually decode a frame.
76
+ std::stringstream filters;
77
+ bool first = true;
78
+ for (const auto& transform : transforms) {
79
+ if (!first) {
80
+ filters << ",";
81
+ }
82
+ filters << transform->getFilterGraphCpu();
83
+ first = false;
84
+ }
85
+ if (!transforms.empty()) {
86
+ // Note [Transform and Format Conversion Order]
87
+ // We have to ensure that all user filters happen AFTER the explicit format
88
+ // conversion. That is, we want the filters to be applied in RGB24, not the
89
+ // pixel format of the input frame.
90
+ //
91
+ // The ouput frame will always be in RGB24, as we specify the sink node with
92
+ // AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter
93
+ // conversion to ensure the output frame matches the pixel format
94
+ // specified in the sink. But by default, it will insert it after the user
95
+ // filters. We need an explicit format conversion to get the behavior we
96
+ // want.
97
+ filters_ = "format=rgb24," + filters.str();
98
+ }
99
+
100
+ initialized_ = true;
101
+ }
102
+
103
+ void CpuDeviceInterface::initializeAudio(
104
+ const AudioStreamOptions& audioStreamOptions) {
105
+ avMediaType_ = AVMEDIA_TYPE_AUDIO;
106
+ audioStreamOptions_ = audioStreamOptions;
107
+ initialized_ = true;
108
+ }
109
+
110
+ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
111
+ const FrameDims& outputDims) const {
112
+ // swscale requires widths to be multiples of 32:
113
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
114
+ bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;
115
+
116
+ // We want to use swscale for color conversion if possible because it is
117
+ // faster than filtergraph. The following are the conditions we need to meet
118
+ // to use it.
119
+ //
120
+ // Note that we treat the transform limitation differently from the width
121
+ // limitation. That is, we consider the transforms being compatible with
122
+ // swscale as a hard requirement. If the transforms are not compatiable,
123
+ // then we will end up not applying the transforms, and that is wrong.
124
+ //
125
+ // The width requirement, however, is a soft requirement. Even if we don't
126
+ // meet it, we let the user override it. We have tests that depend on this
127
+ // behavior. Since we don't expose the ability to choose swscale or
128
+ // filtergraph in our public API, this is probably okay. It's also the only
129
+ // way that we can be certain we are testing one versus the other.
130
+ if (areTransformsSwScaleCompatible_ &&
131
+ (userRequestedSwScale_ || isWidthSwScaleCompatible)) {
132
+ return ColorConversionLibrary::SWSCALE;
133
+ } else {
134
+ return ColorConversionLibrary::FILTERGRAPH;
135
+ }
136
+ }
137
+
138
+ void CpuDeviceInterface::convertAVFrameToFrameOutput(
139
+ UniqueAVFrame& avFrame,
140
+ FrameOutput& frameOutput,
141
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
142
+ TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
143
+
144
+ if (avMediaType_ == AVMEDIA_TYPE_AUDIO) {
145
+ convertAudioAVFrameToFrameOutput(avFrame, frameOutput);
146
+ } else {
147
+ convertVideoAVFrameToFrameOutput(
148
+ avFrame, frameOutput, preAllocatedOutputTensor);
149
+ }
150
+ }
151
+
152
+ // Note [preAllocatedOutputTensor with swscale and filtergraph]:
153
+ // Callers may pass a pre-allocated tensor, where the output.data tensor will
154
+ // be stored. This parameter is honored in any case, but it only leads to a
155
+ // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
156
+ // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
157
+ // found a way to do that with filtegraph.
158
+ // TODO: Figure out whether that's possible!
159
+ // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
160
+ // `dimension_order` parameter. It's up to callers to re-shape it if needed.
161
+ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput(
162
+ UniqueAVFrame& avFrame,
163
+ FrameOutput& frameOutput,
164
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
165
+ // Note that we ignore the dimensions from the metadata; we don't even bother
166
+ // storing them. The resized dimensions take priority. If we don't have any,
167
+ // then we use the dimensions from the actual decoded frame. We use the actual
168
+ // decoded frame and not the metadata for two reasons:
169
+ //
170
+ // 1. Metadata may be wrong. If we access to more accurate information, we
171
+ // should use it.
172
+ // 2. Video streams can have variable resolution. This fact is not captured
173
+ // in the stream metadata.
174
+ //
175
+ // Both cases cause problems for our batch APIs, as we allocate
176
+ // FrameBatchOutputs based on the the stream metadata. But single-frame APIs
177
+ // can still work in such situations, so they should.
178
+ auto outputDims =
179
+ resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width));
180
+
181
+ if (preAllocatedOutputTensor.has_value()) {
182
+ auto shape = preAllocatedOutputTensor.value().sizes();
183
+ TORCH_CHECK(
184
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
185
+ (shape[1] == outputDims.width) && (shape[2] == 3),
186
+ "Expected pre-allocated tensor of shape ",
187
+ outputDims.height,
188
+ "x",
189
+ outputDims.width,
190
+ "x3, got ",
191
+ shape);
192
+ }
193
+
194
+ auto colorConversionLibrary = getColorConversionLibrary(outputDims);
195
+ torch::Tensor outputTensor;
196
+
197
+ if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
198
+ outputTensor = preAllocatedOutputTensor.value_or(
199
+ allocateEmptyHWCTensor(outputDims, torch::kCPU));
200
+
201
+ int resultHeight =
202
+ convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);
203
+
204
+ // If this check failed, it would mean that the frame wasn't reshaped to
205
+ // the expected height.
206
+ // TODO: Can we do the same check for width?
207
+ TORCH_CHECK(
208
+ resultHeight == outputDims.height,
209
+ "resultHeight != outputDims.height: ",
210
+ resultHeight,
211
+ " != ",
212
+ outputDims.height);
213
+
214
+ frameOutput.data = outputTensor;
215
+ } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
216
+ outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims);
217
+
218
+ // Similarly to above, if this check fails it means the frame wasn't
219
+ // reshaped to its expected dimensions by filtergraph.
220
+ auto shape = outputTensor.sizes();
221
+ TORCH_CHECK(
222
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
223
+ (shape[1] == outputDims.width) && (shape[2] == 3),
224
+ "Expected output tensor of shape ",
225
+ outputDims.height,
226
+ "x",
227
+ outputDims.width,
228
+ "x3, got ",
229
+ shape);
230
+
231
+ if (preAllocatedOutputTensor.has_value()) {
232
+ // We have already validated that preAllocatedOutputTensor and
233
+ // outputTensor have the same shape.
234
+ preAllocatedOutputTensor.value().copy_(outputTensor);
235
+ frameOutput.data = preAllocatedOutputTensor.value();
236
+ } else {
237
+ frameOutput.data = outputTensor;
238
+ }
239
+ } else {
240
+ TORCH_CHECK(
241
+ false,
242
+ "Invalid color conversion library: ",
243
+ static_cast<int>(colorConversionLibrary));
244
+ }
245
+ }
246
+
247
+ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
248
+ const UniqueAVFrame& avFrame,
249
+ torch::Tensor& outputTensor,
250
+ const FrameDims& outputDims) {
251
+ enum AVPixelFormat frameFormat =
252
+ static_cast<enum AVPixelFormat>(avFrame->format);
253
+
254
+ bool needsResize =
255
+ (avFrame->height != outputDims.height ||
256
+ avFrame->width != outputDims.width);
257
+
258
+ // We need to compare the current frame context with our previous frame
259
+ // context. If they are different, then we need to re-create our colorspace
260
+ // conversion objects. We create our colorspace conversion objects late so
261
+ // that we don't have to depend on the unreliable metadata in the header.
262
+ // And we sometimes re-create them because it's possible for frame
263
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
264
+ // conversion objects as much as possible for performance reasons.
265
+ SwsFrameContext swsFrameContext(
266
+ avFrame->width,
267
+ avFrame->height,
268
+ frameFormat,
269
+ needsResize ? avFrame->width : outputDims.width,
270
+ needsResize ? avFrame->height : outputDims.height);
271
+
272
+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
273
+ swsContext_ = createSwsContext(
274
+ swsFrameContext,
275
+ avFrame->colorspace,
276
+
277
+ // See [Transform and Format Conversion Order] for more on the output
278
+ // pixel format.
279
+ /*outputFormat=*/AV_PIX_FMT_RGB24,
280
+
281
+ // No flags for color conversion. When resizing is needed, we use a
282
+ // separate swscale context with the appropriate resize flags.
283
+ /*swsFlags=*/0);
284
+ prevSwsFrameContext_ = swsFrameContext;
285
+ }
286
+
287
+ // When resizing is needed, we do sws_scale twice: first convert to RGB24 at
288
+ // original resolution, then resize in RGB24 space. This ensures transforms
289
+ // happen in the output color space (RGB24) rather than the input color space
290
+ // (YUV).
291
+ //
292
+ // When no resize is needed, we do color conversion directly into the output
293
+ // tensor.
294
+
295
+ torch::Tensor colorConvertedTensor = needsResize
296
+ ? allocateEmptyHWCTensor(
297
+ FrameDims(avFrame->height, avFrame->width), torch::kCPU)
298
+ : outputTensor;
299
+
300
+ uint8_t* colorConvertedPointers[4] = {
301
+ colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
302
+ int colorConvertedWidth = static_cast<int>(colorConvertedTensor.sizes()[1]);
303
+ int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0};
304
+
305
+ int colorConvertedHeight = sws_scale(
306
+ swsContext_.get(),
307
+ avFrame->data,
308
+ avFrame->linesize,
309
+ 0,
310
+ avFrame->height,
311
+ colorConvertedPointers,
312
+ colorConvertedLinesizes);
313
+
314
+ TORCH_CHECK(
315
+ colorConvertedHeight == avFrame->height,
316
+ "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ",
317
+ colorConvertedHeight,
318
+ " != ",
319
+ avFrame->height);
320
+
321
+ if (needsResize) {
322
+ // Use cached swscale context for resizing, similar to the color conversion
323
+ // context caching above.
324
+ SwsFrameContext resizeSwsFrameContext(
325
+ avFrame->width,
326
+ avFrame->height,
327
+ AV_PIX_FMT_RGB24,
328
+ outputDims.width,
329
+ outputDims.height);
330
+
331
+ if (!resizeSwsContext_ ||
332
+ prevResizeSwsFrameContext_ != resizeSwsFrameContext) {
333
+ resizeSwsContext_ = createSwsContext(
334
+ resizeSwsFrameContext,
335
+ AVCOL_SPC_RGB,
336
+ /*outputFormat=*/AV_PIX_FMT_RGB24,
337
+ /*swsFlags=*/swsFlags_);
338
+ prevResizeSwsFrameContext_ = resizeSwsFrameContext;
339
+ }
340
+
341
+ uint8_t* srcPointers[4] = {
342
+ colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
343
+ int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0};
344
+
345
+ uint8_t* dstPointers[4] = {
346
+ outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
347
+ int expectedOutputWidth = static_cast<int>(outputTensor.sizes()[1]);
348
+ int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
349
+
350
+ colorConvertedHeight = sws_scale(
351
+ resizeSwsContext_.get(),
352
+ srcPointers,
353
+ srcLinesizes,
354
+ 0,
355
+ avFrame->height,
356
+ dstPointers,
357
+ dstLinesizes);
358
+ }
359
+
360
+ return colorConvertedHeight;
361
+ }
362
+
363
+ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
364
+ const UniqueAVFrame& avFrame,
365
+ const FrameDims& outputDims) {
366
+ enum AVPixelFormat avFrameFormat =
367
+ static_cast<enum AVPixelFormat>(avFrame->format);
368
+
369
+ FiltersContext filtersContext(
370
+ avFrame->width,
371
+ avFrame->height,
372
+ avFrameFormat,
373
+ avFrame->sample_aspect_ratio,
374
+ outputDims.width,
375
+ outputDims.height,
376
+ /*outputFormat=*/AV_PIX_FMT_RGB24,
377
+ filters_,
378
+ timeBase_);
379
+
380
+ if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
381
+ filterGraph_ =
382
+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
383
+ prevFiltersContext_ = std::move(filtersContext);
384
+ }
385
+ return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
386
+ }
387
+
388
+ void CpuDeviceInterface::convertAudioAVFrameToFrameOutput(
389
+ UniqueAVFrame& srcAVFrame,
390
+ FrameOutput& frameOutput) {
391
+ AVSampleFormat srcSampleFormat =
392
+ static_cast<AVSampleFormat>(srcAVFrame->format);
393
+ AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP;
394
+
395
+ int srcSampleRate = srcAVFrame->sample_rate;
396
+ int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate);
397
+
398
+ int srcNumChannels = getNumChannels(codecContext_);
399
+ TORCH_CHECK(
400
+ srcNumChannels == getNumChannels(srcAVFrame),
401
+ "The frame has ",
402
+ getNumChannels(srcAVFrame),
403
+ " channels, expected ",
404
+ srcNumChannels,
405
+ ". If you are hitting this, it may be because you are using "
406
+ "a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
407
+ "valid scenarios. Try to upgrade FFmpeg?");
408
+ int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels);
409
+
410
+ bool mustConvert =
411
+ (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate ||
412
+ srcNumChannels != outNumChannels);
413
+
414
+ UniqueAVFrame convertedAVFrame;
415
+ if (mustConvert) {
416
+ if (!swrContext_) {
417
+ swrContext_.reset(createSwrContext(
418
+ srcSampleFormat,
419
+ outSampleFormat,
420
+ srcSampleRate,
421
+ outSampleRate,
422
+ srcAVFrame,
423
+ outNumChannels));
424
+ }
425
+
426
+ convertedAVFrame = convertAudioAVFrameSamples(
427
+ swrContext_,
428
+ srcAVFrame,
429
+ outSampleFormat,
430
+ outSampleRate,
431
+ outNumChannels);
432
+ }
433
+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
434
+
435
+ AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
436
+ TORCH_CHECK(
437
+ format == outSampleFormat,
438
+ "Something went wrong, the frame didn't get converted to the desired format. ",
439
+ "Desired format = ",
440
+ av_get_sample_fmt_name(outSampleFormat),
441
+ "source format = ",
442
+ av_get_sample_fmt_name(format));
443
+
444
+ int numChannels = getNumChannels(avFrame);
445
+ TORCH_CHECK(
446
+ numChannels == outNumChannels,
447
+ "Something went wrong, the frame didn't get converted to the desired ",
448
+ "number of channels = ",
449
+ outNumChannels,
450
+ ". Got ",
451
+ numChannels,
452
+ " instead.");
453
+
454
+ auto numSamples = avFrame->nb_samples;
455
+
456
+ frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);
457
+
458
+ if (numSamples > 0) {
459
+ uint8_t* outputChannelData =
460
+ static_cast<uint8_t*>(frameOutput.data.data_ptr());
461
+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
462
+ for (auto channel = 0; channel < numChannels;
463
+ ++channel, outputChannelData += numBytesPerChannel) {
464
+ std::memcpy(
465
+ outputChannelData,
466
+ avFrame->extended_data[channel],
467
+ numBytesPerChannel);
468
+ }
469
+ }
470
+ }
471
+
472
+ std::optional<torch::Tensor> CpuDeviceInterface::maybeFlushAudioBuffers() {
473
+ // When sample rate conversion is involved, swresample buffers some of the
474
+ // samples in-between calls to swr_convert (see the libswresample docs).
475
+ // That's because the last few samples in a given frame require future
476
+ // samples from the next frame to be properly converted. This function
477
+ // flushes out the samples that are stored in swresample's buffers.
478
+ if (!swrContext_) {
479
+ return std::nullopt;
480
+ }
481
+ auto numRemainingSamples = // this is an upper bound
482
+ swr_get_out_samples(swrContext_.get(), 0);
483
+
484
+ if (numRemainingSamples == 0) {
485
+ return std::nullopt;
486
+ }
487
+
488
+ int numChannels =
489
+ audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_));
490
+ torch::Tensor lastSamples =
491
+ torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
492
+
493
+ std::vector<uint8_t*> outputBuffers(numChannels);
494
+ for (auto i = 0; i < numChannels; i++) {
495
+ outputBuffers[i] = static_cast<uint8_t*>(lastSamples[i].data_ptr());
496
+ }
497
+
498
+ auto actualNumRemainingSamples = swr_convert(
499
+ swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);
500
+
501
+ return lastSamples.narrow(
502
+ /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
503
+ }
504
+
505
+ std::string CpuDeviceInterface::getDetails() {
506
+ return std::string("CPU Device Interface.");
507
+ }
508
+
509
+ } // namespace facebook::torchcodec
@@ -0,0 +1,141 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include "DeviceInterface.h"
10
+ #include "FFMPEGCommon.h"
11
+ #include "FilterGraph.h"
12
+
13
+ namespace facebook::torchcodec {
14
+
15
+ class CpuDeviceInterface : public DeviceInterface {
16
+ public:
17
+ CpuDeviceInterface(const torch::Device& device);
18
+
19
+ virtual ~CpuDeviceInterface() {}
20
+
21
+ std::optional<const AVCodec*> findCodec(
22
+ [[maybe_unused]] const AVCodecID& codecId,
23
+ [[maybe_unused]] bool isDecoder = true) override {
24
+ return std::nullopt;
25
+ }
26
+
27
+ virtual void initialize(
28
+ const AVStream* avStream,
29
+ const UniqueDecodingAVFormatContext& avFormatCtx,
30
+ const SharedAVCodecContext& codecContext) override;
31
+
32
+ virtual void initializeVideo(
33
+ const VideoStreamOptions& videoStreamOptions,
34
+ const std::vector<std::unique_ptr<Transform>>& transforms,
35
+ const std::optional<FrameDims>& resizedOutputDims) override;
36
+
37
+ virtual void initializeAudio(
38
+ const AudioStreamOptions& audioStreamOptions) override;
39
+
40
+ virtual std::optional<torch::Tensor> maybeFlushAudioBuffers() override;
41
+
42
+ void convertAVFrameToFrameOutput(
43
+ UniqueAVFrame& avFrame,
44
+ FrameOutput& frameOutput,
45
+ std::optional<torch::Tensor> preAllocatedOutputTensor) override;
46
+
47
+ std::string getDetails() override;
48
+
49
+ private:
50
+ void convertAudioAVFrameToFrameOutput(
51
+ UniqueAVFrame& srcAVFrame,
52
+ FrameOutput& frameOutput);
53
+
54
+ void convertVideoAVFrameToFrameOutput(
55
+ UniqueAVFrame& avFrame,
56
+ FrameOutput& frameOutput,
57
+ std::optional<torch::Tensor> preAllocatedOutputTensor);
58
+
59
+ int convertAVFrameToTensorUsingSwScale(
60
+ const UniqueAVFrame& avFrame,
61
+ torch::Tensor& outputTensor,
62
+ const FrameDims& outputDims);
63
+
64
+ torch::Tensor convertAVFrameToTensorUsingFilterGraph(
65
+ const UniqueAVFrame& avFrame,
66
+ const FrameDims& outputDims);
67
+
68
+ ColorConversionLibrary getColorConversionLibrary(
69
+ const FrameDims& inputFrameDims) const;
70
+
71
+ VideoStreamOptions videoStreamOptions_;
72
+ AVRational timeBase_;
73
+
74
+ // If the resized output dimensions are present, then we always use those as
75
+ // the output frame's dimensions. If they are not present, then we use the
76
+ // dimensions of the raw decoded frame. Note that we do not know the
77
+ // dimensions of the raw decoded frame until very late; we learn it in
78
+ // convertAVFrameToFrameOutput(). Deciding the final output frame's actual
79
+ // dimensions late allows us to handle video streams with variable
80
+ // resolutions.
81
+ std::optional<FrameDims> resizedOutputDims_;
82
+
83
+ // Color-conversion objects. Only one of filterGraph_ and swsContext_ should
84
+ // be non-null. Which one we use is determined dynamically in
85
+ // getColorConversionLibrary() each time we decode a frame.
86
+ //
87
+ // Creating both filterGraph_ and swsContext_ is relatively expensive, so we
88
+ // reuse them across frames. However, it is possbile that subsequent frames
89
+ // are different enough (change in dimensions) that we can't reuse the color
90
+ // conversion object. We store the relevant frame context from the frame used
91
+ // to create the object last time. We always compare the current frame's info
92
+ // against the previous one to determine if we need to recreate the color
93
+ // conversion object.
94
+ //
95
+ // TODO: The names of these fields is confusing, as the actual color
96
+ // conversion object for Sws has "context" in the name, and we use
97
+ // "context" for the structs we store to know if we need to recreate a
98
+ // color conversion object. We should clean that up.
99
+ std::unique_ptr<FilterGraph> filterGraph_;
100
+ FiltersContext prevFiltersContext_;
101
+ UniqueSwsContext swsContext_;
102
+ SwsFrameContext prevSwsFrameContext_;
103
+
104
+ // Cached swscale context for resizing in RGB24 space (used in double swscale
105
+ // path). Like the color conversion context above, we cache this to avoid
106
+ // recreating it for every frame.
107
+ UniqueSwsContext resizeSwsContext_;
108
+ SwsFrameContext prevResizeSwsFrameContext_;
109
+
110
+ // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
111
+ // of what FFmpeg calls "filters" to apply to decoded frames before returning
112
+ // them. In the PyTorch ecosystem, we call these "transforms". During
113
+ // initialization, we convert the user-supplied transforms into this string of
114
+ // filters.
115
+ //
116
+ // Note that if there are no user-supplied transforms, then the default filter
117
+ // we use is the copy filter, which is just an identity: it emits the output
118
+ // frame unchanged. We supply such a filter because we can't supply just the
119
+ // empty-string; we must supply SOME filter.
120
+ //
121
+ // See also [Tranform and Format Conversion Order] for more on filters.
122
+ std::string filters_ = "copy";
123
+
124
+ // Values set during initialization and referred to in
125
+ // getColorConversionLibrary().
126
+ bool areTransformsSwScaleCompatible_;
127
+ bool userRequestedSwScale_;
128
+
129
+ // The flags we supply to the resize swscale context. The flags control the
130
+ // resizing algorithm. We default to bilinear. Users can override this with a
131
+ // ResizeTransform that specifies a different interpolation mode.
132
+ int swsFlags_ = SWS_BILINEAR;
133
+
134
+ bool initialized_ = false;
135
+
136
+ // Audio-specific members
137
+ AudioStreamOptions audioStreamOptions_;
138
+ UniqueSwrContext swrContext_;
139
+ };
140
+
141
+ } // namespace facebook::torchcodec