torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchcodec/__init__.py +27 -0
- torchcodec/_core/AVIOContextHolder.cpp +60 -0
- torchcodec/_core/AVIOContextHolder.h +64 -0
- torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
- torchcodec/_core/AVIOFileLikeContext.h +55 -0
- torchcodec/_core/AVIOTensorContext.cpp +130 -0
- torchcodec/_core/AVIOTensorContext.h +44 -0
- torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
- torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
- torchcodec/_core/CMakeLists.txt +295 -0
- torchcodec/_core/CUDACommon.cpp +330 -0
- torchcodec/_core/CUDACommon.h +51 -0
- torchcodec/_core/Cache.h +124 -0
- torchcodec/_core/CpuDeviceInterface.cpp +509 -0
- torchcodec/_core/CpuDeviceInterface.h +141 -0
- torchcodec/_core/CudaDeviceInterface.cpp +602 -0
- torchcodec/_core/CudaDeviceInterface.h +79 -0
- torchcodec/_core/DeviceInterface.cpp +117 -0
- torchcodec/_core/DeviceInterface.h +191 -0
- torchcodec/_core/Encoder.cpp +1054 -0
- torchcodec/_core/Encoder.h +192 -0
- torchcodec/_core/FFMPEGCommon.cpp +684 -0
- torchcodec/_core/FFMPEGCommon.h +314 -0
- torchcodec/_core/FilterGraph.cpp +159 -0
- torchcodec/_core/FilterGraph.h +59 -0
- torchcodec/_core/Frame.cpp +47 -0
- torchcodec/_core/Frame.h +72 -0
- torchcodec/_core/Metadata.cpp +124 -0
- torchcodec/_core/Metadata.h +92 -0
- torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
- torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
- torchcodec/_core/NVDECCache.cpp +60 -0
- torchcodec/_core/NVDECCache.h +102 -0
- torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
- torchcodec/_core/SingleStreamDecoder.h +391 -0
- torchcodec/_core/StreamOptions.h +70 -0
- torchcodec/_core/Transform.cpp +128 -0
- torchcodec/_core/Transform.h +86 -0
- torchcodec/_core/ValidationUtils.cpp +35 -0
- torchcodec/_core/ValidationUtils.h +21 -0
- torchcodec/_core/__init__.py +46 -0
- torchcodec/_core/_metadata.py +262 -0
- torchcodec/_core/custom_ops.cpp +1090 -0
- torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
- torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
- torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
- torchcodec/_core/ops.py +605 -0
- torchcodec/_core/pybind_ops.cpp +50 -0
- torchcodec/_frame.py +146 -0
- torchcodec/_internally_replaced_utils.py +68 -0
- torchcodec/_samplers/__init__.py +7 -0
- torchcodec/_samplers/video_clip_sampler.py +419 -0
- torchcodec/decoders/__init__.py +12 -0
- torchcodec/decoders/_audio_decoder.py +185 -0
- torchcodec/decoders/_decoder_utils.py +113 -0
- torchcodec/decoders/_video_decoder.py +601 -0
- torchcodec/encoders/__init__.py +2 -0
- torchcodec/encoders/_audio_encoder.py +149 -0
- torchcodec/encoders/_video_encoder.py +196 -0
- torchcodec/libtorchcodec_core4.so +0 -0
- torchcodec/libtorchcodec_core5.so +0 -0
- torchcodec/libtorchcodec_core6.so +0 -0
- torchcodec/libtorchcodec_core7.so +0 -0
- torchcodec/libtorchcodec_core8.so +0 -0
- torchcodec/libtorchcodec_custom_ops4.so +0 -0
- torchcodec/libtorchcodec_custom_ops5.so +0 -0
- torchcodec/libtorchcodec_custom_ops6.so +0 -0
- torchcodec/libtorchcodec_custom_ops7.so +0 -0
- torchcodec/libtorchcodec_custom_ops8.so +0 -0
- torchcodec/libtorchcodec_pybind_ops4.so +0 -0
- torchcodec/libtorchcodec_pybind_ops5.so +0 -0
- torchcodec/libtorchcodec_pybind_ops6.so +0 -0
- torchcodec/libtorchcodec_pybind_ops7.so +0 -0
- torchcodec/libtorchcodec_pybind_ops8.so +0 -0
- torchcodec/samplers/__init__.py +2 -0
- torchcodec/samplers/_common.py +84 -0
- torchcodec/samplers/_index_based.py +287 -0
- torchcodec/samplers/_time_based.py +358 -0
- torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
- torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
- torchcodec/transforms/__init__.py +12 -0
- torchcodec/transforms/_decoder_transforms.py +375 -0
- torchcodec/version.py +2 -0
- torchcodec-0.10.0.dist-info/METADATA +286 -0
- torchcodec-0.10.0.dist-info/RECORD +88 -0
- torchcodec-0.10.0.dist-info/WHEEL +5 -0
- torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
- torchcodec-0.10.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1054 @@
|
|
|
1
|
+
#include <sstream>
|
|
2
|
+
|
|
3
|
+
#include "AVIOTensorContext.h"
|
|
4
|
+
#include "Encoder.h"
|
|
5
|
+
#include "torch/types.h"
|
|
6
|
+
|
|
7
|
+
extern "C" {
|
|
8
|
+
#include <libavutil/hwcontext.h>
|
|
9
|
+
#include <libavutil/opt.h>
|
|
10
|
+
#include <libavutil/pixdesc.h>
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
namespace facebook::torchcodec {
|
|
14
|
+
|
|
15
|
+
namespace {
|
|
16
|
+
|
|
17
|
+
torch::Tensor validateSamples(const torch::Tensor& samples) {
|
|
18
|
+
TORCH_CHECK(
|
|
19
|
+
samples.dtype() == torch::kFloat32,
|
|
20
|
+
"samples must have float32 dtype, got ",
|
|
21
|
+
samples.dtype());
|
|
22
|
+
TORCH_CHECK(
|
|
23
|
+
samples.dim() == 2,
|
|
24
|
+
"samples must have 2 dimensions, got ",
|
|
25
|
+
samples.dim());
|
|
26
|
+
|
|
27
|
+
// We enforce this, but if we get user reports we should investigate whether
|
|
28
|
+
// that's actually needed.
|
|
29
|
+
int numChannels = static_cast<int>(samples.sizes()[0]);
|
|
30
|
+
TORCH_CHECK(
|
|
31
|
+
numChannels <= AV_NUM_DATA_POINTERS,
|
|
32
|
+
"Trying to encode ",
|
|
33
|
+
numChannels,
|
|
34
|
+
" channels, but FFmpeg only supports ",
|
|
35
|
+
AV_NUM_DATA_POINTERS,
|
|
36
|
+
" channels per frame.");
|
|
37
|
+
|
|
38
|
+
return samples.contiguous();
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
|
|
42
|
+
const int* supportedSampleRates = getSupportedSampleRates(avCodec);
|
|
43
|
+
if (supportedSampleRates == nullptr) {
|
|
44
|
+
return;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
for (auto i = 0; supportedSampleRates[i] != 0; ++i) {
|
|
48
|
+
if (sampleRate == supportedSampleRates[i]) {
|
|
49
|
+
return;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
std::stringstream supportedRates;
|
|
53
|
+
for (auto i = 0; supportedSampleRates[i] != 0; ++i) {
|
|
54
|
+
if (i > 0) {
|
|
55
|
+
supportedRates << ", ";
|
|
56
|
+
}
|
|
57
|
+
supportedRates << supportedSampleRates[i];
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
TORCH_CHECK(
|
|
61
|
+
false,
|
|
62
|
+
"invalid sample rate=",
|
|
63
|
+
sampleRate,
|
|
64
|
+
". Supported sample rate values are: ",
|
|
65
|
+
supportedRates.str());
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
|
|
69
|
+
AV_SAMPLE_FMT_FLTP,
|
|
70
|
+
AV_SAMPLE_FMT_FLT,
|
|
71
|
+
AV_SAMPLE_FMT_DBLP,
|
|
72
|
+
AV_SAMPLE_FMT_DBL,
|
|
73
|
+
AV_SAMPLE_FMT_S64P,
|
|
74
|
+
AV_SAMPLE_FMT_S64,
|
|
75
|
+
AV_SAMPLE_FMT_S32P,
|
|
76
|
+
AV_SAMPLE_FMT_S32,
|
|
77
|
+
AV_SAMPLE_FMT_S16P,
|
|
78
|
+
AV_SAMPLE_FMT_S16,
|
|
79
|
+
AV_SAMPLE_FMT_U8P,
|
|
80
|
+
AV_SAMPLE_FMT_U8};
|
|
81
|
+
|
|
82
|
+
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
|
|
83
|
+
const AVSampleFormat* supportedSampleFormats =
|
|
84
|
+
getSupportedOutputSampleFormats(avCodec);
|
|
85
|
+
|
|
86
|
+
// Find a sample format that the encoder supports. We prefer using FLT[P],
|
|
87
|
+
// since this is the format of the input samples. If FLTP isn't supported
|
|
88
|
+
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
|
|
89
|
+
// into the format with the highest resolution.
|
|
90
|
+
if (supportedSampleFormats == nullptr) {
|
|
91
|
+
// Can't really validate anything in this case, best we can do is hope that
|
|
92
|
+
// FLTP is supported by the encoder. If not, FFmpeg will raise.
|
|
93
|
+
return AV_SAMPLE_FMT_FLTP;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
|
|
97
|
+
for (int i = 0; supportedSampleFormats[i] != -1; ++i) {
|
|
98
|
+
if (supportedSampleFormats[i] == preferredFormat) {
|
|
99
|
+
return preferredFormat;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
// We should always find a match in preferredFormatsOrder, so we should always
|
|
104
|
+
// return earlier. But in the event that a future FFmpeg version defines an
|
|
105
|
+
// additional sample format that isn't in preferredFormatsOrder, we fallback:
|
|
106
|
+
return supportedSampleFormats[0];
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
void closeAVIOContext(
|
|
110
|
+
AVFormatContext* avFormatContext,
|
|
111
|
+
AVIOContextHolder* avioContextHolder) {
|
|
112
|
+
if (!avFormatContext || !avFormatContext->pb) {
|
|
113
|
+
return;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if (avFormatContext->pb->error == 0) {
|
|
117
|
+
avio_flush(avFormatContext->pb);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if (!avioContextHolder) {
|
|
121
|
+
if (avFormatContext->pb->error == 0) {
|
|
122
|
+
avio_close(avFormatContext->pb);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
avFormatContext->pb = nullptr;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
} // namespace
|
|
130
|
+
|
|
131
|
+
AudioEncoder::~AudioEncoder() {
|
|
132
|
+
closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get());
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
AudioEncoder::AudioEncoder(
|
|
136
|
+
const torch::Tensor& samples,
|
|
137
|
+
int sampleRate,
|
|
138
|
+
std::string_view fileName,
|
|
139
|
+
const AudioStreamOptions& audioStreamOptions)
|
|
140
|
+
: samples_(validateSamples(samples)), inSampleRate_(sampleRate) {
|
|
141
|
+
setFFmpegLogLevel();
|
|
142
|
+
AVFormatContext* avFormatContext = nullptr;
|
|
143
|
+
int status = avformat_alloc_output_context2(
|
|
144
|
+
&avFormatContext, nullptr, nullptr, fileName.data());
|
|
145
|
+
|
|
146
|
+
TORCH_CHECK(
|
|
147
|
+
avFormatContext != nullptr,
|
|
148
|
+
"Couldn't allocate AVFormatContext. ",
|
|
149
|
+
"The destination file is ",
|
|
150
|
+
fileName,
|
|
151
|
+
", check the desired extension? ",
|
|
152
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
153
|
+
avFormatContext_.reset(avFormatContext);
|
|
154
|
+
|
|
155
|
+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
|
|
156
|
+
TORCH_CHECK(
|
|
157
|
+
status >= 0,
|
|
158
|
+
"avio_open failed. The destination file is ",
|
|
159
|
+
fileName,
|
|
160
|
+
", make sure it's a valid path? ",
|
|
161
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
162
|
+
|
|
163
|
+
initializeEncoder(audioStreamOptions);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
AudioEncoder::AudioEncoder(
|
|
167
|
+
const torch::Tensor& samples,
|
|
168
|
+
int sampleRate,
|
|
169
|
+
std::string_view formatName,
|
|
170
|
+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
|
|
171
|
+
const AudioStreamOptions& audioStreamOptions)
|
|
172
|
+
: samples_(validateSamples(samples)),
|
|
173
|
+
inSampleRate_(sampleRate),
|
|
174
|
+
avioContextHolder_(std::move(avioContextHolder)) {
|
|
175
|
+
setFFmpegLogLevel();
|
|
176
|
+
AVFormatContext* avFormatContext = nullptr;
|
|
177
|
+
int status = avformat_alloc_output_context2(
|
|
178
|
+
&avFormatContext, nullptr, formatName.data(), nullptr);
|
|
179
|
+
|
|
180
|
+
TORCH_CHECK(
|
|
181
|
+
avFormatContext != nullptr,
|
|
182
|
+
"Couldn't allocate AVFormatContext. ",
|
|
183
|
+
"Check the desired format? Got format=",
|
|
184
|
+
formatName,
|
|
185
|
+
". ",
|
|
186
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
187
|
+
avFormatContext_.reset(avFormatContext);
|
|
188
|
+
|
|
189
|
+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
|
|
190
|
+
|
|
191
|
+
initializeEncoder(audioStreamOptions);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
void AudioEncoder::initializeEncoder(
|
|
195
|
+
const AudioStreamOptions& audioStreamOptions) {
|
|
196
|
+
// We use the AVFormatContext's default codec for that
|
|
197
|
+
// specific format/container.
|
|
198
|
+
const AVCodec* avCodec =
|
|
199
|
+
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
|
|
200
|
+
TORCH_CHECK(avCodec != nullptr, "Codec not found");
|
|
201
|
+
|
|
202
|
+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
|
|
203
|
+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
|
|
204
|
+
avCodecContext_.reset(avCodecContext);
|
|
205
|
+
|
|
206
|
+
auto desiredBitRate = audioStreamOptions.bitRate;
|
|
207
|
+
if (desiredBitRate.has_value()) {
|
|
208
|
+
TORCH_CHECK(
|
|
209
|
+
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
|
|
210
|
+
}
|
|
211
|
+
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
|
|
212
|
+
// well when "-b:a" isn't specified.
|
|
213
|
+
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
|
|
214
|
+
|
|
215
|
+
outNumChannels_ = static_cast<int>(
|
|
216
|
+
audioStreamOptions.numChannels.value_or(samples_.sizes()[0]));
|
|
217
|
+
validateNumChannels(*avCodec, outNumChannels_);
|
|
218
|
+
// The avCodecContext layout defines the layout of the encoded output, it's
|
|
219
|
+
// not related to the input sampes.
|
|
220
|
+
setDefaultChannelLayout(avCodecContext_, outNumChannels_);
|
|
221
|
+
|
|
222
|
+
outSampleRate_ = audioStreamOptions.sampleRate.value_or(inSampleRate_);
|
|
223
|
+
validateSampleRate(*avCodec, outSampleRate_);
|
|
224
|
+
avCodecContext_->sample_rate = outSampleRate_;
|
|
225
|
+
|
|
226
|
+
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
|
|
227
|
+
// may need to convert the samples into a supported output sample format,
|
|
228
|
+
// which is what the `.sample_fmt` defines.
|
|
229
|
+
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
|
|
230
|
+
|
|
231
|
+
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
|
|
232
|
+
TORCH_CHECK(
|
|
233
|
+
status == AVSUCCESS,
|
|
234
|
+
"avcodec_open2 failed: ",
|
|
235
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
236
|
+
|
|
237
|
+
// We're allocating the stream here. Streams are meant to be freed by
|
|
238
|
+
// avformat_free_context(avFormatContext), which we call in the
|
|
239
|
+
// avFormatContext_'s destructor.
|
|
240
|
+
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
|
|
241
|
+
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
|
|
242
|
+
status = avcodec_parameters_from_context(
|
|
243
|
+
avStream->codecpar, avCodecContext_.get());
|
|
244
|
+
TORCH_CHECK(
|
|
245
|
+
status == AVSUCCESS,
|
|
246
|
+
"avcodec_parameters_from_context failed: ",
|
|
247
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
248
|
+
streamIndex_ = avStream->index;
|
|
249
|
+
|
|
250
|
+
// If sample rate conversion is needed and the encoder doesn't support
|
|
251
|
+
// variable frame size, we need to create an intermediate FIFO. See
|
|
252
|
+
// [Encoding loop, sample rate conversion and FIFO].
|
|
253
|
+
if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0) &&
|
|
254
|
+
(inSampleRate_ != outSampleRate_)) {
|
|
255
|
+
// frame_size * 2 is a decent default size. FFmpeg automatically
|
|
256
|
+
// re-allocates the fifo if more space is needed.
|
|
257
|
+
auto avAudioFifo = av_audio_fifo_alloc(
|
|
258
|
+
avCodecContext_->sample_fmt,
|
|
259
|
+
outNumChannels_,
|
|
260
|
+
avCodecContext_->frame_size * 2);
|
|
261
|
+
TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo.");
|
|
262
|
+
avAudioFifo_.reset(avAudioFifo);
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
torch::Tensor AudioEncoder::encodeToTensor() {
|
|
267
|
+
TORCH_CHECK(
|
|
268
|
+
avioContextHolder_ != nullptr,
|
|
269
|
+
"Cannot encode to tensor, avio tensor context doesn't exist.");
|
|
270
|
+
encode();
|
|
271
|
+
auto avioToTensorContext =
|
|
272
|
+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
|
|
273
|
+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
|
|
274
|
+
return avioToTensorContext->getOutputTensor();
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
void AudioEncoder::encode() {
|
|
278
|
+
// To be on the safe side we enforce that encode() can only be called once on
|
|
279
|
+
// an encoder object. Whether this is actually necessary is unknown, so this
|
|
280
|
+
// may be relaxed if needed.
|
|
281
|
+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
|
|
282
|
+
encodeWasCalled_ = true;
|
|
283
|
+
|
|
284
|
+
// Default to 256 like in torchaudio
|
|
285
|
+
int numSamplesAllocatedPerFrame =
|
|
286
|
+
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
|
|
287
|
+
UniqueAVFrame avFrame = allocateAVFrame(
|
|
288
|
+
numSamplesAllocatedPerFrame,
|
|
289
|
+
inSampleRate_,
|
|
290
|
+
static_cast<int>(samples_.sizes()[0]),
|
|
291
|
+
AV_SAMPLE_FMT_FLTP);
|
|
292
|
+
avFrame->pts = 0;
|
|
293
|
+
|
|
294
|
+
AutoAVPacket autoAVPacket;
|
|
295
|
+
|
|
296
|
+
uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
|
|
297
|
+
int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
|
|
298
|
+
int numEncodedSamples = 0; // per channel
|
|
299
|
+
int numBytesPerSample = static_cast<int>(samples_.element_size());
|
|
300
|
+
int numBytesPerChannel = numSamples * numBytesPerSample;
|
|
301
|
+
|
|
302
|
+
auto status = avformat_write_header(avFormatContext_.get(), nullptr);
|
|
303
|
+
TORCH_CHECK(
|
|
304
|
+
status == AVSUCCESS,
|
|
305
|
+
"Error in avformat_write_header: ",
|
|
306
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
307
|
+
|
|
308
|
+
while (numEncodedSamples < numSamples) {
|
|
309
|
+
int numSamplesToEncode =
|
|
310
|
+
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
|
|
311
|
+
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
|
|
312
|
+
|
|
313
|
+
for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
|
|
314
|
+
std::memcpy(
|
|
315
|
+
avFrame->data[ch],
|
|
316
|
+
psamples + ch * numBytesPerChannel,
|
|
317
|
+
numBytesToEncode);
|
|
318
|
+
}
|
|
319
|
+
psamples += numBytesToEncode;
|
|
320
|
+
|
|
321
|
+
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
|
|
322
|
+
// that the frame buffers are allocated to a big enough size. Here, we reset
|
|
323
|
+
// it to the exact number of samples that need to be encoded, otherwise the
|
|
324
|
+
// encoded frame would contain more samples than necessary and our results
|
|
325
|
+
// wouldn't match the ffmpeg CLI.
|
|
326
|
+
avFrame->nb_samples = numSamplesToEncode;
|
|
327
|
+
|
|
328
|
+
UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame);
|
|
329
|
+
encodeFrameThroughFifo(autoAVPacket, convertedAVFrame);
|
|
330
|
+
|
|
331
|
+
numEncodedSamples += numSamplesToEncode;
|
|
332
|
+
}
|
|
333
|
+
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
|
|
334
|
+
|
|
335
|
+
flushBuffers();
|
|
336
|
+
|
|
337
|
+
status = av_write_trailer(avFormatContext_.get());
|
|
338
|
+
TORCH_CHECK(
|
|
339
|
+
status == AVSUCCESS,
|
|
340
|
+
"Error in: av_write_trailer",
|
|
341
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
342
|
+
|
|
343
|
+
closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get());
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
|
|
347
|
+
if (static_cast<AVSampleFormat>(avFrame->format) ==
|
|
348
|
+
avCodecContext_->sample_fmt &&
|
|
349
|
+
getNumChannels(avFrame) == outNumChannels_ &&
|
|
350
|
+
avFrame->sample_rate == outSampleRate_) {
|
|
351
|
+
// Note: the clone references the same underlying data, it's a cheap copy.
|
|
352
|
+
return UniqueAVFrame(av_frame_clone(avFrame.get()));
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
if (!swrContext_) {
|
|
356
|
+
swrContext_.reset(createSwrContext(
|
|
357
|
+
static_cast<AVSampleFormat>(avFrame->format),
|
|
358
|
+
avCodecContext_->sample_fmt,
|
|
359
|
+
avFrame->sample_rate,
|
|
360
|
+
outSampleRate_,
|
|
361
|
+
avFrame,
|
|
362
|
+
outNumChannels_));
|
|
363
|
+
}
|
|
364
|
+
UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples(
|
|
365
|
+
swrContext_,
|
|
366
|
+
avFrame,
|
|
367
|
+
avCodecContext_->sample_fmt,
|
|
368
|
+
outSampleRate_,
|
|
369
|
+
outNumChannels_);
|
|
370
|
+
|
|
371
|
+
if (avFrame->sample_rate == outSampleRate_) {
|
|
372
|
+
TORCH_CHECK(
|
|
373
|
+
convertedAVFrame->nb_samples == avFrame->nb_samples,
|
|
374
|
+
"convertedAVFrame->nb_samples=",
|
|
375
|
+
convertedAVFrame->nb_samples,
|
|
376
|
+
" differs from ",
|
|
377
|
+
"avFrame->nb_samples=",
|
|
378
|
+
avFrame->nb_samples,
|
|
379
|
+
"This is unexpected, please report on the TorchCodec bug tracker.");
|
|
380
|
+
}
|
|
381
|
+
return convertedAVFrame;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
void AudioEncoder::encodeFrameThroughFifo(
|
|
385
|
+
AutoAVPacket& autoAVPacket,
|
|
386
|
+
const UniqueAVFrame& avFrame,
|
|
387
|
+
// flushFifo is only set to true in maybeFlushSwrBuffers(), i.e. at the very
|
|
388
|
+
// end of the encoding process when we're flushing buffers. We also want to
|
|
389
|
+
// flush the FIFO so as to not leave any remaining samples in it.
|
|
390
|
+
bool flushFifo) {
|
|
391
|
+
if (avAudioFifo_ == nullptr) {
|
|
392
|
+
encodeFrame(autoAVPacket, avFrame);
|
|
393
|
+
return;
|
|
394
|
+
}
|
|
395
|
+
int numSamplesWritten = av_audio_fifo_write(
|
|
396
|
+
avAudioFifo_.get(),
|
|
397
|
+
reinterpret_cast<void**>(avFrame->data),
|
|
398
|
+
avFrame->nb_samples);
|
|
399
|
+
TORCH_CHECK(
|
|
400
|
+
numSamplesWritten == avFrame->nb_samples,
|
|
401
|
+
"Tried to write ",
|
|
402
|
+
avFrame->nb_samples,
|
|
403
|
+
" samples, but only wrote ",
|
|
404
|
+
numSamplesWritten);
|
|
405
|
+
|
|
406
|
+
UniqueAVFrame newavFrame = allocateAVFrame(
|
|
407
|
+
avCodecContext_->frame_size,
|
|
408
|
+
outSampleRate_,
|
|
409
|
+
outNumChannels_,
|
|
410
|
+
avCodecContext_->sample_fmt);
|
|
411
|
+
|
|
412
|
+
// Explaining the while bound:
|
|
413
|
+
// - if we're not flushing the FIFO, i.e. in most cases, we want to pull
|
|
414
|
+
// exactly `frame_size` samples from the FIFO, so we have to stop before it
|
|
415
|
+
// contains less than `frame_size` samples.
|
|
416
|
+
// - if we're flushing the FIFO, we want to read from the FIFO until the very
|
|
417
|
+
// last sample it contains.
|
|
418
|
+
//
|
|
419
|
+
// In both cases, for as long as we can, we're trying to pull exatly
|
|
420
|
+
// `frame_size` samples from the FIFO and send each `frame_size`-sized avFrame
|
|
421
|
+
// to encodeFrame(). Only the very last avFrame of the encoding process is
|
|
422
|
+
// allowed to contained less than frame_size samples. That only happens when
|
|
423
|
+
// flushFifo is true.
|
|
424
|
+
while (av_audio_fifo_size(avAudioFifo_.get()) >=
|
|
425
|
+
(flushFifo ? 1 : avCodecContext_->frame_size)) {
|
|
426
|
+
int samplesToRead = std::min(
|
|
427
|
+
av_audio_fifo_size(avAudioFifo_.get()), newavFrame->nb_samples);
|
|
428
|
+
int numSamplesRead = av_audio_fifo_read(
|
|
429
|
+
avAudioFifo_.get(),
|
|
430
|
+
reinterpret_cast<void**>(newavFrame->data),
|
|
431
|
+
samplesToRead);
|
|
432
|
+
TORCH_CHECK(
|
|
433
|
+
numSamplesRead == samplesToRead,
|
|
434
|
+
"Tried to read ",
|
|
435
|
+
samplesToRead,
|
|
436
|
+
" samples, but only read ",
|
|
437
|
+
numSamplesRead);
|
|
438
|
+
|
|
439
|
+
newavFrame->nb_samples = numSamplesRead;
|
|
440
|
+
encodeFrame(autoAVPacket, newavFrame);
|
|
441
|
+
}
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
void AudioEncoder::encodeFrame(
|
|
445
|
+
AutoAVPacket& autoAVPacket,
|
|
446
|
+
const UniqueAVFrame& avFrame) {
|
|
447
|
+
if (avFrame != nullptr) {
|
|
448
|
+
avFrame->pts = lastEncodedAVFramePts_;
|
|
449
|
+
lastEncodedAVFramePts_ += avFrame->nb_samples;
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
|
|
453
|
+
TORCH_CHECK(
|
|
454
|
+
status == AVSUCCESS,
|
|
455
|
+
"Error while sending frame: ",
|
|
456
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
457
|
+
|
|
458
|
+
while (status >= 0) {
|
|
459
|
+
ReferenceAVPacket packet(autoAVPacket);
|
|
460
|
+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
|
|
461
|
+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
|
|
462
|
+
if (status == AVERROR_EOF) {
|
|
463
|
+
// Flush the packets that were potentially buffered by
|
|
464
|
+
// av_interleaved_write_frame(). See corresponding block in
|
|
465
|
+
// TorchAudio:
|
|
466
|
+
// https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
|
|
467
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
|
|
468
|
+
TORCH_CHECK(
|
|
469
|
+
status == AVSUCCESS,
|
|
470
|
+
"Failed to flush packet: ",
|
|
471
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
472
|
+
}
|
|
473
|
+
return;
|
|
474
|
+
}
|
|
475
|
+
TORCH_CHECK(
|
|
476
|
+
status >= 0,
|
|
477
|
+
"Error receiving packet: ",
|
|
478
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
479
|
+
|
|
480
|
+
packet->stream_index = streamIndex_;
|
|
481
|
+
|
|
482
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
|
|
483
|
+
TORCH_CHECK(
|
|
484
|
+
status == AVSUCCESS,
|
|
485
|
+
"Error in av_interleaved_write_frame: ",
|
|
486
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
|
|
491
|
+
// Similar to the decoder's method with the same name, but for encoding this
|
|
492
|
+
// time. That is, when sample conversion is involved, libswresample may have
|
|
493
|
+
// buffered some samples that we now need to flush and send to the encoder.
|
|
494
|
+
if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) {
|
|
495
|
+
return;
|
|
496
|
+
}
|
|
497
|
+
TORCH_CHECK(
|
|
498
|
+
swrContext_ != nullptr,
|
|
499
|
+
"swrContext is null, but sample rate conversion is needed. ",
|
|
500
|
+
"This is unexpected, please report on the TorchCodec bug tracker.");
|
|
501
|
+
|
|
502
|
+
int numRemainingSamples = // this is an upper bound
|
|
503
|
+
swr_get_out_samples(swrContext_.get(), 0);
|
|
504
|
+
if (numRemainingSamples == 0) {
|
|
505
|
+
return;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
UniqueAVFrame avFrame = allocateAVFrame(
|
|
509
|
+
numRemainingSamples,
|
|
510
|
+
outSampleRate_,
|
|
511
|
+
outNumChannels_,
|
|
512
|
+
avCodecContext_->sample_fmt);
|
|
513
|
+
int actualNumRemainingSamples = swr_convert(
|
|
514
|
+
swrContext_.get(), avFrame->data, avFrame->nb_samples, nullptr, 0);
|
|
515
|
+
avFrame->nb_samples = actualNumRemainingSamples;
|
|
516
|
+
|
|
517
|
+
// We're potentially sending avFrame through the FIFO (if it exists), in which
|
|
518
|
+
// case we also want to flush the FIFO itself.
|
|
519
|
+
encodeFrameThroughFifo(autoAVPacket, avFrame, /*flushFifo=*/true);
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
void AudioEncoder::flushBuffers() {
|
|
523
|
+
AutoAVPacket autoAVPacket;
|
|
524
|
+
maybeFlushSwrBuffers(autoAVPacket);
|
|
525
|
+
|
|
526
|
+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
namespace {
|
|
530
|
+
|
|
531
|
+
torch::Tensor validateFrames(const torch::Tensor& frames) {
|
|
532
|
+
TORCH_CHECK(
|
|
533
|
+
frames.dtype() == torch::kUInt8,
|
|
534
|
+
"frames must have uint8 dtype, got ",
|
|
535
|
+
frames.dtype());
|
|
536
|
+
TORCH_CHECK(
|
|
537
|
+
frames.dim() == 4,
|
|
538
|
+
"frames must have 4 dimensions (N, C, H, W), got ",
|
|
539
|
+
frames.dim());
|
|
540
|
+
TORCH_CHECK(
|
|
541
|
+
frames.sizes()[1] == 3,
|
|
542
|
+
"frame must have 3 channels (R, G, B), got ",
|
|
543
|
+
frames.sizes()[1]);
|
|
544
|
+
return frames.contiguous();
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
AVPixelFormat validatePixelFormat(
|
|
548
|
+
const AVCodec& avCodec,
|
|
549
|
+
const std::string& targetPixelFormat) {
|
|
550
|
+
AVPixelFormat pixelFormat = av_get_pix_fmt(targetPixelFormat.c_str());
|
|
551
|
+
|
|
552
|
+
// Validate that the encoder supports this pixel format
|
|
553
|
+
const AVPixelFormat* supportedFormats = getSupportedPixelFormats(avCodec);
|
|
554
|
+
if (supportedFormats != nullptr) {
|
|
555
|
+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
|
|
556
|
+
if (supportedFormats[i] == pixelFormat) {
|
|
557
|
+
return pixelFormat;
|
|
558
|
+
}
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
std::stringstream errorMsg;
|
|
563
|
+
// av_get_pix_fmt failed to find a pix_fmt
|
|
564
|
+
if (pixelFormat == AV_PIX_FMT_NONE) {
|
|
565
|
+
errorMsg << "Unknown pixel format: " << targetPixelFormat;
|
|
566
|
+
} else {
|
|
567
|
+
errorMsg << "Specified pixel format " << targetPixelFormat
|
|
568
|
+
<< " is not supported by the " << avCodec.name << " encoder.";
|
|
569
|
+
}
|
|
570
|
+
// Build error message, similar to FFmpeg's error log
|
|
571
|
+
errorMsg << "\nSupported pixel formats for " << avCodec.name << ":";
|
|
572
|
+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
|
|
573
|
+
errorMsg << " " << av_get_pix_fmt_name(supportedFormats[i]);
|
|
574
|
+
}
|
|
575
|
+
TORCH_CHECK(false, errorMsg.str());
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
void tryToValidateCodecOption(
|
|
579
|
+
const AVCodec& avCodec,
|
|
580
|
+
const char* optionName,
|
|
581
|
+
const std::string& value) {
|
|
582
|
+
if (!avCodec.priv_class) {
|
|
583
|
+
return;
|
|
584
|
+
}
|
|
585
|
+
const AVOption* option = av_opt_find2(
|
|
586
|
+
// Convert obj arg from const AVClass* const* to non-const void*
|
|
587
|
+
// First cast to remove const, then cast to void*
|
|
588
|
+
const_cast<void*>(static_cast<const void*>(&avCodec.priv_class)),
|
|
589
|
+
optionName,
|
|
590
|
+
nullptr,
|
|
591
|
+
0,
|
|
592
|
+
AV_OPT_SEARCH_FAKE_OBJ,
|
|
593
|
+
nullptr);
|
|
594
|
+
// If option is not found we cannot validate it, let FFmpeg handle it
|
|
595
|
+
if (!option) {
|
|
596
|
+
return;
|
|
597
|
+
}
|
|
598
|
+
// Validate if option is defined as a numeric type
|
|
599
|
+
if (option->type == AV_OPT_TYPE_INT || option->type == AV_OPT_TYPE_INT64 ||
|
|
600
|
+
option->type == AV_OPT_TYPE_FLOAT || option->type == AV_OPT_TYPE_DOUBLE) {
|
|
601
|
+
try {
|
|
602
|
+
double numericValue = std::stod(value);
|
|
603
|
+
TORCH_CHECK(
|
|
604
|
+
numericValue >= option->min && numericValue <= option->max,
|
|
605
|
+
optionName,
|
|
606
|
+
"=",
|
|
607
|
+
numericValue,
|
|
608
|
+
" is out of valid range [",
|
|
609
|
+
option->min,
|
|
610
|
+
", ",
|
|
611
|
+
option->max,
|
|
612
|
+
"] for this codec. For more details, run 'ffmpeg -h encoder=",
|
|
613
|
+
avCodec.name,
|
|
614
|
+
"'");
|
|
615
|
+
} catch (const std::invalid_argument&) {
|
|
616
|
+
TORCH_CHECK(
|
|
617
|
+
false,
|
|
618
|
+
"Option ",
|
|
619
|
+
optionName,
|
|
620
|
+
" expects a numeric value but got '",
|
|
621
|
+
value,
|
|
622
|
+
"'");
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
void sortCodecOptions(
|
|
628
|
+
const std::map<std::string, std::string>& extraOptions,
|
|
629
|
+
UniqueAVDictionary& codecDict,
|
|
630
|
+
UniqueAVDictionary& formatDict) {
|
|
631
|
+
// Accepts a map of options as input, then sorts them into codec options and
|
|
632
|
+
// format options. The sorted options are returned into two separate dicts.
|
|
633
|
+
const AVClass* formatClass = avformat_get_class();
|
|
634
|
+
for (const auto& [key, value] : extraOptions) {
|
|
635
|
+
const AVOption* fmtOpt = av_opt_find2(
|
|
636
|
+
&formatClass,
|
|
637
|
+
key.c_str(),
|
|
638
|
+
nullptr,
|
|
639
|
+
0,
|
|
640
|
+
AV_OPT_SEARCH_CHILDREN | AV_OPT_SEARCH_FAKE_OBJ,
|
|
641
|
+
nullptr);
|
|
642
|
+
if (fmtOpt) {
|
|
643
|
+
av_dict_set(formatDict.getAddress(), key.c_str(), value.c_str(), 0);
|
|
644
|
+
} else {
|
|
645
|
+
// Default to codec option (includes AVCodecContext + encoder-private)
|
|
646
|
+
av_dict_set(codecDict.getAddress(), key.c_str(), value.c_str(), 0);
|
|
647
|
+
}
|
|
648
|
+
}
|
|
649
|
+
}
|
|
650
|
+
} // namespace
|
|
651
|
+
|
|
652
|
+
VideoEncoder::~VideoEncoder() {
|
|
653
|
+
closeAVIOContext(avFormatContext_.get(), avioContextHolder_.get());
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
VideoEncoder::VideoEncoder(
|
|
657
|
+
const torch::Tensor& frames,
|
|
658
|
+
double frameRate,
|
|
659
|
+
std::string_view fileName,
|
|
660
|
+
const VideoStreamOptions& videoStreamOptions)
|
|
661
|
+
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
|
|
662
|
+
setFFmpegLogLevel();
|
|
663
|
+
|
|
664
|
+
// Allocate output format context
|
|
665
|
+
AVFormatContext* avFormatContext = nullptr;
|
|
666
|
+
int status = avformat_alloc_output_context2(
|
|
667
|
+
&avFormatContext, nullptr, nullptr, fileName.data());
|
|
668
|
+
|
|
669
|
+
TORCH_CHECK(
|
|
670
|
+
avFormatContext != nullptr,
|
|
671
|
+
"Couldn't allocate AVFormatContext. ",
|
|
672
|
+
"The destination file is ",
|
|
673
|
+
fileName,
|
|
674
|
+
", check the desired extension? ",
|
|
675
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
676
|
+
avFormatContext_.reset(avFormatContext);
|
|
677
|
+
|
|
678
|
+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
|
|
679
|
+
TORCH_CHECK(
|
|
680
|
+
status >= 0,
|
|
681
|
+
"avio_open failed. The destination file is ",
|
|
682
|
+
fileName,
|
|
683
|
+
", make sure it's a valid path? ",
|
|
684
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
685
|
+
initializeEncoder(videoStreamOptions);
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
VideoEncoder::VideoEncoder(
|
|
689
|
+
const torch::Tensor& frames,
|
|
690
|
+
double frameRate,
|
|
691
|
+
std::string_view formatName,
|
|
692
|
+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
|
|
693
|
+
const VideoStreamOptions& videoStreamOptions)
|
|
694
|
+
: frames_(validateFrames(frames)),
|
|
695
|
+
inFrameRate_(frameRate),
|
|
696
|
+
avioContextHolder_(std::move(avioContextHolder)) {
|
|
697
|
+
setFFmpegLogLevel();
|
|
698
|
+
// Map mkv -> matroska when used as format name
|
|
699
|
+
formatName = (formatName == "mkv") ? "matroska" : formatName;
|
|
700
|
+
AVFormatContext* avFormatContext = nullptr;
|
|
701
|
+
int status = avformat_alloc_output_context2(
|
|
702
|
+
&avFormatContext, nullptr, formatName.data(), nullptr);
|
|
703
|
+
|
|
704
|
+
TORCH_CHECK(
|
|
705
|
+
avFormatContext != nullptr,
|
|
706
|
+
"Couldn't allocate AVFormatContext. ",
|
|
707
|
+
"Check the desired format? Got format=",
|
|
708
|
+
formatName,
|
|
709
|
+
". ",
|
|
710
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
711
|
+
avFormatContext_.reset(avFormatContext);
|
|
712
|
+
|
|
713
|
+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
|
|
714
|
+
|
|
715
|
+
initializeEncoder(videoStreamOptions);
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
void VideoEncoder::initializeEncoder(
|
|
719
|
+
const VideoStreamOptions& videoStreamOptions) {
|
|
720
|
+
// Only create device interface when frames are on a CUDA device.
|
|
721
|
+
// Encoding on CPU is implemented in this file.
|
|
722
|
+
if (frames_.device().is_cuda()) {
|
|
723
|
+
deviceInterface_ = createDeviceInterface(frames_.device());
|
|
724
|
+
}
|
|
725
|
+
const AVCodec* avCodec = nullptr;
|
|
726
|
+
// If codec arg is provided, find codec using logic similar to FFmpeg:
|
|
727
|
+
// https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835
|
|
728
|
+
if (videoStreamOptions.codec.has_value()) {
|
|
729
|
+
const std::string& codec = videoStreamOptions.codec.value();
|
|
730
|
+
// Try to find codec by name ("libx264", "libsvtav1")
|
|
731
|
+
avCodec = avcodec_find_encoder_by_name(codec.c_str());
|
|
732
|
+
// Try to find by codec descriptor ("h264", "av1")
|
|
733
|
+
if (!avCodec) {
|
|
734
|
+
const AVCodecDescriptor* desc =
|
|
735
|
+
avcodec_descriptor_get_by_name(codec.c_str());
|
|
736
|
+
if (desc) {
|
|
737
|
+
avCodec = avcodec_find_encoder(desc->id);
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
} else {
|
|
741
|
+
TORCH_CHECK(
|
|
742
|
+
avFormatContext_->oformat != nullptr,
|
|
743
|
+
"Output format is null, unable to find default codec.");
|
|
744
|
+
// If frames are on a CUDA device, try to substitute the default codec
|
|
745
|
+
// with its hardware equivalent
|
|
746
|
+
if (frames_.device().is_cuda()) {
|
|
747
|
+
TORCH_CHECK(
|
|
748
|
+
deviceInterface_ != nullptr,
|
|
749
|
+
"Device interface is undefined when input frames are on a CUDA device. This should never happen, please report this to the TorchCodec repo.");
|
|
750
|
+
auto hwCodec = deviceInterface_->findCodec(
|
|
751
|
+
avFormatContext_->oformat->video_codec, /*isDecoder=*/false);
|
|
752
|
+
if (hwCodec.has_value()) {
|
|
753
|
+
avCodec = hwCodec.value();
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
if (!avCodec) {
|
|
757
|
+
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
|
|
758
|
+
}
|
|
759
|
+
}
|
|
760
|
+
TORCH_CHECK(
|
|
761
|
+
avCodec != nullptr,
|
|
762
|
+
"Video codec ",
|
|
763
|
+
videoStreamOptions.codec.has_value()
|
|
764
|
+
? videoStreamOptions.codec.value() + " "
|
|
765
|
+
: "",
|
|
766
|
+
"not found. To see available codecs, run: ffmpeg -encoders");
|
|
767
|
+
|
|
768
|
+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
|
|
769
|
+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
|
|
770
|
+
avCodecContext_.reset(avCodecContext);
|
|
771
|
+
|
|
772
|
+
// Store dimension order and input pixel format
|
|
773
|
+
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
|
|
774
|
+
auto sizes = frames_.sizes();
|
|
775
|
+
inPixelFormat_ = AV_PIX_FMT_GBRP;
|
|
776
|
+
inHeight_ = static_cast<int>(sizes[2]);
|
|
777
|
+
inWidth_ = static_cast<int>(sizes[3]);
|
|
778
|
+
|
|
779
|
+
// Use specified dimensions or input dimensions
|
|
780
|
+
// TODO-VideoEncoder: Allow height and width to be set
|
|
781
|
+
outWidth_ = inWidth_;
|
|
782
|
+
outHeight_ = inHeight_;
|
|
783
|
+
|
|
784
|
+
if (videoStreamOptions.pixelFormat.has_value()) {
|
|
785
|
+
// TODO-VideoEncoder: Enable pixel formats to be set by user
|
|
786
|
+
// and handled with the appropriate NPP function on GPU.
|
|
787
|
+
if (frames_.device().is_cuda()) {
|
|
788
|
+
TORCH_CHECK(
|
|
789
|
+
false,
|
|
790
|
+
"Video encoding on GPU currently only supports the nv12 pixel format. "
|
|
791
|
+
"Do not set pixel_format to use nv12 by default.");
|
|
792
|
+
}
|
|
793
|
+
outPixelFormat_ =
|
|
794
|
+
validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value());
|
|
795
|
+
} else {
|
|
796
|
+
if (frames_.device().is_cuda()) {
|
|
797
|
+
// Default to nv12 pixel format when encoding on GPU.
|
|
798
|
+
outPixelFormat_ = DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT;
|
|
799
|
+
} else {
|
|
800
|
+
const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec);
|
|
801
|
+
// Use first listed pixel format as default (often yuv420p).
|
|
802
|
+
// This is similar to FFmpeg's logic:
|
|
803
|
+
// https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087
|
|
804
|
+
// If pixel formats are undefined for some reason, try yuv420p
|
|
805
|
+
outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE)
|
|
806
|
+
? formats[0]
|
|
807
|
+
: AV_PIX_FMT_YUV420P;
|
|
808
|
+
}
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
// Configure codec parameters
|
|
812
|
+
avCodecContext_->codec_id = avCodec->id;
|
|
813
|
+
avCodecContext_->width = outWidth_;
|
|
814
|
+
avCodecContext_->height = outHeight_;
|
|
815
|
+
avCodecContext_->pix_fmt = outPixelFormat_;
|
|
816
|
+
// TODO-VideoEncoder: Add and utilize output frame_rate option
|
|
817
|
+
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
|
|
818
|
+
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);
|
|
819
|
+
|
|
820
|
+
// Set flag for containers that require extradata to be in the codec context
|
|
821
|
+
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
|
|
822
|
+
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
// Apply videoStreamOptions
|
|
826
|
+
UniqueAVDictionary avCodecOptions;
|
|
827
|
+
if (videoStreamOptions.extraOptions.has_value()) {
|
|
828
|
+
for (const auto& [key, value] : videoStreamOptions.extraOptions.value()) {
|
|
829
|
+
tryToValidateCodecOption(*avCodec, key.c_str(), value);
|
|
830
|
+
}
|
|
831
|
+
sortCodecOptions(
|
|
832
|
+
videoStreamOptions.extraOptions.value(),
|
|
833
|
+
avCodecOptions,
|
|
834
|
+
avFormatOptions_);
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
if (videoStreamOptions.crf.has_value()) {
|
|
838
|
+
std::string crfValue = std::to_string(videoStreamOptions.crf.value());
|
|
839
|
+
tryToValidateCodecOption(*avCodec, "crf", crfValue);
|
|
840
|
+
av_dict_set(avCodecOptions.getAddress(), "crf", crfValue.c_str(), 0);
|
|
841
|
+
}
|
|
842
|
+
if (videoStreamOptions.preset.has_value()) {
|
|
843
|
+
av_dict_set(
|
|
844
|
+
avCodecOptions.getAddress(),
|
|
845
|
+
"preset",
|
|
846
|
+
videoStreamOptions.preset.value().c_str(),
|
|
847
|
+
0);
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
// When frames are on a CUDA device, deviceInterface_ will be defined.
|
|
851
|
+
if (frames_.device().is_cuda() && deviceInterface_) {
|
|
852
|
+
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
|
|
853
|
+
deviceInterface_->setupHardwareFrameContextForEncoding(
|
|
854
|
+
avCodecContext_.get());
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
int status = avcodec_open2(
|
|
858
|
+
avCodecContext_.get(), avCodec, avCodecOptions.getAddress());
|
|
859
|
+
|
|
860
|
+
TORCH_CHECK(
|
|
861
|
+
status == AVSUCCESS,
|
|
862
|
+
"avcodec_open2 failed: ",
|
|
863
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
864
|
+
|
|
865
|
+
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
|
|
866
|
+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
|
|
867
|
+
|
|
868
|
+
// Set the stream time base to encode correct frame timestamps
|
|
869
|
+
avStream_->time_base = avCodecContext_->time_base;
|
|
870
|
+
// Set the stream frame rate to store correct frame durations for some
|
|
871
|
+
// containers (webm, mkv)
|
|
872
|
+
avStream_->r_frame_rate = avCodecContext_->framerate;
|
|
873
|
+
|
|
874
|
+
status = avcodec_parameters_from_context(
|
|
875
|
+
avStream_->codecpar, avCodecContext_.get());
|
|
876
|
+
TORCH_CHECK(
|
|
877
|
+
status == AVSUCCESS,
|
|
878
|
+
"avcodec_parameters_from_context failed: ",
|
|
879
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
void VideoEncoder::encode() {
|
|
883
|
+
// To be on the safe side we enforce that encode() can only be called once
|
|
884
|
+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
|
|
885
|
+
encodeWasCalled_ = true;
|
|
886
|
+
|
|
887
|
+
int status = avformat_write_header(
|
|
888
|
+
avFormatContext_.get(), avFormatOptions_.getAddress());
|
|
889
|
+
TORCH_CHECK(
|
|
890
|
+
status == AVSUCCESS,
|
|
891
|
+
"Error in avformat_write_header: ",
|
|
892
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
893
|
+
|
|
894
|
+
AutoAVPacket autoAVPacket;
|
|
895
|
+
int numFrames = static_cast<int>(frames_.sizes()[0]);
|
|
896
|
+
for (int i = 0; i < numFrames; ++i) {
|
|
897
|
+
torch::Tensor currFrame = frames_[i];
|
|
898
|
+
UniqueAVFrame avFrame;
|
|
899
|
+
if (frames_.device().is_cuda() && deviceInterface_) {
|
|
900
|
+
auto cudaFrame = deviceInterface_->convertCUDATensorToAVFrameForEncoding(
|
|
901
|
+
currFrame, i, avCodecContext_.get());
|
|
902
|
+
TORCH_CHECK(
|
|
903
|
+
cudaFrame != nullptr,
|
|
904
|
+
"convertCUDATensorToAVFrameForEncoding failed for frame ",
|
|
905
|
+
i,
|
|
906
|
+
" on device: ",
|
|
907
|
+
frames_.device());
|
|
908
|
+
avFrame = std::move(cudaFrame);
|
|
909
|
+
} else {
|
|
910
|
+
avFrame = convertTensorToAVFrame(currFrame, i);
|
|
911
|
+
}
|
|
912
|
+
encodeFrame(autoAVPacket, avFrame);
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
flushBuffers();
|
|
916
|
+
|
|
917
|
+
status = av_write_trailer(avFormatContext_.get());
|
|
918
|
+
TORCH_CHECK(
|
|
919
|
+
status == AVSUCCESS,
|
|
920
|
+
"Error in av_write_trailer: ",
|
|
921
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
|
|
925
|
+
const torch::Tensor& frame,
|
|
926
|
+
int frameIndex) {
|
|
927
|
+
// Initialize and cache scaling context if it does not exist
|
|
928
|
+
if (!swsContext_) {
|
|
929
|
+
swsContext_.reset(sws_getContext(
|
|
930
|
+
inWidth_,
|
|
931
|
+
inHeight_,
|
|
932
|
+
inPixelFormat_,
|
|
933
|
+
outWidth_,
|
|
934
|
+
outHeight_,
|
|
935
|
+
outPixelFormat_,
|
|
936
|
+
SWS_BICUBIC, // Used by FFmpeg CLI
|
|
937
|
+
nullptr,
|
|
938
|
+
nullptr,
|
|
939
|
+
nullptr));
|
|
940
|
+
TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
UniqueAVFrame avFrame(av_frame_alloc());
|
|
944
|
+
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
|
|
945
|
+
|
|
946
|
+
// Set output frame properties
|
|
947
|
+
avFrame->format = outPixelFormat_;
|
|
948
|
+
avFrame->width = outWidth_;
|
|
949
|
+
avFrame->height = outHeight_;
|
|
950
|
+
avFrame->pts = frameIndex;
|
|
951
|
+
|
|
952
|
+
int status = av_frame_get_buffer(avFrame.get(), 0);
|
|
953
|
+
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");
|
|
954
|
+
|
|
955
|
+
// Need to convert/scale the frame
|
|
956
|
+
// Create temporary frame with input format
|
|
957
|
+
UniqueAVFrame inputFrame(av_frame_alloc());
|
|
958
|
+
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");
|
|
959
|
+
|
|
960
|
+
inputFrame->format = inPixelFormat_;
|
|
961
|
+
inputFrame->width = inWidth_;
|
|
962
|
+
inputFrame->height = inHeight_;
|
|
963
|
+
|
|
964
|
+
uint8_t* tensorData = static_cast<uint8_t*>(frame.data_ptr());
|
|
965
|
+
|
|
966
|
+
// TODO-VideoEncoder: Reorder tensor if in NHWC format
|
|
967
|
+
int channelSize = inHeight_ * inWidth_;
|
|
968
|
+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
|
|
969
|
+
// TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
|
|
970
|
+
inputFrame->data[0] = tensorData + channelSize;
|
|
971
|
+
inputFrame->data[1] = tensorData + (2 * channelSize);
|
|
972
|
+
inputFrame->data[2] = tensorData;
|
|
973
|
+
|
|
974
|
+
inputFrame->linesize[0] = inWidth_;
|
|
975
|
+
inputFrame->linesize[1] = inWidth_;
|
|
976
|
+
inputFrame->linesize[2] = inWidth_;
|
|
977
|
+
|
|
978
|
+
status = sws_scale(
|
|
979
|
+
swsContext_.get(),
|
|
980
|
+
inputFrame->data,
|
|
981
|
+
inputFrame->linesize,
|
|
982
|
+
0,
|
|
983
|
+
inputFrame->height,
|
|
984
|
+
avFrame->data,
|
|
985
|
+
avFrame->linesize);
|
|
986
|
+
TORCH_CHECK(status == outHeight_, "sws_scale failed");
|
|
987
|
+
return avFrame;
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
torch::Tensor VideoEncoder::encodeToTensor() {
|
|
991
|
+
TORCH_CHECK(
|
|
992
|
+
avioContextHolder_ != nullptr,
|
|
993
|
+
"Cannot encode to tensor, avio tensor context doesn't exist.");
|
|
994
|
+
encode();
|
|
995
|
+
auto avioToTensorContext =
|
|
996
|
+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
|
|
997
|
+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
|
|
998
|
+
return avioToTensorContext->getOutputTensor();
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
void VideoEncoder::encodeFrame(
|
|
1002
|
+
AutoAVPacket& autoAVPacket,
|
|
1003
|
+
const UniqueAVFrame& avFrame) {
|
|
1004
|
+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
|
|
1005
|
+
TORCH_CHECK(
|
|
1006
|
+
status == AVSUCCESS,
|
|
1007
|
+
"Error while sending frame: ",
|
|
1008
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
1009
|
+
|
|
1010
|
+
while (status >= 0) {
|
|
1011
|
+
ReferenceAVPacket packet(autoAVPacket);
|
|
1012
|
+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
|
|
1013
|
+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
|
|
1014
|
+
if (status == AVERROR_EOF) {
|
|
1015
|
+
// Flush remaining buffered packets
|
|
1016
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
|
|
1017
|
+
TORCH_CHECK(
|
|
1018
|
+
status == AVSUCCESS,
|
|
1019
|
+
"Failed to flush packet: ",
|
|
1020
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
1021
|
+
}
|
|
1022
|
+
return;
|
|
1023
|
+
}
|
|
1024
|
+
TORCH_CHECK(
|
|
1025
|
+
status >= 0,
|
|
1026
|
+
"Error receiving packet: ",
|
|
1027
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
1028
|
+
|
|
1029
|
+
// The code below is borrowed from torchaudio:
|
|
1030
|
+
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
|
|
1031
|
+
// Setting packet->duration to 1 allows the last frame to be properly
|
|
1032
|
+
// encoded, and needs to be set before calling av_packet_rescale_ts.
|
|
1033
|
+
if (packet->duration == 0) {
|
|
1034
|
+
packet->duration = 1;
|
|
1035
|
+
}
|
|
1036
|
+
av_packet_rescale_ts(
|
|
1037
|
+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
|
|
1038
|
+
packet->stream_index = avStream_->index;
|
|
1039
|
+
|
|
1040
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
|
|
1041
|
+
TORCH_CHECK(
|
|
1042
|
+
status == AVSUCCESS,
|
|
1043
|
+
"Error in av_interleaved_write_frame: ",
|
|
1044
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
1045
|
+
}
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
void VideoEncoder::flushBuffers() {
|
|
1049
|
+
AutoAVPacket autoAVPacket;
|
|
1050
|
+
// Send null frame to signal end of input
|
|
1051
|
+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
} // namespace facebook::torchcodec
|