torchcodec 0.7.0__cp312-cp312-win_amd64.whl → 0.8.1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (66) hide show
  1. torchcodec/_core/AVIOTensorContext.cpp +23 -16
  2. torchcodec/_core/AVIOTensorContext.h +2 -1
  3. torchcodec/_core/BetaCudaDeviceInterface.cpp +718 -0
  4. torchcodec/_core/BetaCudaDeviceInterface.h +193 -0
  5. torchcodec/_core/CMakeLists.txt +18 -3
  6. torchcodec/_core/CUDACommon.cpp +330 -0
  7. torchcodec/_core/CUDACommon.h +51 -0
  8. torchcodec/_core/Cache.h +6 -20
  9. torchcodec/_core/CpuDeviceInterface.cpp +195 -108
  10. torchcodec/_core/CpuDeviceInterface.h +84 -19
  11. torchcodec/_core/CudaDeviceInterface.cpp +227 -376
  12. torchcodec/_core/CudaDeviceInterface.h +38 -6
  13. torchcodec/_core/DeviceInterface.cpp +57 -19
  14. torchcodec/_core/DeviceInterface.h +97 -16
  15. torchcodec/_core/Encoder.cpp +346 -9
  16. torchcodec/_core/Encoder.h +62 -1
  17. torchcodec/_core/FFMPEGCommon.cpp +190 -3
  18. torchcodec/_core/FFMPEGCommon.h +27 -1
  19. torchcodec/_core/FilterGraph.cpp +30 -22
  20. torchcodec/_core/FilterGraph.h +15 -1
  21. torchcodec/_core/Frame.cpp +22 -7
  22. torchcodec/_core/Frame.h +15 -61
  23. torchcodec/_core/Metadata.h +2 -2
  24. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  25. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  26. torchcodec/_core/NVDECCache.cpp +60 -0
  27. torchcodec/_core/NVDECCache.h +102 -0
  28. torchcodec/_core/SingleStreamDecoder.cpp +196 -201
  29. torchcodec/_core/SingleStreamDecoder.h +42 -15
  30. torchcodec/_core/StreamOptions.h +16 -6
  31. torchcodec/_core/Transform.cpp +87 -0
  32. torchcodec/_core/Transform.h +84 -0
  33. torchcodec/_core/__init__.py +4 -0
  34. torchcodec/_core/custom_ops.cpp +257 -32
  35. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
  36. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  37. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  38. torchcodec/_core/ops.py +147 -44
  39. torchcodec/_core/pybind_ops.cpp +22 -59
  40. torchcodec/_samplers/video_clip_sampler.py +7 -19
  41. torchcodec/decoders/__init__.py +1 -0
  42. torchcodec/decoders/_decoder_utils.py +61 -1
  43. torchcodec/decoders/_video_decoder.py +46 -20
  44. torchcodec/libtorchcodec_core4.dll +0 -0
  45. torchcodec/libtorchcodec_core5.dll +0 -0
  46. torchcodec/libtorchcodec_core6.dll +0 -0
  47. torchcodec/libtorchcodec_core7.dll +0 -0
  48. torchcodec/libtorchcodec_core8.dll +0 -0
  49. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  50. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  51. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  52. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  53. torchcodec/libtorchcodec_custom_ops8.dll +0 -0
  54. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  55. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  56. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  57. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  58. torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
  59. torchcodec/samplers/_time_based.py +8 -0
  60. torchcodec/version.py +1 -1
  61. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/METADATA +29 -16
  62. torchcodec-0.8.1.dist-info/RECORD +82 -0
  63. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/WHEEL +1 -1
  64. torchcodec-0.7.0.dist-info/RECORD +0 -67
  65. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/licenses/LICENSE +0 -0
  66. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/top_level.txt +0 -0
