torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. torchcodec/__init__.py +27 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +130 -0
  7. torchcodec/_core/AVIOTensorContext.h +44 -0
  8. torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
  9. torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
  10. torchcodec/_core/CMakeLists.txt +295 -0
  11. torchcodec/_core/CUDACommon.cpp +330 -0
  12. torchcodec/_core/CUDACommon.h +51 -0
  13. torchcodec/_core/Cache.h +124 -0
  14. torchcodec/_core/CpuDeviceInterface.cpp +509 -0
  15. torchcodec/_core/CpuDeviceInterface.h +141 -0
  16. torchcodec/_core/CudaDeviceInterface.cpp +602 -0
  17. torchcodec/_core/CudaDeviceInterface.h +79 -0
  18. torchcodec/_core/DeviceInterface.cpp +117 -0
  19. torchcodec/_core/DeviceInterface.h +191 -0
  20. torchcodec/_core/Encoder.cpp +1054 -0
  21. torchcodec/_core/Encoder.h +192 -0
  22. torchcodec/_core/FFMPEGCommon.cpp +684 -0
  23. torchcodec/_core/FFMPEGCommon.h +314 -0
  24. torchcodec/_core/FilterGraph.cpp +159 -0
  25. torchcodec/_core/FilterGraph.h +59 -0
  26. torchcodec/_core/Frame.cpp +47 -0
  27. torchcodec/_core/Frame.h +72 -0
  28. torchcodec/_core/Metadata.cpp +124 -0
  29. torchcodec/_core/Metadata.h +92 -0
  30. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  31. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  32. torchcodec/_core/NVDECCache.cpp +60 -0
  33. torchcodec/_core/NVDECCache.h +102 -0
  34. torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
  35. torchcodec/_core/SingleStreamDecoder.h +391 -0
  36. torchcodec/_core/StreamOptions.h +70 -0
  37. torchcodec/_core/Transform.cpp +128 -0
  38. torchcodec/_core/Transform.h +86 -0
  39. torchcodec/_core/ValidationUtils.cpp +35 -0
  40. torchcodec/_core/ValidationUtils.h +21 -0
  41. torchcodec/_core/__init__.py +46 -0
  42. torchcodec/_core/_metadata.py +262 -0
  43. torchcodec/_core/custom_ops.cpp +1090 -0
  44. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
  45. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  46. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  47. torchcodec/_core/ops.py +605 -0
  48. torchcodec/_core/pybind_ops.cpp +50 -0
  49. torchcodec/_frame.py +146 -0
  50. torchcodec/_internally_replaced_utils.py +68 -0
  51. torchcodec/_samplers/__init__.py +7 -0
  52. torchcodec/_samplers/video_clip_sampler.py +419 -0
  53. torchcodec/decoders/__init__.py +12 -0
  54. torchcodec/decoders/_audio_decoder.py +185 -0
  55. torchcodec/decoders/_decoder_utils.py +113 -0
  56. torchcodec/decoders/_video_decoder.py +601 -0
  57. torchcodec/encoders/__init__.py +2 -0
  58. torchcodec/encoders/_audio_encoder.py +149 -0
  59. torchcodec/encoders/_video_encoder.py +196 -0
  60. torchcodec/libtorchcodec_core4.so +0 -0
  61. torchcodec/libtorchcodec_core5.so +0 -0
  62. torchcodec/libtorchcodec_core6.so +0 -0
  63. torchcodec/libtorchcodec_core7.so +0 -0
  64. torchcodec/libtorchcodec_core8.so +0 -0
  65. torchcodec/libtorchcodec_custom_ops4.so +0 -0
  66. torchcodec/libtorchcodec_custom_ops5.so +0 -0
  67. torchcodec/libtorchcodec_custom_ops6.so +0 -0
  68. torchcodec/libtorchcodec_custom_ops7.so +0 -0
  69. torchcodec/libtorchcodec_custom_ops8.so +0 -0
  70. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  71. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  72. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  73. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  74. torchcodec/libtorchcodec_pybind_ops8.so +0 -0
  75. torchcodec/samplers/__init__.py +2 -0
  76. torchcodec/samplers/_common.py +84 -0
  77. torchcodec/samplers/_index_based.py +287 -0
  78. torchcodec/samplers/_time_based.py +358 -0
  79. torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
  80. torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
  81. torchcodec/transforms/__init__.py +12 -0
  82. torchcodec/transforms/_decoder_transforms.py +375 -0
  83. torchcodec/version.py +2 -0
  84. torchcodec-0.10.0.dist-info/METADATA +286 -0
  85. torchcodec-0.10.0.dist-info/RECORD +88 -0
  86. torchcodec-0.10.0.dist-info/WHEEL +5 -0
  87. torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
  88. torchcodec-0.10.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,391 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <torch/types.h>
