torchcodec 0.7.0__cp310-cp310-win_amd64.whl → 0.8.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (61) hide show
  1. torchcodec/_core/BetaCudaDeviceInterface.cpp +636 -0
  2. torchcodec/_core/BetaCudaDeviceInterface.h +191 -0
  3. torchcodec/_core/CMakeLists.txt +36 -3
  4. torchcodec/_core/CUDACommon.cpp +315 -0
  5. torchcodec/_core/CUDACommon.h +46 -0
  6. torchcodec/_core/CpuDeviceInterface.cpp +189 -108
  7. torchcodec/_core/CpuDeviceInterface.h +81 -19
  8. torchcodec/_core/CudaDeviceInterface.cpp +211 -368
  9. torchcodec/_core/CudaDeviceInterface.h +33 -6
  10. torchcodec/_core/DeviceInterface.cpp +57 -19
  11. torchcodec/_core/DeviceInterface.h +97 -16
  12. torchcodec/_core/Encoder.cpp +302 -9
  13. torchcodec/_core/Encoder.h +51 -1
  14. torchcodec/_core/FFMPEGCommon.cpp +189 -2
  15. torchcodec/_core/FFMPEGCommon.h +18 -0
  16. torchcodec/_core/FilterGraph.cpp +28 -21
  17. torchcodec/_core/FilterGraph.h +15 -1
  18. torchcodec/_core/Frame.cpp +17 -7
  19. torchcodec/_core/Frame.h +15 -61
  20. torchcodec/_core/Metadata.h +2 -2
  21. torchcodec/_core/NVDECCache.cpp +70 -0
  22. torchcodec/_core/NVDECCache.h +104 -0
  23. torchcodec/_core/SingleStreamDecoder.cpp +202 -198
  24. torchcodec/_core/SingleStreamDecoder.h +39 -14
  25. torchcodec/_core/StreamOptions.h +16 -6
  26. torchcodec/_core/Transform.cpp +60 -0
  27. torchcodec/_core/Transform.h +59 -0
  28. torchcodec/_core/__init__.py +1 -0
  29. torchcodec/_core/custom_ops.cpp +180 -32
  30. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
  31. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  32. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  33. torchcodec/_core/ops.py +86 -43
  34. torchcodec/_core/pybind_ops.cpp +22 -59
  35. torchcodec/_samplers/video_clip_sampler.py +7 -19
  36. torchcodec/decoders/__init__.py +1 -0
  37. torchcodec/decoders/_decoder_utils.py +61 -1
  38. torchcodec/decoders/_video_decoder.py +56 -20
  39. torchcodec/libtorchcodec_core4.dll +0 -0
  40. torchcodec/libtorchcodec_core5.dll +0 -0
  41. torchcodec/libtorchcodec_core6.dll +0 -0
  42. torchcodec/libtorchcodec_core7.dll +0 -0
  43. torchcodec/libtorchcodec_core8.dll +0 -0
  44. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  45. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  46. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  47. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  48. torchcodec/libtorchcodec_custom_ops8.dll +0 -0
  49. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  50. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  51. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  52. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  53. torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
  54. torchcodec/samplers/_time_based.py +8 -0
  55. torchcodec/version.py +1 -1
  56. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/METADATA +24 -13
  57. torchcodec-0.8.0.dist-info/RECORD +80 -0
  58. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/WHEEL +1 -1
  59. torchcodec-0.7.0.dist-info/RECORD +0 -67
  60. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  61. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/top_level.txt +0 -0
@@ -10,11 +10,23 @@ namespace facebook::torchcodec {
10
10
  namespace {
11
11
 
12
12
  static bool g_cpu = registerDeviceInterface(
13
- torch::kCPU,
13
+ DeviceInterfaceKey(torch::kCPU),
14
14
  [](const torch::Device& device) { return new CpuDeviceInterface(device); });
15
15
 
16
16
  } // namespace
17
17
 
18
+ CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19
+ int inputWidth,
20
+ int inputHeight,
21
+ AVPixelFormat inputFormat,
22
+ int outputWidth,
23
+ int outputHeight)
24
+ : inputWidth(inputWidth),
25
+ inputHeight(inputHeight),
26
+ inputFormat(inputFormat),
27
+ outputWidth(outputWidth),
28
+ outputHeight(outputHeight) {}
29
+
18
30
  bool CpuDeviceInterface::SwsFrameContext::operator==(
19
31
  const CpuDeviceInterface::SwsFrameContext& other) const {
20
32
  return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
@@ -34,6 +46,96 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
34
46
  device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
35
47
  }