@@ -10,11 +10,23 @@ namespace facebook::torchcodec {
10
10
  namespace {
11
11
 
12
12
  static bool g_cpu = registerDeviceInterface(
13
- torch::kCPU,
13
+ DeviceInterfaceKey(torch::kCPU),
14
14
  [](const torch::Device& device) { return new CpuDeviceInterface(device); });
15
15
 
16
16
  } // namespace
17
17
 
18
+ CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19
+ int inputWidth,
20
+ int inputHeight,
21
+ AVPixelFormat inputFormat,
22
+ int outputWidth,
23
+ int outputHeight)
24
+ : inputWidth(inputWidth),
25
+ inputHeight(inputHeight),
26
+ inputFormat(inputFormat),
27
+ outputWidth(outputWidth),
28
+ outputHeight(outputHeight) {}
29
+
18
30
  bool CpuDeviceInterface::SwsFrameContext::operator==(
19
31
  const CpuDeviceInterface::SwsFrameContext& other) const {
20
32
  return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
@@ -34,6 +46,98 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
34
46
  device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
35
47
  }
36
48
 
49
+ void CpuDeviceInterface::initialize(
50
+ const AVStream* avStream,
51
+ [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
52
+ const SharedAVCodecContext& codecContext) {
53
+ TORCH_CHECK(avStream != nullptr, "avStream is null");
54
+ codecContext_ = codecContext;
55
+ timeBase_ = avStream->time_base;
56
+ }
57
+
58
+ void CpuDeviceInterface::initializeVideo(
59
+ const VideoStreamOptions& videoStreamOptions,
60
+ const std::vector<std::unique_ptr<Transform>>& transforms,
61
+ const std::optional<FrameDims>& resizedOutputDims) {
62
+ videoStreamOptions_ = videoStreamOptions;
63
+ resizedOutputDims_ = resizedOutputDims;
64
+
65
+ // We can only use swscale when we have a single resize transform. Note that
66
+ // this means swscale will not support the case of having several,
67
+ // back-to-base resizes. There's no strong reason to even do that, but if
68
+ // someone does, it's more correct to implement that with filtergraph.
69
+ //
70
+ // We calculate this value during initilization but we don't refer to it until
71
+ // getColorConversionLibrary() is called. Calculating this value during
72
+ // initialization saves us from having to save all of the transforms.
73
+ areTransformsSwScaleCompatible_ = transforms.empty() ||
74
+ (transforms.size() == 1 && transforms[0]->isResize());
75
+
76
+ // Note that we do not expose this capability in the public API, only through
77
+ // the core API.
78
+ //
79
+ // Same as above, we calculate this value during initialization and refer to
80
+ // it in getColorConversionLibrary().
81
+ userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
82
+ ColorConversionLibrary::SWSCALE;
83
+
84
+ // We can only use swscale when we have a single resize transform. Note that
85
+ // we actually decide on whether or not to actually use swscale at the last
86
+ // possible moment, when we actually convert the frame. This is because we
87
+ // need to know the actual frame dimensions.
88
+ if (transforms.size() == 1 && transforms[0]->isResize()) {
89
+ auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
90
+ TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
91
+ swsFlags_ = resize->getSwsFlags();
92
+ }
93
+
94
+ // If we have any transforms, replace filters_ with the filter strings from
95
+ // the transforms. As noted above, we decide between swscale and filtergraph
96
+ // when we actually decode a frame.
97
+ std::stringstream filters;
98
+ bool first = true;
99
+ for (const auto& transform : transforms) {
100
+ if (!first) {
101
+ filters << ",";
102
+ }
103
+ filters << transform->getFilterGraphCpu();
104
+ first = false;
105
+ }
106
+ if (!transforms.empty()) {
107
+ filters_ = filters.str();
108
+ }
109
+
110
+ initialized_ = true;
111
+ }
112
+
113
+ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
114
+ const FrameDims& outputDims) const {
115
+ // swscale requires widths to be multiples of 32:
116
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
117
+ bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;
118
+
119
+ // We want to use swscale for color conversion if possible because it is
120
+ // faster than filtergraph. The following are the conditions we need to meet
121
+ // to use it.
122
+ //
123
+ // Note that we treat the transform limitation differently from the width
124
+ // limitation. That is, we consider the transforms being compatible with
125
+ // swscale as a hard requirement. If the transforms are not compatiable,
126
+ // then we will end up not applying the transforms, and that is wrong.
127
+ //
128
+ // The width requirement, however, is a soft requirement. Even if we don't
129
+ // meet it, we let the user override it. We have tests that depend on this
130
+ // behavior. Since we don't expose the ability to choose swscale or
131
+ // filtergraph in our public API, this is probably okay. It's also the only
132
+ // way that we can be certain we are testing one versus the other.
133
+ if (areTransformsSwScaleCompatible_ &&
134
+ (userRequestedSwScale_ || isWidthSwScaleCompatible)) {
135
+ return ColorConversionLibrary::SWSCALE;
136
+ } else {
137
+ return ColorConversionLibrary::FILTERGRAPH;
138
+ }
139
+ }
140
+
37
141
  // Note [preAllocatedOutputTensor with swscale and filtergraph]:
38
142
  // Callers may pass a pre-allocated tensor, where the output.data tensor will
39
143
  // be stored. This parameter is honored in any case, but it only leads to a
@@ -44,124 +148,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
44
148
  // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
45
149
  // `dimension_order` parameter. It's up to callers to re-shape it if needed.
46
150
  void CpuDeviceInterface::convertAVFrameToFrameOutput(
47
- const VideoStreamOptions& videoStreamOptions,
48
- const AVRational& timeBase,
49
151
  UniqueAVFrame& avFrame,
50
152
  FrameOutput& frameOutput,
51
153
  std::optional<torch::Tensor> preAllocatedOutputTensor) {
52
- auto frameDims =
53
- getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
54
- int expectedOutputHeight = frameDims.height;
55
- int expectedOutputWidth = frameDims.width;
154
+ TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
155
+
156
+ // Note that we ignore the dimensions from the metadata; we don't even bother
157
+ // storing them. The resized dimensions take priority. If we don't have any,
158
+ // then we use the dimensions from the actual decoded frame. We use the actual
159
+ // decoded frame and not the metadata for two reasons:
160
+ //
161
+ // 1. Metadata may be wrong. If we access to more accurate information, we
162
+ // should use it.
163
+ // 2. Video streams can have variable resolution. This fact is not captured
164
+ // in the stream metadata.
165
+ //
166
+ // Both cases cause problems for our batch APIs, as we allocate
167
+ // FrameBatchOutputs based on the the stream metadata. But single-frame APIs
168
+ // can still work in such situations, so they should.
169
+ auto outputDims =
170
+ resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width));
56
171
 
57
172
  if (preAllocatedOutputTensor.has_value()) {
58
173
  auto shape = preAllocatedOutputTensor.value().sizes();
59
174
  TORCH_CHECK(
60
- (shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
61
- (shape[1] == expectedOutputWidth) && (shape[2] == 3),
175
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
176
+ (shape[1] == outputDims.width) && (shape[2] == 3),
62
177
  "Expected pre-allocated tensor of shape ",
63
- expectedOutputHeight,
178
+ outputDims.height,
64
179
  "x",
65
- expectedOutputWidth,
180
+ outputDims.width,
66
181
  "x3, got ",
67
182
  shape);
68
183
  }
69
184
 
185
+ auto colorConversionLibrary = getColorConversionLibrary(outputDims);
70
186
  torch::Tensor outputTensor;
71
- enum AVPixelFormat frameFormat =
72
- static_cast<enum AVPixelFormat>(avFrame->format);
73
-
74
- // By default, we want to use swscale for color conversion because it is
75
- // faster. However, it has width requirements, so we may need to fall back
76
- // to filtergraph. We also need to respect what was requested from the
77
- // options; we respect the options unconditionally, so it's possible for
78
- // swscale's width requirements to be violated. We don't expose the ability to
79
- // choose color conversion library publicly; we only use this ability
80
- // internally.
81
-
82
- // swscale requires widths to be multiples of 32:
83
- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
84
- // so we fall back to filtergraph if the width is not a multiple of 32.
85
- auto defaultLibrary = (expectedOutputWidth % 32 == 0)
86
- ? ColorConversionLibrary::SWSCALE
87
- : ColorConversionLibrary::FILTERGRAPH;
88
-
89
- ColorConversionLibrary colorConversionLibrary =
90
- videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
91
187
 
92
188
  if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93
- // We need to compare the current frame context with our previous frame
94
- // context. If they are different, then we need to re-create our colorspace
95
- // conversion objects. We create our colorspace conversion objects late so
96
- // that we don't have to depend on the unreliable metadata in the header.
97
- // And we sometimes re-create them because it's possible for frame
98
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
99
- // conversion objects as much as possible for performance reasons.
100
- SwsFrameContext swsFrameContext;
101
-
102
- swsFrameContext.inputWidth = avFrame->width;
103
- swsFrameContext.inputHeight = avFrame->height;
104
- swsFrameContext.inputFormat = frameFormat;
105
- swsFrameContext.outputWidth = expectedOutputWidth;
106
- swsFrameContext.outputHeight = expectedOutputHeight;
107
-
108
- outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
109
- expectedOutputHeight, expectedOutputWidth, torch::kCPU));
110
-
111
- if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
112
- createSwsContext(swsFrameContext, avFrame->colorspace);
113
- prevSwsFrameContext_ = swsFrameContext;
114
- }
189
+ outputTensor = preAllocatedOutputTensor.value_or(
190
+ allocateEmptyHWCTensor(outputDims, torch::kCPU));
191
+
115
192
  int resultHeight =