10
+ #include <cstdint>
11
+ #include <memory>
12
+ #include <ostream>
13
+ #include <string_view>
14
+
15
+ #include "AVIOContextHolder.h"
16
+ #include "DeviceInterface.h"
17
+ #include "FFMPEGCommon.h"
18
+ #include "Frame.h"
19
+ #include "Metadata.h"
20
+ #include "StreamOptions.h"
21
+ #include "Transform.h"
22
+
23
+ namespace facebook::torchcodec {
24
+
25
+ // The SingleStreamDecoder class can be used to decode video frames to Tensors.
26
+ // Note that SingleStreamDecoder is not thread-safe.
27
+ // Do not call non-const APIs concurrently on the same object.
28
+ class SingleStreamDecoder {
29
+ public:
30
+ // --------------------------------------------------------------------------
31
+ // CONSTRUCTION API
32
+ // --------------------------------------------------------------------------
33
+
34
+ // Creates a SingleStreamDecoder from the video at videoFilePath.
35
+ explicit SingleStreamDecoder(
36
+ const std::string& videoFilePath,
37
+ SeekMode seekMode = SeekMode::exact);
38
+
39
+ // Creates a SingleStreamDecoder using the provided AVIOContext inside the
40
+ // AVIOContextHolder. The AVIOContextHolder is the base class, and the
41
+ // derived class will have specialized how the custom read, seek and writes
42
+ // work.
43
+ explicit SingleStreamDecoder(
44
+ std::unique_ptr<AVIOContextHolder> context,
45
+ SeekMode seekMode = SeekMode::exact);
46
+
47
+ // --------------------------------------------------------------------------
48
+ // VIDEO METADATA QUERY API
49
+ // --------------------------------------------------------------------------
50
+
51
+ // Updates the metadata of the video to accurate values obtained by scanning
52
+ // the contents of the video file. Also updates each StreamInfo's index, i.e.
53
+ // the allFrames and keyFrames vectors.
54
+ void scanFileAndUpdateMetadataAndIndex();
55
+
56
+ // Sorts the keyFrames and allFrames vectors in each StreamInfo by pts.
57
+ void sortAllFrames();
58
+
59
+ // Returns the metadata for the container.
60
+ ContainerMetadata getContainerMetadata() const;
61
+
62
+ // Returns the seek mode of this decoder.
63
+ SeekMode getSeekMode() const;
64
+
65
+ // Returns the active stream index. Returns -2 if no stream is active.
66
+ int getActiveStreamIndex() const;
67
+
68
+ // Returns the key frame indices as a tensor. The tensor is 1D and contains
69
+ // int64 values, where each value is the frame index for a key frame.
70
+ torch::Tensor getKeyFrameIndices();
71
+
72
+ // FrameMappings is used for the custom_frame_mappings seek mode to store
73
+ // metadata of frames in a stream. The size of all tensors in this struct must
74
+ // match.
75
+
76
+ // --------------------------------------------------------------------------
77
+ // ADDING STREAMS API
78
+ // --------------------------------------------------------------------------
79
+ struct FrameMappings {
80
+ // 1D tensor of int64, each value is the PTS of a frame in timebase units.
81
+ torch::Tensor all_frames;
82
+ // 1D tensor of bool, each value indicates if the corresponding frame in
83
+ // all_frames is a key frame.
84
+ torch::Tensor is_key_frame;
85
+ // 1D tensor of int64, each value is the duration of the corresponding frame
86
+ // in all_frames in timebase units.
87
+ torch::Tensor duration;
88
+ };
89
+
90
+ void addVideoStream(
91
+ int streamIndex,
92
+ std::vector<Transform*>& transforms,
93
+ const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(),
94
+ std::optional<FrameMappings> customFrameMappings = std::nullopt);
95
+ void addAudioStream(
96
+ int streamIndex,
97
+ const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
98
+
99
+ // --------------------------------------------------------------------------
100
+ // DECODING AND SEEKING APIs
101
+ // --------------------------------------------------------------------------
102
+
103
+ // Places the cursor at the first frame on or after the position in seconds.
104
+ // Calling getNextFrame() will return the first frame at
105
+ // or after this position.
106
+ void setCursorPtsInSeconds(double seconds);
107
+
108
+ // Decodes the frame where the current cursor position is. It also advances
109
+ // the cursor to the next frame.
110
+ FrameOutput getNextFrame();
111
+
112
+ FrameOutput getFrameAtIndex(int64_t frameIndex);
113
+
114
+ // Returns frames at the given indices for a given stream as a single stacked
115
+ // Tensor.
116
+ FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices);
117
+
118
+ // Returns frames within a given range. The range is defined by [start, stop).
119
+ // The values retrieved from the range are: [start, start+step,
120
+ // start+(2*step), start+(3*step), ..., stop). The default for step is 1.
121
+ FrameBatchOutput getFramesInRange(int64_t start, int64_t stop, int64_t step);
122
+
123
+ // Decodes the first frame in any added stream that is visible at a given
124
+ // timestamp. Frames in the video have a presentation timestamp and a
125
+ // duration. For example, if a frame has presentation timestamp of 5.0s and a
126
+ // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
127
+ // i.e. it will be returned when this function is called with seconds=5.0 or
128
+ // seconds=5.999, etc.
129
+ FrameOutput getFramePlayedAt(double seconds);
130
+
131
+ FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps);
132
+
133
+ // Returns frames within a given pts range. The range is defined by
134
+ // [startSeconds, stopSeconds) with respect to the pts values for frames. The
135
+ // returned frames are in pts order.
136
+ //
137
+ // Note that while stopSeconds is excluded in the half open range, this really
138
+ // only makes a difference when stopSeconds is exactly the pts value for a
139
+ // frame. Otherwise, the moment in time immediately before stopSeconds is in
140
+ // the range, and that time maps to the same frame as stopSeconds.
141
+ //
142
+ // The frames returned are the frames that would be played by our abstract
143
+ // player. Our abstract player displays frames based on pts only. It displays
144
+ // frame i starting at the pts for frame i, and stops at the pts for frame
145
+ // i+1. This model ignores a frame's reported duration.
146
+ //
147
+ // Valid values for startSeconds and stopSeconds are:
148
+ //
149
+ // [beginStreamPtsSecondsFromContent, endStreamPtsSecondsFromContent)
150
+ FrameBatchOutput getFramesPlayedInRange(
151
+ double startSeconds,
152
+ double stopSeconds);
153
+
154
+ AudioFramesOutput getFramesPlayedInRangeAudio(
155
+ double startSeconds,
156
+ std::optional<double> stopSecondsOptional = std::nullopt);
157
+
158
+ class EndOfFileException : public std::runtime_error {
159
+ public:
160
+ explicit EndOfFileException(const std::string& msg)
161
+ : std::runtime_error(msg) {}
162
+ };
163
+
164
+ // --------------------------------------------------------------------------
165
+ // MORALLY PRIVATE APIS
166
+ // --------------------------------------------------------------------------
167
+ // These are APIs that should be private, but that are effectively exposed for
168
+ // practical reasons, typically for testing purposes.
169
+
170
+ // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
171
+ // can move it back to private.
172
+ FrameOutput getFrameAtIndexInternal(
173
+ int64_t frameIndex,
174
+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
175
+
176
+ // Exposed for _test_frame_pts_equality, which is used to test non-regression
177
+ // of pts resolution (64 to 32 bit floats)
178
+ double getPtsSecondsForFrame(int64_t frameIndex);
179
+
180
+ // Exposed for performance testing.
181
+ struct DecodeStats {
182
+ int64_t numSeeksAttempted = 0;
183
+ int64_t numSeeksDone = 0;
184
+ int64_t numSeeksSkipped = 0;
185
+ int64_t numPacketsRead = 0;
186
+ int64_t numPacketsSentToDecoder = 0;
187
+ int64_t numFramesReceivedByDecoder = 0;
188
+ int64_t numFlushes = 0;
189
+ };
190
+
191
+ DecodeStats getDecodeStats() const;
192
+ void resetDecodeStats();
193
+
194
+ std::string getDeviceInterfaceDetails() const;
195
+
196
+ private:
197
+ // --------------------------------------------------------------------------
198
+ // STREAMINFO AND ASSOCIATED STRUCTS
199
+ // --------------------------------------------------------------------------
200
+
201
+ struct FrameInfo {
202
+ int64_t pts = 0;
203
+
204
+ // The value of the nextPts default is important: the last frame's nextPts
205
+ // will be INT64_MAX, which ensures that the allFrames vec contains
206
+ // FrameInfo structs with *increasing* nextPts values. That's a necessary
207
+ // condition for the binary searches on those values to work properly (as
208
+ // typically done during pts -> index conversions).
209
+ // TODO: This field is unset (left to the default) for entries in the
210
+ // keyFrames vec!
211
+ int64_t nextPts = INT64_MAX;
212
+
213
+ // Note that frameIndex is ALWAYS the index into all of the frames in that
214
+ // stream, even when the FrameInfo is part of the key frame index. Given a
215
+ // FrameInfo for a key frame, the frameIndex allows us to know which frame
216
+ // that is in the stream.
217
+ int64_t frameIndex = 0;
218
+
219
+ // Indicates whether a frame is a key frame. It may appear redundant as it's
220
+ // only true for FrameInfos in the keyFrames index, but it is needed to
221
+ // correctly map frames between allFrames and keyFrames during the scan.
222
+ bool isKeyFrame = false;
223
+ };
224
+
225
+ struct StreamInfo {
226
+ int streamIndex = -1;
227
+ AVStream* stream = nullptr;
228
+ AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN;
229
+
230
+ AVRational timeBase = {};
231
+ SharedAVCodecContext codecContext;
232
+
233
+ // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
234
+ // called.
235
+ std::vector<FrameInfo> keyFrames;
236
+ std::vector<FrameInfo> allFrames;
237
+
238
+ VideoStreamOptions videoStreamOptions;
239
+ AudioStreamOptions audioStreamOptions;
240
+ };
241
+
242
+ // --------------------------------------------------------------------------
243
+ // INITIALIZERS
244
+ // --------------------------------------------------------------------------
245
+
246
+ void initializeDecoder();
247
+
248
+ // Reads the user provided frame index and updates each StreamInfo's index,
249
+ // i.e. the allFrames and keyFrames vectors, and
250
+ // endStreamPtsSecondsFromContent
251
+ void readCustomFrameMappingsUpdateMetadataAndIndex(
252
+ int streamIndex,
253
+ FrameMappings customFrameMappings);
254
+ // --------------------------------------------------------------------------
255
+ // DECODING APIS AND RELATED UTILS
256
+ // --------------------------------------------------------------------------
257
+
258
+ void setCursor(int64_t pts);
259
+ void setCursor(double) = delete; // prevent calls with doubles and floats
260
+ bool canWeAvoidSeeking() const;
261
+
262
+ void maybeSeekToBeforeDesiredPts();
263
+
264
+ UniqueAVFrame decodeAVFrame(
265
+ std::function<bool(const UniqueAVFrame&)> filterFunction);
266
+
267
+ FrameOutput getNextFrameInternal(
268
+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
269
+
270
+ torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
271
+
272
+ FrameOutput convertAVFrameToFrameOutput(
273
+ UniqueAVFrame& avFrame,
274
+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
275
+
276
+ void convertAVFrameToFrameOutputOnCPU(
277
+ UniqueAVFrame& avFrame,
278
+ FrameOutput& frameOutput,
279
+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
280
+
281
+ // --------------------------------------------------------------------------
282
+ // PTS <-> INDEX CONVERSIONS
283
+ // --------------------------------------------------------------------------
284
+
285
+ int getKeyFrameIdentifier(int64_t pts) const;
286
+
287
+ // Returns the key frame index of the presentation timestamp using our index.
288
+ // We build this index by scanning the file in
289
+ // scanFileAndUpdateMetadataAndIndex
290
+ int getKeyFrameIndexForPtsUsingScannedIndex(
291
+ const std::vector<SingleStreamDecoder::FrameInfo>& keyFrames,
292
+ int64_t pts) const;
293
+
294
+ int64_t secondsToIndexLowerBound(double seconds);
295
+
296
+ int64_t secondsToIndexUpperBound(double seconds);
297
+
298
+ int64_t getPts(int64_t frameIndex);
299
+
300
+ // --------------------------------------------------------------------------
301
+ // STREAM AND METADATA APIS
302
+ // --------------------------------------------------------------------------
303
+
304
+ void addStream(
305
+ int streamIndex,
306
+ AVMediaType mediaType,
307
+ const torch::Device& device = torch::kCPU,
308
+ const std::string_view deviceVariant = "ffmpeg",
309
+ std::optional<int> ffmpegThreadCount = std::nullopt);
310
+
311
+ // Returns the "best" stream index for a given media type. The "best" is
312
+ // determined by various heuristics in FFMPEG.
313
+ // See
314
+ // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
315
+ // for more details about the heuristics.
316
+ // Returns the key frame index of the presentation timestamp using FFMPEG's
317
+ // index. Note that this index may be truncated for some files.
318
+ int getBestStreamIndex(AVMediaType mediaType);
319
+
320
+ // --------------------------------------------------------------------------
321
+ // VALIDATION UTILS
322
+ // --------------------------------------------------------------------------
323
+
324
+ void validateActiveStream(
325
+ std::optional<AVMediaType> avMediaType = std::nullopt);
326
+ void validateScannedAllStreams(const std::string& msg);
327
+ void validateFrameIndex(
328
+ const StreamMetadata& streamMetadata,
329
+ int64_t frameIndex);
330
+
331
+ // --------------------------------------------------------------------------
332
+ // ATTRIBUTES
333
+ // --------------------------------------------------------------------------
334
+
335
+ SeekMode seekMode_;
336
+ ContainerMetadata containerMetadata_;
337
+ UniqueDecodingAVFormatContext formatContext_;
338
+ std::unique_ptr<DeviceInterface> deviceInterface_;
339
+ std::map<int, StreamInfo> streamInfos_;
340
+ const int NO_ACTIVE_STREAM = -2;
341
+ int activeStreamIndex_ = NO_ACTIVE_STREAM;
342
+
343
+ // The desired position of the cursor in the stream. We send frames >= this
344
+ // pts to the user when they request a frame.
345
+ int64_t cursor_ = INT64_MIN;
346
+ bool cursorWasJustSet_ = false;
347
+ int64_t lastDecodedAvFramePts_ = 0;
348
+ int64_t lastDecodedAvFrameDuration_ = 0;
349
+ int64_t lastDecodedFrameIndex_ = INT64_MIN;
350
+
351
+ // Stores various internal decoding stats.
352
+ DecodeStats decodeStats_;
353
+
354
+ // Stores the AVIOContext for the input buffer.
355
+ std::unique_ptr<AVIOContextHolder> avioContextHolder_;
356
+
357
+ // We will receive a vector of transforms upon adding a stream and store it
358
+ // here. However, we need to know if any of those operations change the
359
+ // dimensions of the output frame. If they do, we need to figure out what are
360
+ // the final dimensions of the output frame after ALL transformations. We
361
+ // figure this out as soon as we receive the transforms. If any of the
362
+ // transforms change the final output frame dimensions, we store that in
363
+ // resizedOutputDims_. If resizedOutputDims_ has no value, that means there
364
+ // are no transforms that change the output frame dimensions.
365
+ //
366
+ // The priority order for output frame dimension is:
367
+ //
368
+ // 1. resizedOutputDims_; the resize requested by the user always takes
369
+ // priority.
370
+ // 2. The dimemnsions of the actual decoded AVFrame. This can change
371
+ // per-decoded frame, and is unknown in SingleStreamDecoder. Only the
372
+ // DeviceInterface learns it immediately after decoding a raw frame but
373
+ // before the color transformation.
374
+ // 3. metdataDims_; the dimensions we learned from the metadata.
375
+ std::vector<std::unique_ptr<Transform>> transforms_;
376
+ std::optional<FrameDims> resizedOutputDims_;
377
+ FrameDims metadataDims_;
378
+
379
+ // Whether or not we have already scanned all streams to update the metadata.
380
+ bool scannedAllStreams_ = false;
381
+
382
+ // Tracks that we've already been initialized.
383
+ bool initialized_ = false;
384
+ };
385
+
386
+ // Prints the SingleStreamDecoder::DecodeStats to the ostream.
387
+ std::ostream& operator<<(
388
+ std::ostream& os,
389
+ const SingleStreamDecoder::DecodeStats& stats);
390
+
391
+ } // namespace facebook::torchcodec
@@ -0,0 +1,70 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <torch/types.h>
10
+ #include <map>
11
+ #include <optional>
12
+ #include <string>
13
+ #include <string_view>
14
+
15
+ namespace facebook::torchcodec {
16
+
17
+ enum ColorConversionLibrary {
18
+ // Use the libavfilter library for color conversion.
19
+ FILTERGRAPH,
20
+ // Use the libswscale library for color conversion.
21
+ SWSCALE
22
+ };
23
+
24
+ struct VideoStreamOptions {
25
+ VideoStreamOptions() {}
26
+
27
+ // Number of threads we pass to FFMPEG for decoding.
28
+ // 0 means FFMPEG will choose the number of threads automatically to fully
29
+ // utilize all cores. If not set, it will be the default FFMPEG behavior for
30
+ // the given codec.
31
+ std::optional<int> ffmpegThreadCount;
32
+
33
+ // Currently the dimension order can be either NHWC or NCHW.
34
+ // H=height, W=width, C=channel.
35
+ std::string dimensionOrder = "NCHW";
36
+
37
+ // By default we have to use filtergraph, as it is more general. We can only
38
+ // use swscale when we have met strict requirements. See
39
+ // CpuDeviceInterface::initialze() for the logic.
40
+ ColorConversionLibrary colorConversionLibrary =
41
+ ColorConversionLibrary::FILTERGRAPH;
42
+
43
+ // By default we use CPU for decoding for both C++ and python users.
44
+ // Note: This is not used for video encoding, because device is determined by
45
+ // the device of the input frame tensor.
46
+ torch::Device device = torch::kCPU;
47
+ // Device variant (e.g., "ffmpeg", "beta", etc.)
48
+ std::string_view deviceVariant = "ffmpeg";
49
+
50
+ // Encoding options
51
+ std::optional<std::string> codec;
52
+ // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
53
+ // If not specified, uses codec's default format.
54
+ std::optional<std::string> pixelFormat;
55
+ std::optional<double> crf;
56
+ std::optional<std::string> preset;
57
+ std::optional<std::map<std::string, std::string>> extraOptions;
58
+ };
59
+
60
+ struct AudioStreamOptions {
61
+ AudioStreamOptions() {}
62
+
63
+ // Encoding only
64
+ std::optional<int> bitRate;
65
+ // Decoding and encoding:
66
+ std::optional<int> numChannels;
67
+ std::optional<int> sampleRate;
68
+ };
69
+
70
+ } // namespace facebook::torchcodec
@@ -0,0 +1,128 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include "Transform.h"
8
+ #include <torch/types.h>
9
+ #include "FFMPEGCommon.h"
10
+
11
+ namespace facebook::torchcodec {
12
+
13
+ namespace {
14
+
15
+ std::string toFilterGraphInterpolation(
16
+ ResizeTransform::InterpolationMode mode) {
17
+ switch (mode) {
18
+ case ResizeTransform::InterpolationMode::BILINEAR:
19
+ return "bilinear";
20
+ default:
21
+ TORCH_CHECK(
22
+ false,
23
+ "Unknown interpolation mode: " +
24
+ std::to_string(static_cast<int>(mode)));
25
+ }
26
+ }
27
+
28
+ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
29
+ switch (mode) {
30
+ case ResizeTransform::InterpolationMode::BILINEAR:
31
+ return SWS_BILINEAR;
32
+ default:
33
+ TORCH_CHECK(
34
+ false,
35
+ "Unknown interpolation mode: " +
36
+ std::to_string(static_cast<int>(mode)));
37
+ }
38
+ }
39
+
40
+ } // namespace
41
+
42
+ std::string ResizeTransform::getFilterGraphCpu() const {
43
+ return "scale=" + std::to_string(outputDims_.width) + ":" +
44
+ std::to_string(outputDims_.height) +
45
+ ":flags=" + toFilterGraphInterpolation(interpolationMode_);
46
+ }
47
+
48
+ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
49
+ return outputDims_;
50
+ }
51
+
52
+ bool ResizeTransform::isResize() const {
53
+ return true;
54
+ }
55
+
56
+ int ResizeTransform::getSwsFlags() const {
57
+ return toSwsInterpolation(interpolationMode_);
58
+ }
59
+
60
+ CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {}
61
+
62
+ CropTransform::CropTransform(const FrameDims& dims, int x, int y)
63
+ : outputDims_(dims), x_(x), y_(y) {
64
+ TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
65
+ TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
66
+ }
67
+
68
+ std::string CropTransform::getFilterGraphCpu() const {
69
+ // For the FFmpeg filter crop, if the x and y coordinates are left
70
+ // unspecified, it defaults to a center crop.
71
+ std::string coordinates = x_.has_value()
72
+ ? (":" + std::to_string(x_.value()) + ":" + std::to_string(y_.value()))
73
+ : "";
74
+ return "crop=" + std::to_string(outputDims_.width) + ":" +
75
+ std::to_string(outputDims_.height) + coordinates + ":exact=1";
76
+ }
77
+
78
+ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
79
+ return outputDims_;
80
+ }
81
+
82
+ void CropTransform::validate(const FrameDims& inputDims) const {
83
+ TORCH_CHECK(
84
+ outputDims_.height <= inputDims.height,
85
+ "Crop output height (",
86
+ outputDims_.height,
87
+ ") is greater than input height (",
88
+ inputDims.height,
89
+ ")");
90
+ TORCH_CHECK(
91
+ outputDims_.width <= inputDims.width,
92
+ "Crop output width (",
93
+ outputDims_.width,
94
+ ") is greater than input width (",
95
+ inputDims.width,
96
+ ")");
97
+ TORCH_CHECK(
98
+ x_.has_value() == y_.has_value(),
99
+ "Crop x and y values must be both set or both unset");
100
+ if (x_.has_value()) {
101
+ TORCH_CHECK(
102
+ x_.value() <= inputDims.width,
103
+ "Crop x start position, ",
104
+ x_.value(),
105
+ ", out of bounds of input width, ",
106
+ inputDims.width);
107
+ TORCH_CHECK(
108
+ x_.value() + outputDims_.width <= inputDims.width,
109
+ "Crop x end position, ",
110
+ x_.value() + outputDims_.width,
111
+ ", out of bounds of input width ",
112
+ inputDims.width);
113
+ TORCH_CHECK(
114
+ y_.value() <= inputDims.height,
115
+ "Crop y start position, ",
116
+ y_.value(),
117
+ ", out of bounds of input height, ",
118
+ inputDims.height);
119
+ TORCH_CHECK(
120
+ y_.value() + outputDims_.height <= inputDims.height,
121
+ "Crop y end position, ",
122
+ y_.value() + outputDims_.height,
123
+ ", out of bounds of input height ",
124
+ inputDims.height);
125
+ }
126
+ }
127
+
128
+ } // namespace facebook::torchcodec
@@ -0,0 +1,86 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <optional>
10
+ #include <string>
11
+ #include "Frame.h"
12
+ #include "Metadata.h"
13
+
14
+ namespace facebook::torchcodec {
15
+
16
+ class Transform {
17
+ public:
18
+ virtual std::string getFilterGraphCpu() const = 0;
19
+ virtual ~Transform() = default;
20
+
21
+ // If the transformation does not change the output frame dimensions, then
22
+ // there is no need to override this member function. The default
23
+ // implementation returns an empty optional, indicating that the output frame
24
+ // has the same dimensions as the input frame.
25
+ //
26
+ // If the transformation does change the output frame dimensions, then it
27
+ // must override this member function and return the output frame dimensions.
28
+ virtual std::optional<FrameDims> getOutputFrameDims() const {
29
+ return std::nullopt;
30
+ }
31
+
32
+ // The ResizeTransform is special because it is the only transform
33
+ // that swscale can handle.
34
+ virtual bool isResize() const {
35
+ return false;
36
+ }
37
+
38
+ // The validity of some transforms depends on the characteristics of the
39
+ // AVStream they're being applied to. For example, some transforms will
40
+ // specify coordinates inside a frame, we need to validate that those are
41
+ // within the frame's bounds.
42
+ //
43
+ // Note that the validation function does not return anything. We expect
44
+ // invalid configurations to throw an exception.
45
+ virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
46
+ };
47
+
48
+ class ResizeTransform : public Transform {
49
+ public:
50
+ enum class InterpolationMode { BILINEAR };
51
+
52
+ explicit ResizeTransform(const FrameDims& dims)
53
+ : outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {}
54
+
55
+ ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode)
56
+ : outputDims_(dims), interpolationMode_(interpolationMode) {}
57
+
58
+ std::string getFilterGraphCpu() const override;
59
+ std::optional<FrameDims> getOutputFrameDims() const override;
60
+ bool isResize() const override;
61
+
62
+ int getSwsFlags() const;
63
+
64
+ private:
65
+ FrameDims outputDims_;
66
+ InterpolationMode interpolationMode_;
67
+ };
68
+
69
+ class CropTransform : public Transform {
70
+ public:
71
+ CropTransform(const FrameDims& dims, int x, int y);
72
+
73
+ // Becomes a center crop if x and y are not specified.
74
+ explicit CropTransform(const FrameDims& dims);
75
+
76
+ std::string getFilterGraphCpu() const override;
77
+ std::optional<FrameDims> getOutputFrameDims() const override;
78
+ void validate(const FrameDims& inputDims) const override;
79
+
80
+ private:
81
+ FrameDims outputDims_;
82
+ std::optional<int> x_;
83
+ std::optional<int> y_;
84
+ };
85
+
86
+ } // namespace facebook::torchcodec