36
48
 
49
+ void CpuDeviceInterface::initialize(
50
+ const AVStream* avStream,
51
+ [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
52
+ TORCH_CHECK(avStream != nullptr, "avStream is null");
53
+ timeBase_ = avStream->time_base;
54
+ }
55
+
56
+ void CpuDeviceInterface::initializeVideo(
57
+ const VideoStreamOptions& videoStreamOptions,
58
+ const std::vector<std::unique_ptr<Transform>>& transforms,
59
+ const std::optional<FrameDims>& resizedOutputDims) {
60
+ videoStreamOptions_ = videoStreamOptions;
61
+ resizedOutputDims_ = resizedOutputDims;
62
+
63
+ // We can only use swscale when we have a single resize transform. Note that
64
+ // this means swscale will not support the case of having several,
65
+ // back-to-base resizes. There's no strong reason to even do that, but if
66
+ // someone does, it's more correct to implement that with filtergraph.
67
+ //
68
+ // We calculate this value during initilization but we don't refer to it until
69
+ // getColorConversionLibrary() is called. Calculating this value during
70
+ // initialization saves us from having to save all of the transforms.
71
+ areTransformsSwScaleCompatible_ = transforms.empty() ||
72
+ (transforms.size() == 1 && transforms[0]->isResize());
73
+
74
+ // Note that we do not expose this capability in the public API, only through
75
+ // the core API.
76
+ //
77
+ // Same as above, we calculate this value during initialization and refer to
78
+ // it in getColorConversionLibrary().
79
+ userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
80
+ ColorConversionLibrary::SWSCALE;
81
+
82
+ // We can only use swscale when we have a single resize transform. Note that
83
+ // we actually decide on whether or not to actually use swscale at the last
84
+ // possible moment, when we actually convert the frame. This is because we
85
+ // need to know the actual frame dimensions.
86
+ if (transforms.size() == 1 && transforms[0]->isResize()) {
87
+ auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
88
+ TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
89
+ swsFlags_ = resize->getSwsFlags();
90
+ }
91
+
92
+ // If we have any transforms, replace filters_ with the filter strings from
93
+ // the transforms. As noted above, we decide between swscale and filtergraph
94
+ // when we actually decode a frame.
95
+ std::stringstream filters;
96
+ bool first = true;
97
+ for (const auto& transform : transforms) {
98
+ if (!first) {
99
+ filters << ",";
100
+ }
101
+ filters << transform->getFilterGraphCpu();
102
+ first = false;
103
+ }
104
+ if (!transforms.empty()) {
105
+ filters_ = filters.str();
106
+ }
107
+
108
+ initialized_ = true;
109
+ }
110
+
111
+ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
112
+ const FrameDims& outputDims) const {
113
+ // swscale requires widths to be multiples of 32:
114
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
115
+ bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0;
116
+
117
+ // We want to use swscale for color conversion if possible because it is
118
+ // faster than filtergraph. The following are the conditions we need to meet
119
+ // to use it.
120
+ //
121
+ // Note that we treat the transform limitation differently from the width
122
+ // limitation. That is, we consider the transforms being compatible with
123
+ // swscale as a hard requirement. If the transforms are not compatiable,
124
+ // then we will end up not applying the transforms, and that is wrong.
125
+ //
126
+ // The width requirement, however, is a soft requirement. Even if we don't
127
+ // meet it, we let the user override it. We have tests that depend on this
128
+ // behavior. Since we don't expose the ability to choose swscale or
129
+ // filtergraph in our public API, this is probably okay. It's also the only
130
+ // way that we can be certain we are testing one versus the other.
131
+ if (areTransformsSwScaleCompatible_ &&
132
+ (userRequestedSwScale_ || isWidthSwScaleCompatible)) {
133
+ return ColorConversionLibrary::SWSCALE;
134
+ } else {
135
+ return ColorConversionLibrary::FILTERGRAPH;
136
+ }
137
+ }
138
+
37
139
  // Note [preAllocatedOutputTensor with swscale and filtergraph]:
38
140
  // Callers may pass a pre-allocated tensor, where the output.data tensor will
39
141
  // be stored. This parameter is honored in any case, but it only leads to a
@@ -44,124 +146,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
44
146
  // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
45
147
  // `dimension_order` parameter. It's up to callers to re-shape it if needed.
46
148
  void CpuDeviceInterface::convertAVFrameToFrameOutput(
47
- const VideoStreamOptions& videoStreamOptions,
48
- const AVRational& timeBase,
49
149
  UniqueAVFrame& avFrame,
50
150
  FrameOutput& frameOutput,
51
151
  std::optional<torch::Tensor> preAllocatedOutputTensor) {
52
- auto frameDims =
53
- getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
54
- int expectedOutputHeight = frameDims.height;
55
- int expectedOutputWidth = frameDims.width;
152
+ TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
153
+
154
+ // Note that we ignore the dimensions from the metadata; we don't even bother
155
+ // storing them. The resized dimensions take priority. If we don't have any,
156
+ // then we use the dimensions from the actual decoded frame. We use the actual
157
+ // decoded frame and not the metadata for two reasons:
158
+ //
159
+ // 1. Metadata may be wrong. If we access to more accurate information, we
160
+ // should use it.
161
+ // 2. Video streams can have variable resolution. This fact is not captured
162
+ // in the stream metadata.
163
+ //
164
+ // Both cases cause problems for our batch APIs, as we allocate
165
+ // FrameBatchOutputs based on the the stream metadata. But single-frame APIs
166
+ // can still work in such situations, so they should.
167
+ auto outputDims =
168
+ resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width));
56
169
 
57
170
  if (preAllocatedOutputTensor.has_value()) {
58
171
  auto shape = preAllocatedOutputTensor.value().sizes();
59
172
  TORCH_CHECK(
60
- (shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
61
- (shape[1] == expectedOutputWidth) && (shape[2] == 3),
173
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
174
+ (shape[1] == outputDims.width) && (shape[2] == 3),
62
175
  "Expected pre-allocated tensor of shape ",
63
- expectedOutputHeight,
176
+ outputDims.height,
64
177
  "x",
65
- expectedOutputWidth,
178
+ outputDims.width,
66
179
  "x3, got ",
67
180
  shape);
68
181
  }
69
182
 
183
+ auto colorConversionLibrary = getColorConversionLibrary(outputDims);
70
184
  torch::Tensor outputTensor;
71
- enum AVPixelFormat frameFormat =
72
- static_cast<enum AVPixelFormat>(avFrame->format);
73
-
74
- // By default, we want to use swscale for color conversion because it is
75
- // faster. However, it has width requirements, so we may need to fall back
76
- // to filtergraph. We also need to respect what was requested from the
77
- // options; we respect the options unconditionally, so it's possible for
78
- // swscale's width requirements to be violated. We don't expose the ability to
79
- // choose color conversion library publicly; we only use this ability
80
- // internally.
81
-
82
- // swscale requires widths to be multiples of 32:
83
- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
84
- // so we fall back to filtergraph if the width is not a multiple of 32.
85
- auto defaultLibrary = (expectedOutputWidth % 32 == 0)
86
- ? ColorConversionLibrary::SWSCALE
87
- : ColorConversionLibrary::FILTERGRAPH;
88
-
89
- ColorConversionLibrary colorConversionLibrary =
90
- videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
91
185
 
92
186
  if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93
- // We need to compare the current frame context with our previous frame
94
- // context. If they are different, then we need to re-create our colorspace
95
- // conversion objects. We create our colorspace conversion objects late so
96
- // that we don't have to depend on the unreliable metadata in the header.
97
- // And we sometimes re-create them because it's possible for frame
98
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
99
- // conversion objects as much as possible for performance reasons.
100
- SwsFrameContext swsFrameContext;
101
-
102
- swsFrameContext.inputWidth = avFrame->width;
103
- swsFrameContext.inputHeight = avFrame->height;
104
- swsFrameContext.inputFormat = frameFormat;
105
- swsFrameContext.outputWidth = expectedOutputWidth;
106
- swsFrameContext.outputHeight = expectedOutputHeight;
107
-
108
- outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
109
- expectedOutputHeight, expectedOutputWidth, torch::kCPU));
110
-
111
- if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
112
- createSwsContext(swsFrameContext, avFrame->colorspace);
113
- prevSwsFrameContext_ = swsFrameContext;
114
- }
187
+ outputTensor = preAllocatedOutputTensor.value_or(
188
+ allocateEmptyHWCTensor(outputDims, torch::kCPU));
189
+
115
190
  int resultHeight =