116
- convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
193
+ convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);
194
+
117
195
  // If this check failed, it would mean that the frame wasn't reshaped to
118
196
  // the expected height.
119
197
  // TODO: Can we do the same check for width?
120
198
  TORCH_CHECK(
121
- resultHeight == expectedOutputHeight,
122
- "resultHeight != expectedOutputHeight: ",
199
+ resultHeight == outputDims.height,
200
+ "resultHeight != outputDims.height: ",
123
201
  resultHeight,
124
202
  " != ",
125
- expectedOutputHeight);
203
+ outputDims.height);
126
204
 
127
205
  frameOutput.data = outputTensor;
128
206
  } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129
- // See comment above in swscale branch about the filterGraphContext_
130
- // creation. creation
131
- FiltersContext filtersContext;
132
-
133
- filtersContext.inputWidth = avFrame->width;
134
- filtersContext.inputHeight = avFrame->height;
135
- filtersContext.inputFormat = frameFormat;
136
- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137
- filtersContext.outputWidth = expectedOutputWidth;
138
- filtersContext.outputHeight = expectedOutputHeight;
139
- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140
- filtersContext.timeBase = timeBase;
141
-
142
- std::stringstream filters;
143
- filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144
- filters << ":sws_flags=bilinear";
145
-
146
- filtersContext.filtergraphStr = filters.str();
147
-
148
- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149
- filterGraphContext_ =
150
- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151
- prevFiltersContext_ = std::move(filtersContext);
152
- }
153
- outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
207
+ outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims);
154
208
 
155
209
  // Similarly to above, if this check fails it means the frame wasn't
156
210
  // reshaped to its expected dimensions by filtergraph.
157
211
  auto shape = outputTensor.sizes();
158
212
  TORCH_CHECK(
159
- (shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
160
- (shape[1] == expectedOutputWidth) && (shape[2] == 3),
213
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
214
+ (shape[1] == outputDims.width) && (shape[2] == 3),
161
215
  "Expected output tensor of shape ",
162
- expectedOutputHeight,
216
+ outputDims.height,
163
217
  "x",
164
- expectedOutputWidth,
218
+ outputDims.width,
165
219
  "x3, got ",
166
220
  shape);
167
221
 
@@ -181,9 +235,32 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
181
235
  }
182
236
  }
183
237
 
184
- int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
238
+ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
185
239
  const UniqueAVFrame& avFrame,
186
- torch::Tensor& outputTensor) {
240
+ torch::Tensor& outputTensor,
241
+ const FrameDims& outputDims) {
242
+ enum AVPixelFormat frameFormat =
243
+ static_cast<enum AVPixelFormat>(avFrame->format);
244
+
245
+ // We need to compare the current frame context with our previous frame
246
+ // context. If they are different, then we need to re-create our colorspace
247
+ // conversion objects. We create our colorspace conversion objects late so
248
+ // that we don't have to depend on the unreliable metadata in the header.
249
+ // And we sometimes re-create them because it's possible for frame
250
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
251
+ // conversion objects as much as possible for performance reasons.
252
+ SwsFrameContext swsFrameContext(
253
+ avFrame->width,
254
+ avFrame->height,
255
+ frameFormat,
256
+ outputDims.width,
257
+ outputDims.height);
258
+
259
+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
260
+ createSwsContext(swsFrameContext, avFrame->colorspace);
261
+ prevSwsFrameContext_ = swsFrameContext;
262
+ }
263
+
187
264
  uint8_t* pointers[4] = {
188
265
  outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
189
266
  int expectedOutputWidth = outputTensor.sizes()[1];
@@ -199,25 +276,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
199
276
  return resultHeight;
200
277
  }
