torchcodec 0.7.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (67) hide show
  1. torchcodec/__init__.py +16 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +123 -0
  7. torchcodec/_core/AVIOTensorContext.h +43 -0
  8. torchcodec/_core/CMakeLists.txt +292 -0
  9. torchcodec/_core/Cache.h +138 -0
  10. torchcodec/_core/CpuDeviceInterface.cpp +266 -0
  11. torchcodec/_core/CpuDeviceInterface.h +70 -0
  12. torchcodec/_core/CudaDeviceInterface.cpp +514 -0
  13. torchcodec/_core/CudaDeviceInterface.h +37 -0
  14. torchcodec/_core/DeviceInterface.cpp +79 -0
  15. torchcodec/_core/DeviceInterface.h +67 -0
  16. torchcodec/_core/Encoder.cpp +514 -0
  17. torchcodec/_core/Encoder.h +123 -0
  18. torchcodec/_core/FFMPEGCommon.cpp +421 -0
  19. torchcodec/_core/FFMPEGCommon.h +227 -0
  20. torchcodec/_core/FilterGraph.cpp +142 -0
  21. torchcodec/_core/FilterGraph.h +45 -0
  22. torchcodec/_core/Frame.cpp +32 -0
  23. torchcodec/_core/Frame.h +118 -0
  24. torchcodec/_core/Metadata.h +72 -0
  25. torchcodec/_core/SingleStreamDecoder.cpp +1715 -0
  26. torchcodec/_core/SingleStreamDecoder.h +380 -0
  27. torchcodec/_core/StreamOptions.h +53 -0
  28. torchcodec/_core/ValidationUtils.cpp +35 -0
  29. torchcodec/_core/ValidationUtils.h +21 -0
  30. torchcodec/_core/__init__.py +40 -0
  31. torchcodec/_core/_metadata.py +317 -0
  32. torchcodec/_core/custom_ops.cpp +727 -0
  33. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +300 -0
  34. torchcodec/_core/ops.py +455 -0
  35. torchcodec/_core/pybind_ops.cpp +87 -0
  36. torchcodec/_frame.py +145 -0
  37. torchcodec/_internally_replaced_utils.py +67 -0
  38. torchcodec/_samplers/__init__.py +7 -0
  39. torchcodec/_samplers/video_clip_sampler.py +430 -0
  40. torchcodec/decoders/__init__.py +11 -0
  41. torchcodec/decoders/_audio_decoder.py +177 -0
  42. torchcodec/decoders/_decoder_utils.py +52 -0
  43. torchcodec/decoders/_video_decoder.py +464 -0
  44. torchcodec/encoders/__init__.py +1 -0
  45. torchcodec/encoders/_audio_encoder.py +150 -0
  46. torchcodec/libtorchcodec_core4.dll +0 -0
  47. torchcodec/libtorchcodec_core5.dll +0 -0
  48. torchcodec/libtorchcodec_core6.dll +0 -0
  49. torchcodec/libtorchcodec_core7.dll +0 -0
  50. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  51. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  52. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  53. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  54. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  55. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  56. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  57. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  58. torchcodec/samplers/__init__.py +2 -0
  59. torchcodec/samplers/_common.py +84 -0
  60. torchcodec/samplers/_index_based.py +287 -0
  61. torchcodec/samplers/_time_based.py +350 -0
  62. torchcodec/version.py +2 -0
  63. torchcodec-0.7.0.dist-info/METADATA +242 -0
  64. torchcodec-0.7.0.dist-info/RECORD +67 -0
  65. torchcodec-0.7.0.dist-info/WHEEL +5 -0
  66. torchcodec-0.7.0.dist-info/licenses/LICENSE +28 -0
  67. torchcodec-0.7.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,79 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include "src/torchcodec/_core/DeviceInterface.h"