116
- convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
191
+ convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims);
192
+
117
193
  // If this check failed, it would mean that the frame wasn't reshaped to
118
194
  // the expected height.
119
195
  // TODO: Can we do the same check for width?
120
196
  TORCH_CHECK(
121
- resultHeight == expectedOutputHeight,
122
- "resultHeight != expectedOutputHeight: ",
197
+ resultHeight == outputDims.height,
198
+ "resultHeight != outputDims.height: ",
123
199
  resultHeight,
124
200
  " != ",
125
- expectedOutputHeight);
201
+ outputDims.height);
126
202
 
127
203
  frameOutput.data = outputTensor;
128
204
  } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129
- // See comment above in swscale branch about the filterGraphContext_
130
- // creation. creation
131
- FiltersContext filtersContext;
132
-
133
- filtersContext.inputWidth = avFrame->width;
134
- filtersContext.inputHeight = avFrame->height;
135
- filtersContext.inputFormat = frameFormat;
136
- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137
- filtersContext.outputWidth = expectedOutputWidth;
138
- filtersContext.outputHeight = expectedOutputHeight;
139
- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140
- filtersContext.timeBase = timeBase;
141
-
142
- std::stringstream filters;
143
- filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144
- filters << ":sws_flags=bilinear";
145
-
146
- filtersContext.filtergraphStr = filters.str();
147
-
148
- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149
- filterGraphContext_ =
150
- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151
- prevFiltersContext_ = std::move(filtersContext);
152
- }
153
- outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
205
+ outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims);
154
206
 
155
207
  // Similarly to above, if this check fails it means the frame wasn't
156
208
  // reshaped to its expected dimensions by filtergraph.
157
209
  auto shape = outputTensor.sizes();
158
210
  TORCH_CHECK(
159
- (shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
160
- (shape[1] == expectedOutputWidth) && (shape[2] == 3),
211
+ (shape.size() == 3) && (shape[0] == outputDims.height) &&
212
+ (shape[1] == outputDims.width) && (shape[2] == 3),
161
213
  "Expected output tensor of shape ",
162
- expectedOutputHeight,
214
+ outputDims.height,
163
215
  "x",
164
- expectedOutputWidth,
216
+ outputDims.width,
165
217
  "x3, got ",
166
218
  shape);
167
219
 
@@ -181,9 +233,32 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
181
233
  }
182
234
  }
183
235
 
184
- int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
236
+ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
185
237
  const UniqueAVFrame& avFrame,
186
- torch::Tensor& outputTensor) {
238
+ torch::Tensor& outputTensor,
239
+ const FrameDims& outputDims) {
240
+ enum AVPixelFormat frameFormat =
241
+ static_cast<enum AVPixelFormat>(avFrame->format);
242
+
243
+ // We need to compare the current frame context with our previous frame
244
+ // context. If they are different, then we need to re-create our colorspace
245
+ // conversion objects. We create our colorspace conversion objects late so
246
+ // that we don't have to depend on the unreliable metadata in the header.
247
+ // And we sometimes re-create them because it's possible for frame
248
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
249
+ // conversion objects as much as possible for performance reasons.
250
+ SwsFrameContext swsFrameContext(
251
+ avFrame->width,
252
+ avFrame->height,
253
+ frameFormat,
254
+ outputDims.width,
255
+ outputDims.height);
256
+
257
+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
258
+ createSwsContext(swsFrameContext, avFrame->colorspace);
259
+ prevSwsFrameContext_ = swsFrameContext;
260
+ }
261
+
187
262
  uint8_t* pointers[4] = {
188
263
  outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
189
264
  int expectedOutputWidth = outputTensor.sizes()[1];
@@ -199,25 +274,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
199
274
  return resultHeight;
200
275
  }