201
278
 
202
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
203
- const UniqueAVFrame& avFrame) {
204
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
205
-
206
- TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
207
-
208
- auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
209
- int height = frameDims.height;
210
- int width = frameDims.width;
211
- std::vector<int64_t> shape = {height, width, 3};
212
- std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
213
- AVFrame* filteredAVFramePtr = filteredAVFrame.release();
214
- auto deleter = [filteredAVFramePtr](void*) {
215
- UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
216
- };
217
- return torch::from_blob(
218
- filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
219
- }
220
-
221
279
  void CpuDeviceInterface::createSwsContext(
222
280
  const SwsFrameContext& swsFrameContext,
223
281
  const enum AVColorSpace colorspace) {
@@ -228,7 +286,7 @@ void CpuDeviceInterface::createSwsContext(
228
286
  swsFrameContext.outputWidth,
229
287
  swsFrameContext.outputHeight,
230
288
  AV_PIX_FMT_RGB24,
231
- SWS_BILINEAR,
289
+ swsFlags_,
232
290
  nullptr,
233
291
  nullptr,
234
292
  nullptr);
@@ -263,4 +321,33 @@ void CpuDeviceInterface::createSwsContext(
263
321
  swsContext_.reset(swsContext);
264
322
  }
265
323
 
324
+ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
325
+ const UniqueAVFrame& avFrame,
326
+ const FrameDims& outputDims) {
327
+ enum AVPixelFormat frameFormat =
328
+ static_cast<enum AVPixelFormat>(avFrame->format);
329
+
330
+ FiltersContext filtersContext(
331
+ avFrame->width,
332
+ avFrame->height,
333
+ frameFormat,
334
+ avFrame->sample_aspect_ratio,
335
+ outputDims.width,
336
+ outputDims.height,
337
+ AV_PIX_FMT_RGB24,
338
+ filters_,
339
+ timeBase_);
340
+
341
+ if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
342
+ filterGraph_ =
343
+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
344
+ prevFiltersContext_ = std::move(filtersContext);
345
+ }
346
+ return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
347
+ }
348
+
349
+ std::string CpuDeviceInterface::getDetails() {
350
+ return std::string("CPU Device Interface.");
351
+ }
352
+
266
353
  } // namespace facebook::torchcodec
@@ -23,31 +23,51 @@ class CpuDeviceInterface : public DeviceInterface {
23
23
  return std::nullopt;
24
24
  }
25
25
 
26
- void initializeContext(
27
- [[maybe_unused]] AVCodecContext* codecContext) override {}
26
+ virtual void initialize(
27
+ const AVStream* avStream,
28
+ const UniqueDecodingAVFormatContext& avFormatCtx,
29
+ const SharedAVCodecContext& codecContext) override;
28
30
 
29
- void convertAVFrameToFrameOutput(
31
+ virtual void initializeVideo(
30
32
  const VideoStreamOptions& videoStreamOptions,
31
- const AVRational& timeBase,
33
+ const std::vector<std::unique_ptr<Transform>>& transforms,
34
+ const std::optional<FrameDims>& resizedOutputDims) override;
35
+
36
+ void convertAVFrameToFrameOutput(
32
37
  UniqueAVFrame& avFrame,
33
38
  FrameOutput& frameOutput,
34
39
  std::optional<torch::Tensor> preAllocatedOutputTensor =
35
40
  std::nullopt) override;
36
41
 
42
+ std::string getDetails() override;
43
+
37
44
  private:
38
- int convertAVFrameToTensorUsingSwsScale(
45
+ int convertAVFrameToTensorUsingSwScale(
39
46
  const UniqueAVFrame& avFrame,
40
- torch::Tensor& outputTensor);
47
+ torch::Tensor& outputTensor,
48
+ const FrameDims& outputDims);
41
49
 
42
50
  torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43
- const UniqueAVFrame& avFrame);
51
+ const UniqueAVFrame& avFrame,
52
+ const FrameDims& outputDims);
53
+
54
+ ColorConversionLibrary getColorConversionLibrary(
55
+ const FrameDims& inputFrameDims) const;
44
56
 