8
+ #include <map>
9
+ #include <mutex>
10
+
11
+ namespace facebook::torchcodec {
12
+
13
+ namespace {
14
+ using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
15
+ static std::mutex g_interface_mutex;
16
+
17
+ DeviceInterfaceMap& getDeviceMap() {
18
+ static DeviceInterfaceMap deviceMap;
19
+ return deviceMap;
20
+ }
21
+
22
+ std::string getDeviceType(const std::string& device) {
23
+ size_t pos = device.find(':');
24
+ if (pos == std::string::npos) {
25
+ return device;
26
+ }
27
+ return device.substr(0, pos);
28
+ }
29
+
30
+ } // namespace
31
+
32
+ bool registerDeviceInterface(
33
+ torch::DeviceType deviceType,
34
+ CreateDeviceInterfaceFn createInterface) {
35
+ std::scoped_lock lock(g_interface_mutex);
36
+ DeviceInterfaceMap& deviceMap = getDeviceMap();
37
+
38
+ TORCH_CHECK(
39
+ deviceMap.find(deviceType) == deviceMap.end(),
40
+ "Device interface already registered for ",
41
+ deviceType);
42
+ deviceMap.insert({deviceType, createInterface});
43
+
44
+ return true;
45
+ }
46
+
47
+ torch::Device createTorchDevice(const std::string device) {
48
+ std::scoped_lock lock(g_interface_mutex);
49
+ std::string deviceType = getDeviceType(device);
50
+ DeviceInterfaceMap& deviceMap = getDeviceMap();
51
+
52
+ auto deviceInterface = std::find_if(
53
+ deviceMap.begin(),
54
+ deviceMap.end(),
55
+ [&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
56
+ return device.rfind(
57
+ torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
58
+ });
59
+ TORCH_CHECK(
60
+ deviceInterface != deviceMap.end(), "Unsupported device: ", device);
61
+
62
+ return torch::Device(device);
63
+ }
64
+
65
+ std::unique_ptr<DeviceInterface> createDeviceInterface(
66
+ const torch::Device& device) {
67
+ auto deviceType = device.type();
68
+ std::scoped_lock lock(g_interface_mutex);
69
+ DeviceInterfaceMap& deviceMap = getDeviceMap();
70
+
71
+ TORCH_CHECK(
72
+ deviceMap.find(deviceType) != deviceMap.end(),
73
+ "Unsupported device: ",
74
+ device);
75
+
76
+ return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
77
+ }
78
+
79
+ } // namespace facebook::torchcodec
@@ -0,0 +1,67 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <torch/types.h>
10
+ #include <functional>
11
+ #include <memory>
12
+ #include <stdexcept>
13
+ #include <string>
14
+ #include "FFMPEGCommon.h"
15
+ #include "src/torchcodec/_core/Frame.h"
16
+ #include "src/torchcodec/_core/StreamOptions.h"
17
+
18
+ namespace facebook::torchcodec {
19
+
20
+ // Note that all these device functions should only be called if the device is
21
+ // not a CPU device. CPU device functions are already implemented in the
22
+ // SingleStreamDecoder implementation.
23
+ // These functions should only be called from within an if block like this:
24
+ // if (device.type() != torch::kCPU) {
25
+ // deviceFunction(device, ...);
26
+ // }
27
+
28
+ class DeviceInterface {
29
+ public:
30
+ DeviceInterface(const torch::Device& device) : device_(device) {}
31
+
32
+ virtual ~DeviceInterface(){};
33
+
34
+ torch::Device& device() {
35
+ return device_;
36
+ };
37
+
38
+ virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;
39
+
40
+ // Initialize the hardware device that is specified in `device`. Some builds
41
+ // support CUDA and others only support CPU.
42
+ virtual void initializeContext(AVCodecContext* codecContext) = 0;
43
+
44
+ virtual void convertAVFrameToFrameOutput(
45
+ const VideoStreamOptions& videoStreamOptions,
46
+ const AVRational& timeBase,
47
+ UniqueAVFrame& avFrame,
48
+ FrameOutput& frameOutput,
49
+ std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
50
+
51
+ protected:
52
+ torch::Device device_;
53
+ };
54
+
55
+ using CreateDeviceInterfaceFn =
56
+ std::function<DeviceInterface*(const torch::Device& device)>;
57
+
58
+ bool registerDeviceInterface(
59
+ torch::DeviceType deviceType,
60
+ const CreateDeviceInterfaceFn createInterface);
61
+
62
+ torch::Device createTorchDevice(const std::string device);
63
+
64
+ std::unique_ptr<DeviceInterface> createDeviceInterface(
65
+ const torch::Device& device);
66
+
67
+ } // namespace facebook::torchcodec
@@ -0,0 +1,514 @@
1
+ #include <sstream>
2
+
3
+ #include "src/torchcodec/_core/AVIOTensorContext.h"
4
+ #include "src/torchcodec/_core/Encoder.h"
5
+ #include "torch/types.h"
6
+
7
+ namespace facebook::torchcodec {
8
+
9
+ namespace {
10
+
11
+ torch::Tensor validateSamples(const torch::Tensor& samples) {
12
+ TORCH_CHECK(
13
+ samples.dtype() == torch::kFloat32,
14
+ "samples must have float32 dtype, got ",
15
+ samples.dtype());
16
+ TORCH_CHECK(
17
+ samples.dim() == 2,
18
+ "samples must have 2 dimensions, got ",
19
+ samples.dim());
20
+
21
+ // We enforce this, but if we get user reports we should investigate whether
22
+ // that's actually needed.
23
+ int numChannels = static_cast<int>(samples.sizes()[0]);
24
+ TORCH_CHECK(
25
+ numChannels <= AV_NUM_DATA_POINTERS,
26
+ "Trying to encode ",
27
+ numChannels,
28
+ " channels, but FFmpeg only supports ",
29
+ AV_NUM_DATA_POINTERS,
30
+ " channels per frame.");
31
+
32
+ return samples.contiguous();
33
+ }
34
+
35
+ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
36
+ if (avCodec.supported_samplerates == nullptr) {
37
+ return;
38
+ }
39
+
40
+ for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
41
+ if (sampleRate == avCodec.supported_samplerates[i]) {
42
+ return;
43
+ }
44
+ }
45
+ std::stringstream supportedRates;
46
+ for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
47
+ if (i > 0) {
48
+ supportedRates << ", ";
49
+ }
50
+ supportedRates << avCodec.supported_samplerates[i];
51
+ }
52
+
53
+ TORCH_CHECK(
54
+ false,
55
+ "invalid sample rate=",
56
+ sampleRate,
57
+ ". Supported sample rate values are: ",
58
+ supportedRates.str());
59
+ }
60
+
61
+ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
62
+ AV_SAMPLE_FMT_FLTP,
63
+ AV_SAMPLE_FMT_FLT,
64
+ AV_SAMPLE_FMT_DBLP,
65
+ AV_SAMPLE_FMT_DBL,
66
+ AV_SAMPLE_FMT_S64P,
67
+ AV_SAMPLE_FMT_S64,
68
+ AV_SAMPLE_FMT_S32P,
69
+ AV_SAMPLE_FMT_S32,
70
+ AV_SAMPLE_FMT_S16P,
71
+ AV_SAMPLE_FMT_S16,
72
+ AV_SAMPLE_FMT_U8P,
73
+ AV_SAMPLE_FMT_U8};
74
+
75
+ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
76
+ // Find a sample format that the encoder supports. We prefer using FLT[P],
77
+ // since this is the format of the input samples. If FLTP isn't supported
78
+ // then we'll need to convert the AVFrame's format. Our heuristic is to encode
79
+ // into the format with the highest resolution.
80
+ if (avCodec.sample_fmts == nullptr) {
81
+ // Can't really validate anything in this case, best we can do is hope that
82
+ // FLTP is supported by the encoder. If not, FFmpeg will raise.
83
+ return AV_SAMPLE_FMT_FLTP;
84
+ }
85
+
86
+ for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
87
+ for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) {
88
+ if (avCodec.sample_fmts[i] == preferredFormat) {
89
+ return preferredFormat;
90
+ }
91
+ }
92
+ }
93
+ // We should always find a match in preferredFormatsOrder, so we should always
94
+ // return earlier. But in the event that a future FFmpeg version defines an
95
+ // additional sample format that isn't in preferredFormatsOrder, we fallback:
96
+ return avCodec.sample_fmts[0];
97
+ }
98
+
99
+ } // namespace
100
+
101
+ AudioEncoder::~AudioEncoder() {
102
+ close_avio();
103
+ }
104
+
105
+ void AudioEncoder::close_avio() {
106
+ if (avFormatContext_ && avFormatContext_->pb) {
107
+ if (avFormatContext_->pb->error == 0) {
108
+ avio_flush(avFormatContext_->pb);
109
+ }
110
+
111
+ if (!avioContextHolder_) {
112
+ if (avFormatContext_->pb->error == 0) {
113
+ avio_close(avFormatContext_->pb);
114
+ }
115
+ // avoids closing again in destructor, which would segfault.
116
+ avFormatContext_->pb = nullptr;
117
+ }
118
+ }
119
+ }
120
+
121
+ AudioEncoder::AudioEncoder(
122
+ const torch::Tensor& samples,
123
+ int sampleRate,
124
+ std::string_view fileName,
125
+ const AudioStreamOptions& audioStreamOptions)
126
+ : samples_(validateSamples(samples)), inSampleRate_(sampleRate) {
127
+ setFFmpegLogLevel();
128
+ AVFormatContext* avFormatContext = nullptr;
129
+ int status = avformat_alloc_output_context2(
130
+ &avFormatContext, nullptr, nullptr, fileName.data());
131
+
132
+ TORCH_CHECK(
133
+ avFormatContext != nullptr,
134
+ "Couldn't allocate AVFormatContext. ",
135
+ "The destination file is ",
136
+ fileName,
137
+ ", check the desired extension? ",
138
+ getFFMPEGErrorStringFromErrorCode(status));
139
+ avFormatContext_.reset(avFormatContext);
140
+
141
+ status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
142
+ TORCH_CHECK(
143
+ status >= 0,
144
+ "avio_open failed. The destination file is ",
145
+ fileName,
146
+ ", make sure it's a valid path? ",
147
+ getFFMPEGErrorStringFromErrorCode(status));
148
+
149
+ initializeEncoder(audioStreamOptions);
150
+ }
151
+
152
+ AudioEncoder::AudioEncoder(
153
+ const torch::Tensor& samples,
154
+ int sampleRate,
155
+ std::string_view formatName,
156
+ std::unique_ptr<AVIOContextHolder> avioContextHolder,
157
+ const AudioStreamOptions& audioStreamOptions)
158
+ : samples_(validateSamples(samples)),
159
+ inSampleRate_(sampleRate),
160
+ avioContextHolder_(std::move(avioContextHolder)) {
161
+ setFFmpegLogLevel();
162
+ AVFormatContext* avFormatContext = nullptr;
163
+ int status = avformat_alloc_output_context2(
164
+ &avFormatContext, nullptr, formatName.data(), nullptr);
165
+
166
+ TORCH_CHECK(
167
+ avFormatContext != nullptr,
168
+ "Couldn't allocate AVFormatContext. ",
169
+ "Check the desired format? Got format=",
170
+ formatName,
171
+ ". ",
172
+ getFFMPEGErrorStringFromErrorCode(status));
173
+ avFormatContext_.reset(avFormatContext);
174
+
175
+ avFormatContext_->pb = avioContextHolder_->getAVIOContext();
176
+
177
+ initializeEncoder(audioStreamOptions);
178
+ }
179
+
180
+ void AudioEncoder::initializeEncoder(
181
+ const AudioStreamOptions& audioStreamOptions) {
182
+ // We use the AVFormatContext's default codec for that
183
+ // specific format/container.
184
+ const AVCodec* avCodec =
185
+ avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
186
+ TORCH_CHECK(avCodec != nullptr, "Codec not found");
187
+
188
+ AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
189
+ TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
190
+ avCodecContext_.reset(avCodecContext);
191
+
192
+ auto desiredBitRate = audioStreamOptions.bitRate;
193
+ if (desiredBitRate.has_value()) {
194
+ TORCH_CHECK(
195
+ *desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
196
+ }
197
+ // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
198
+ // well when "-b:a" isn't specified.
199
+ avCodecContext_->bit_rate = desiredBitRate.value_or(0);
200
+
201
+ outNumChannels_ = static_cast<int>(
202
+ audioStreamOptions.numChannels.value_or(samples_.sizes()[0]));
203
+ validateNumChannels(*avCodec, outNumChannels_);
204
+ // The avCodecContext layout defines the layout of the encoded output, it's
205
+ // not related to the input sampes.
206
+ setDefaultChannelLayout(avCodecContext_, outNumChannels_);
207
+
208
+ outSampleRate_ = audioStreamOptions.sampleRate.value_or(inSampleRate_);
209
+ validateSampleRate(*avCodec, outSampleRate_);
210
+ avCodecContext_->sample_rate = outSampleRate_;
211
+
212
+ // Input samples are expected to be FLTP. Not all encoders support FLTP, so we
213
+ // may need to convert the samples into a supported output sample format,
214
+ // which is what the `.sample_fmt` defines.
215
+ avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
216
+
217
+ int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
218
+ TORCH_CHECK(
219
+ status == AVSUCCESS,
220
+ "avcodec_open2 failed: ",
221
+ getFFMPEGErrorStringFromErrorCode(status));
222
+
223
+ // We're allocating the stream here. Streams are meant to be freed by
224
+ // avformat_free_context(avFormatContext), which we call in the
225
+ // avFormatContext_'s destructor.
226
+ AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
227
+ TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
228
+ status = avcodec_parameters_from_context(
229
+ avStream->codecpar, avCodecContext_.get());
230
+ TORCH_CHECK(
231
+ status == AVSUCCESS,
232
+ "avcodec_parameters_from_context failed: ",
233
+ getFFMPEGErrorStringFromErrorCode(status));
234
+ streamIndex_ = avStream->index;
235
+
236
+ // If sample rate conversion is needed and the encoder doesn't support
237
+ // variable frame size, we need to create an intermediate FIFO. See
238
+ // [Encoding loop, sample rate conversion and FIFO].
239
+ if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0) &&
240
+ (inSampleRate_ != outSampleRate_)) {
241
+ // frame_size * 2 is a decent default size. FFmpeg automatically
242
+ // re-allocates the fifo if more space is needed.
243
+ auto avAudioFifo = av_audio_fifo_alloc(
244
+ avCodecContext_->sample_fmt,
245
+ outNumChannels_,
246
+ avCodecContext_->frame_size * 2);
247
+ TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo.");
248
+ avAudioFifo_.reset(avAudioFifo);
249
+ }
250
+ }
251
+
252
+ torch::Tensor AudioEncoder::encodeToTensor() {
253
+ TORCH_CHECK(
254
+ avioContextHolder_ != nullptr,
255
+ "Cannot encode to tensor, avio tensor context doesn't exist.");
256
+ encode();
257
+ auto avioToTensorContext =
258
+ dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
259
+ TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
260
+ return avioToTensorContext->getOutputTensor();
261
+ }
262
+
263
+ void AudioEncoder::encode() {
264
+ // To be on the safe side we enforce that encode() can only be called once on
265
+ // an encoder object. Whether this is actually necessary is unknown, so this
266
+ // may be relaxed if needed.
267
+ TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
268
+ encodeWasCalled_ = true;
269
+
270
+ // Default to 256 like in torchaudio
271
+ int numSamplesAllocatedPerFrame =
272
+ avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
273
+ UniqueAVFrame avFrame = allocateAVFrame(
274
+ numSamplesAllocatedPerFrame,
275
+ inSampleRate_,
276
+ static_cast<int>(samples_.sizes()[0]),
277
+ AV_SAMPLE_FMT_FLTP);
278
+ avFrame->pts = 0;
279
+
280
+ AutoAVPacket autoAVPacket;
281
+
282
+ uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
283
+ int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
284
+ int numEncodedSamples = 0; // per channel
285
+ int numBytesPerSample = static_cast<int>(samples_.element_size());
286
+ int numBytesPerChannel = numSamples * numBytesPerSample;
287
+
288
+ auto status = avformat_write_header(avFormatContext_.get(), nullptr);
289
+ TORCH_CHECK(
290
+ status == AVSUCCESS,
291
+ "Error in avformat_write_header: ",
292
+ getFFMPEGErrorStringFromErrorCode(status));
293
+
294
+ while (numEncodedSamples < numSamples) {
295
+ int numSamplesToEncode =
296
+ std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
297
+ int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
298
+
299
+ for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
300
+ std::memcpy(
301
+ avFrame->data[ch],
302
+ psamples + ch * numBytesPerChannel,
303
+ numBytesToEncode);
304
+ }
305
+ psamples += numBytesToEncode;
306
+
307
+ // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
308
+ // that the frame buffers are allocated to a big enough size. Here, we reset
309
+ // it to the exact number of samples that need to be encoded, otherwise the
310
+ // encoded frame would contain more samples than necessary and our results
311
+ // wouldn't match the ffmpeg CLI.
312
+ avFrame->nb_samples = numSamplesToEncode;
313
+
314
+ UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame);
315
+ encodeFrameThroughFifo(autoAVPacket, convertedAVFrame);
316
+
317
+ numEncodedSamples += numSamplesToEncode;
318
+ }
319
+ TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
320
+
321
+ flushBuffers();
322
+
323
+ status = av_write_trailer(avFormatContext_.get());
324
+ TORCH_CHECK(
325
+ status == AVSUCCESS,
326
+ "Error in: av_write_trailer",
327
+ getFFMPEGErrorStringFromErrorCode(status));
328
+
329
+ close_avio();
330
+ }
331
+
332
+ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
333
+ if (static_cast<AVSampleFormat>(avFrame->format) ==
334
+ avCodecContext_->sample_fmt &&
335
+ getNumChannels(avFrame) == outNumChannels_ &&
336
+ avFrame->sample_rate == outSampleRate_) {
337
+ // Note: the clone references the same underlying data, it's a cheap copy.
338
+ return UniqueAVFrame(av_frame_clone(avFrame.get()));
339
+ }
340
+
341
+ if (!swrContext_) {
342
+ swrContext_.reset(createSwrContext(
343
+ static_cast<AVSampleFormat>(avFrame->format),
344
+ avCodecContext_->sample_fmt,
345
+ avFrame->sample_rate,
346
+ outSampleRate_,
347
+ avFrame,
348
+ outNumChannels_));
349
+ }
350
+ UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples(
351
+ swrContext_,
352
+ avFrame,
353
+ avCodecContext_->sample_fmt,
354
+ outSampleRate_,
355
+ outNumChannels_);
356
+
357
+ if (avFrame->sample_rate == outSampleRate_) {
358
+ TORCH_CHECK(
359
+ convertedAVFrame->nb_samples == avFrame->nb_samples,
360
+ "convertedAVFrame->nb_samples=",
361
+ convertedAVFrame->nb_samples,
362
+ " differs from ",
363
+ "avFrame->nb_samples=",
364
+ avFrame->nb_samples,
365
+ "This is unexpected, please report on the TorchCodec bug tracker.");
366
+ }
367
+ return convertedAVFrame;
368
+ }
369
+
370
+ void AudioEncoder::encodeFrameThroughFifo(
371
+ AutoAVPacket& autoAVPacket,
372
+ const UniqueAVFrame& avFrame,
373
+ // flushFifo is only set to true in maybeFlushSwrBuffers(), i.e. at the very
374
+ // end of the encoding process when we're flushing buffers. We also want to
375
+ // flush the FIFO so as to not leave any remaining samples in it.
376
+ bool flushFifo) {
377
+ if (avAudioFifo_ == nullptr) {
378
+ encodeFrame(autoAVPacket, avFrame);
379
+ return;
380
+ }
381
+ int numSamplesWritten = av_audio_fifo_write(
382
+ avAudioFifo_.get(),
383
+ reinterpret_cast<void**>(avFrame->data),
384
+ avFrame->nb_samples);
385
+ TORCH_CHECK(
386
+ numSamplesWritten == avFrame->nb_samples,
387
+ "Tried to write ",
388
+ avFrame->nb_samples,
389
+ " samples, but only wrote ",
390
+ numSamplesWritten);
391
+
392
+ UniqueAVFrame newavFrame = allocateAVFrame(
393
+ avCodecContext_->frame_size,
394
+ outSampleRate_,
395
+ outNumChannels_,
396
+ avCodecContext_->sample_fmt);
397
+
398
+ // Explaining the while bound:
399
+ // - if we're not flushing the FIFO, i.e. in most cases, we want to pull
400
+ // exactly `frame_size` samples from the FIFO, so we have to stop before it
401
+ // contains less than `frame_size` samples.
402
+ // - if we're flushing the FIFO, we want to read from the FIFO until the very
403
+ // last sample it contains.
404
+ //
405
+ // In both cases, for as long as we can, we're trying to pull exatly
406
+ // `frame_size` samples from the FIFO and send each `frame_size`-sized avFrame
407
+ // to encodeFrame(). Only the very last avFrame of the encoding process is
408
+ // allowed to contained less than frame_size samples. That only happens when
409
+ // flushFifo is true.
410
+ while (av_audio_fifo_size(avAudioFifo_.get()) >=
411
+ (flushFifo ? 1 : avCodecContext_->frame_size)) {
412
+ int samplesToRead = std::min(
413
+ av_audio_fifo_size(avAudioFifo_.get()), newavFrame->nb_samples);
414
+ int numSamplesRead = av_audio_fifo_read(
415
+ avAudioFifo_.get(),
416
+ reinterpret_cast<void**>(newavFrame->data),
417
+ samplesToRead);
418
+ TORCH_CHECK(
419
+ numSamplesRead == samplesToRead,
420
+ "Tried to read ",
421
+ samplesToRead,
422
+ " samples, but only read ",
423
+ numSamplesRead);
424
+
425
+ newavFrame->nb_samples = numSamplesRead;
426
+ encodeFrame(autoAVPacket, newavFrame);
427
+ }
428
+ }
429
+
430
+ void AudioEncoder::encodeFrame(
431
+ AutoAVPacket& autoAVPacket,
432
+ const UniqueAVFrame& avFrame) {
433
+ if (avFrame != nullptr) {
434
+ avFrame->pts = lastEncodedAVFramePts_;
435
+ lastEncodedAVFramePts_ += avFrame->nb_samples;
436
+ }
437
+
438
+ auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
439
+ TORCH_CHECK(
440
+ status == AVSUCCESS,
441
+ "Error while sending frame: ",
442
+ getFFMPEGErrorStringFromErrorCode(status));
443
+
444
+ while (status >= 0) {
445
+ ReferenceAVPacket packet(autoAVPacket);
446
+ status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
447
+ if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
448
+ if (status == AVERROR_EOF) {
449
+ // Flush the packets that were potentially buffered by
450
+ // av_interleaved_write_frame(). See corresponding block in
451
+ // TorchAudio:
452
+ // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
453
+ status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
454
+ TORCH_CHECK(
455
+ status == AVSUCCESS,
456
+ "Failed to flush packet: ",
457
+ getFFMPEGErrorStringFromErrorCode(status));
458
+ }
459
+ return;
460
+ }
461
+ TORCH_CHECK(
462
+ status >= 0,
463
+ "Error receiving packet: ",
464
+ getFFMPEGErrorStringFromErrorCode(status));
465
+
466
+ packet->stream_index = streamIndex_;
467
+
468
+ status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
469
+ TORCH_CHECK(
470
+ status == AVSUCCESS,
471
+ "Error in av_interleaved_write_frame: ",
472
+ getFFMPEGErrorStringFromErrorCode(status));
473
+ }
474
+ }
475
+
476
+ void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
477
+ // Similar to the decoder's method with the same name, but for encoding this
478
+ // time. That is, when sample conversion is involved, libswresample may have
479
+ // buffered some samples that we now need to flush and send to the encoder.
480
+ if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) {
481
+ return;
482
+ }
483
+ TORCH_CHECK(
484
+ swrContext_ != nullptr,
485
+ "swrContext is null, but sample rate conversion is needed. ",
486
+ "This is unexpected, please report on the TorchCodec bug tracker.");
487
+
488
+ int numRemainingSamples = // this is an upper bound
489
+ swr_get_out_samples(swrContext_.get(), 0);
490
+ if (numRemainingSamples == 0) {
491
+ return;
492
+ }
493
+
494
+ UniqueAVFrame avFrame = allocateAVFrame(
495
+ numRemainingSamples,
496
+ outSampleRate_,
497
+ outNumChannels_,
498
+ avCodecContext_->sample_fmt);
499
+ int actualNumRemainingSamples = swr_convert(
500
+ swrContext_.get(), avFrame->data, avFrame->nb_samples, nullptr, 0);
501
+ avFrame->nb_samples = actualNumRemainingSamples;
502
+
503
+ // We're potentially sending avFrame through the FIFO (if it exists), in which
504
+ // case we also want to flush the FIFO itself.
505
+ encodeFrameThroughFifo(autoAVPacket, avFrame, /*flushFifo=*/true);
506
+ }
507
+
508
+ void AudioEncoder::flushBuffers() {
509
+ AutoAVPacket autoAVPacket;
510
+ maybeFlushSwrBuffers(autoAVPacket);
511
+
512
+ encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
513
+ }
514
+ } // namespace facebook::torchcodec