201
276
 
202
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
203
- const UniqueAVFrame& avFrame) {
204
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
205
-
206
- TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
207
-
208
- auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
209
- int height = frameDims.height;
210
- int width = frameDims.width;
211
- std::vector<int64_t> shape = {height, width, 3};
212
- std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
213
- AVFrame* filteredAVFramePtr = filteredAVFrame.release();
214
- auto deleter = [filteredAVFramePtr](void*) {
215
- UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
216
- };
217
- return torch::from_blob(
218
- filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
219
- }
220
-
221
277
  void CpuDeviceInterface::createSwsContext(
222
278
  const SwsFrameContext& swsFrameContext,
223
279
  const enum AVColorSpace colorspace) {
@@ -228,7 +284,7 @@ void CpuDeviceInterface::createSwsContext(
228
284
  swsFrameContext.outputWidth,
229
285
  swsFrameContext.outputHeight,
230
286
  AV_PIX_FMT_RGB24,
231
- SWS_BILINEAR,
287
+ swsFlags_,
232
288
  nullptr,
233
289
  nullptr,
234
290
  nullptr);
@@ -263,4 +319,29 @@ void CpuDeviceInterface::createSwsContext(
263
319
  swsContext_.reset(swsContext);
264
320
  }
265
321
 
322
+ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
323
+ const UniqueAVFrame& avFrame,
324
+ const FrameDims& outputDims) {
325
+ enum AVPixelFormat frameFormat =
326
+ static_cast<enum AVPixelFormat>(avFrame->format);
327
+
328
+ FiltersContext filtersContext(
329
+ avFrame->width,
330
+ avFrame->height,
331
+ frameFormat,
332
+ avFrame->sample_aspect_ratio,
333
+ outputDims.width,
334
+ outputDims.height,
335
+ AV_PIX_FMT_RGB24,
336
+ filters_,
337
+ timeBase_);
338
+
339
+ if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
340
+ filterGraph_ =
341
+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
342
+ prevFiltersContext_ = std::move(filtersContext);
343
+ }
344
+ return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
345
+ }
346
+
266
347
  } // namespace facebook::torchcodec
@@ -23,31 +23,48 @@ class CpuDeviceInterface : public DeviceInterface {
23
23
  return std::nullopt;
24
24
  }
25
25
 
26
- void initializeContext(
27
- [[maybe_unused]] AVCodecContext* codecContext) override {}
26
+ virtual void initialize(
27
+ const AVStream* avStream,
28
+ const UniqueDecodingAVFormatContext& avFormatCtx) override;
28
29
 
29
- void convertAVFrameToFrameOutput(
30
+ virtual void initializeVideo(
30
31
  const VideoStreamOptions& videoStreamOptions,
31
- const AVRational& timeBase,
32
+ const std::vector<std::unique_ptr<Transform>>& transforms,
33
+ const std::optional<FrameDims>& resizedOutputDims) override;
34
+
35
+ void convertAVFrameToFrameOutput(
32
36
  UniqueAVFrame& avFrame,
33
37
  FrameOutput& frameOutput,
34
38
  std::optional<torch::Tensor> preAllocatedOutputTensor =
35
39
  std::nullopt) override;
36
40
 
37
41
  private:
38
- int convertAVFrameToTensorUsingSwsScale(
42
+ int convertAVFrameToTensorUsingSwScale(
39
43
  const UniqueAVFrame& avFrame,
40
- torch::Tensor& outputTensor);
44
+ torch::Tensor& outputTensor,
45
+ const FrameDims& outputDims);
41
46
 
42
47
  torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43
- const UniqueAVFrame& avFrame);
48
+ const UniqueAVFrame& avFrame,
49
+ const FrameDims& outputDims);
50
+
51
+ ColorConversionLibrary getColorConversionLibrary(
52
+ const FrameDims& inputFrameDims) const;
44
53
 