45
57
  struct SwsFrameContext {
46
- int inputWidth;
47
- int inputHeight;
48
- AVPixelFormat inputFormat;
49
- int outputWidth;
50
- int outputHeight;
58
+ int inputWidth = 0;
59
+ int inputHeight = 0;
60
+ AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
61
+ int outputWidth = 0;
62
+ int outputHeight = 0;
63
+
64
+ SwsFrameContext() = default;
65
+ SwsFrameContext(
66
+ int inputWidth,
67
+ int inputHeight,
68
+ AVPixelFormat inputFormat,
69
+ int outputWidth,
70
+ int outputHeight);
51
71
  bool operator==(const SwsFrameContext&) const;
52
72
  bool operator!=(const SwsFrameContext&) const;
53
73
  };
@@ -56,15 +76,60 @@ class CpuDeviceInterface : public DeviceInterface {
56
76
  const SwsFrameContext& swsFrameContext,
57
77
  const enum AVColorSpace colorspace);
58
78
 
59
- // color-conversion fields. Only one of FilterGraphContext and
60
- // UniqueSwsContext should be non-null.
61
- std::unique_ptr<FilterGraph> filterGraphContext_;
62
- UniqueSwsContext swsContext_;
79
+ VideoStreamOptions videoStreamOptions_;
80
+ AVRational timeBase_;
63
81
 
64
- // Used to know whether a new FilterGraphContext or UniqueSwsContext should
65
- // be created before decoding a new frame.
66
- SwsFrameContext prevSwsFrameContext_;
82
+ // If the resized output dimensions are present, then we always use those as
83
+ // the output frame's dimensions. If they are not present, then we use the
84
+ // dimensions of the raw decoded frame. Note that we do not know the
85
+ // dimensions of the raw decoded frame until very late; we learn it in
86
+ // convertAVFrameToFrameOutput(). Deciding the final output frame's actual
87
+ // dimensions late allows us to handle video streams with variable
88
+ // resolutions.
89
+ std::optional<FrameDims> resizedOutputDims_;
90
+
91
+ // Color-conversion objects. Only one of filterGraph_ and swsContext_ should
92
+ // be non-null. Which one we use is determined dynamically in
93
+ // getColorConversionLibrary() each time we decode a frame.
94
+ //
95
+ // Creating both filterGraph_ and swsContext_ is relatively expensive, so we
96
+ // reuse them across frames. However, it is possbile that subsequent frames
97
+ // are different enough (change in dimensions) that we can't reuse the color
98
+ // conversion object. We store the relevant frame context from the frame used
99
+ // to create the object last time. We always compare the current frame's info
100
+ // against the previous one to determine if we need to recreate the color
101
+ // conversion object.
102
+ //
103
+ // TODO: The names of these fields is confusing, as the actual color
104
+ // conversion object for Sws has "context" in the name, and we use
105
+ // "context" for the structs we store to know if we need to recreate a
106
+ // color conversion object. We should clean that up.
107
+ std::unique_ptr<FilterGraph> filterGraph_;
67
108
  FiltersContext prevFiltersContext_;
109
+ UniqueSwsContext swsContext_;
110
+ SwsFrameContext prevSwsFrameContext_;
111
+
112
+ // The filter we supply to filterGraph_, if it is used. The default is the
113
+ // copy filter, which just copies the input to the output. Computationally, it
114
+ // should be a no-op. If we get no user-provided transforms, we will use the
115
+ // copy filter. Otherwise, we will construct the string from the transforms.
116
+ //
117
+ // Note that even if we only use the copy filter, we still get the desired
118
+ // colorspace conversion. We construct the filtergraph with its output sink
119
+ // set to RGB24.
120
+ std::string filters_ = "copy";
121
+
122
+ // The flags we supply to swsContext_, if it used. The flags control the
123
+ // resizing algorithm. We default to bilinear. Users can override this with a
124
+ // ResizeTransform.
125
+ int swsFlags_ = SWS_BILINEAR;
126
+
127
+ // Values set during initialization and referred to in
128
+ // getColorConversionLibrary().
129
+ bool areTransformsSwScaleCompatible_;
130
+ bool userRequestedSwScale_;
131
+
132
+ bool initialized_ = false;
68
133
  };
69
134
 
70
135
  } // namespace facebook::torchcodec