45
54
  struct SwsFrameContext {
46
- int inputWidth;
47
- int inputHeight;
48
- AVPixelFormat inputFormat;
49
- int outputWidth;
50
- int outputHeight;
55
+ int inputWidth = 0;
56
+ int inputHeight = 0;
57
+ AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
58
+ int outputWidth = 0;
59
+ int outputHeight = 0;
60
+
61
+ SwsFrameContext() = default;
62
+ SwsFrameContext(
63
+ int inputWidth,
64
+ int inputHeight,
65
+ AVPixelFormat inputFormat,
66
+ int outputWidth,
67
+ int outputHeight);
51
68
  bool operator==(const SwsFrameContext&) const;
52
69
  bool operator!=(const SwsFrameContext&) const;
53
70
  };
@@ -56,15 +73,60 @@ class CpuDeviceInterface : public DeviceInterface {
56
73
  const SwsFrameContext& swsFrameContext,
57
74
  const enum AVColorSpace colorspace);
58
75
 
59
- // color-conversion fields. Only one of FilterGraphContext and
60
- // UniqueSwsContext should be non-null.
61
- std::unique_ptr<FilterGraph> filterGraphContext_;
62
- UniqueSwsContext swsContext_;
76
+ VideoStreamOptions videoStreamOptions_;
77
+ AVRational timeBase_;
63
78
 
64
- // Used to know whether a new FilterGraphContext or UniqueSwsContext should
65
- // be created before decoding a new frame.
66
- SwsFrameContext prevSwsFrameContext_;
79
+ // If the resized output dimensions are present, then we always use those as
80
+ // the output frame's dimensions. If they are not present, then we use the
81
+ // dimensions of the raw decoded frame. Note that we do not know the
82
+ // dimensions of the raw decoded frame until very late; we learn it in
83
+ // convertAVFrameToFrameOutput(). Deciding the final output frame's actual
84
+ // dimensions late allows us to handle video streams with variable
85
+ // resolutions.
86
+ std::optional<FrameDims> resizedOutputDims_;
87
+
88
+ // Color-conversion objects. Only one of filterGraph_ and swsContext_ should
89
+ // be non-null. Which one we use is determined dynamically in
90
+ // getColorConversionLibrary() each time we decode a frame.
91
+ //
92
+ // Creating both filterGraph_ and swsContext_ is relatively expensive, so we
93
+ // reuse them across frames. However, it is possbile that subsequent frames
94
+ // are different enough (change in dimensions) that we can't reuse the color
95
+ // conversion object. We store the relevant frame context from the frame used
96
+ // to create the object last time. We always compare the current frame's info
97
+ // against the previous one to determine if we need to recreate the color
98
+ // conversion object.
99
+ //
100
+ // TODO: The names of these fields is confusing, as the actual color
101
+ // conversion object for Sws has "context" in the name, and we use
102
+ // "context" for the structs we store to know if we need to recreate a
103
+ // color conversion object. We should clean that up.
104
+ std::unique_ptr<FilterGraph> filterGraph_;
67
105
  FiltersContext prevFiltersContext_;
106
+ UniqueSwsContext swsContext_;
107
+ SwsFrameContext prevSwsFrameContext_;
108
+
109
+ // The filter we supply to filterGraph_, if it is used. The default is the
110
+ // copy filter, which just copies the input to the output. Computationally, it
111
+ // should be a no-op. If we get no user-provided transforms, we will use the
112
+ // copy filter. Otherwise, we will construct the string from the transforms.
113
+ //
114
+ // Note that even if we only use the copy filter, we still get the desired
115
+ // colorspace conversion. We construct the filtergraph with its output sink
116
+ // set to RGB24.
117
+ std::string filters_ = "copy";
118
+
119
+ // The flags we supply to swsContext_, if it used. The flags control the
120
+ // resizing algorithm. We default to bilinear. Users can override this with a
121
+ // ResizeTransform.
122
+ int swsFlags_ = SWS_BILINEAR;
123
+
124
+ // Values set during initialization and referred to in
125
+ // getColorConversionLibrary().
126
+ bool areTransformsSwScaleCompatible_;
127
+ bool userRequestedSwScale_;
128
+
129
+ bool initialized_ = false;
68
130
  };
69
131
 
70
132
  } // namespace facebook::torchcodec