torchcodec 0.7.0__cp310-cp310-win_amd64.whl → 0.8.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchcodec might be problematic. Click here for more details.
- torchcodec/_core/BetaCudaDeviceInterface.cpp +636 -0
- torchcodec/_core/BetaCudaDeviceInterface.h +191 -0
- torchcodec/_core/CMakeLists.txt +36 -3
- torchcodec/_core/CUDACommon.cpp +315 -0
- torchcodec/_core/CUDACommon.h +46 -0
- torchcodec/_core/CpuDeviceInterface.cpp +189 -108
- torchcodec/_core/CpuDeviceInterface.h +81 -19
- torchcodec/_core/CudaDeviceInterface.cpp +211 -368
- torchcodec/_core/CudaDeviceInterface.h +33 -6
- torchcodec/_core/DeviceInterface.cpp +57 -19
- torchcodec/_core/DeviceInterface.h +97 -16
- torchcodec/_core/Encoder.cpp +302 -9
- torchcodec/_core/Encoder.h +51 -1
- torchcodec/_core/FFMPEGCommon.cpp +189 -2
- torchcodec/_core/FFMPEGCommon.h +18 -0
- torchcodec/_core/FilterGraph.cpp +28 -21
- torchcodec/_core/FilterGraph.h +15 -1
- torchcodec/_core/Frame.cpp +17 -7
- torchcodec/_core/Frame.h +15 -61
- torchcodec/_core/Metadata.h +2 -2
- torchcodec/_core/NVDECCache.cpp +70 -0
- torchcodec/_core/NVDECCache.h +104 -0
- torchcodec/_core/SingleStreamDecoder.cpp +202 -198
- torchcodec/_core/SingleStreamDecoder.h +39 -14
- torchcodec/_core/StreamOptions.h +16 -6
- torchcodec/_core/Transform.cpp +60 -0
- torchcodec/_core/Transform.h +59 -0
- torchcodec/_core/__init__.py +1 -0
- torchcodec/_core/custom_ops.cpp +180 -32
- torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
- torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
- torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
- torchcodec/_core/ops.py +86 -43
- torchcodec/_core/pybind_ops.cpp +22 -59
- torchcodec/_samplers/video_clip_sampler.py +7 -19
- torchcodec/decoders/__init__.py +1 -0
- torchcodec/decoders/_decoder_utils.py +61 -1
- torchcodec/decoders/_video_decoder.py +56 -20
- torchcodec/libtorchcodec_core4.dll +0 -0
- torchcodec/libtorchcodec_core5.dll +0 -0
- torchcodec/libtorchcodec_core6.dll +0 -0
- torchcodec/libtorchcodec_core7.dll +0 -0
- torchcodec/libtorchcodec_core8.dll +0 -0
- torchcodec/libtorchcodec_custom_ops4.dll +0 -0
- torchcodec/libtorchcodec_custom_ops5.dll +0 -0
- torchcodec/libtorchcodec_custom_ops6.dll +0 -0
- torchcodec/libtorchcodec_custom_ops7.dll +0 -0
- torchcodec/libtorchcodec_custom_ops8.dll +0 -0
- torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
- torchcodec/samplers/_time_based.py +8 -0
- torchcodec/version.py +1 -1
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/METADATA +24 -13
- torchcodec-0.8.0.dist-info/RECORD +80 -0
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/WHEEL +1 -1
- torchcodec-0.7.0.dist-info/RECORD +0 -67
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -6,8 +6,9 @@
|
|
|
6
6
|
|
|
7
7
|
#pragma once
|
|
8
8
|
|
|
9
|
-
#include
|
|
9
|
+
#include "src/torchcodec/_core/CUDACommon.h"
|
|
10
10
|
#include "src/torchcodec/_core/DeviceInterface.h"
|
|
11
|
+
#include "src/torchcodec/_core/FilterGraph.h"
|
|
11
12
|
|
|
12
13
|
namespace facebook::torchcodec {
|
|
13
14
|
|
|
@@ -19,19 +20,45 @@ class CudaDeviceInterface : public DeviceInterface {
|
|
|
19
20
|
|
|
20
21
|
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
|
|
21
22
|
|
|
22
|
-
void
|
|
23
|
+
void initialize(
|
|
24
|
+
const AVStream* avStream,
|
|
25
|
+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
|
|
23
26
|
|
|
24
|
-
void
|
|
27
|
+
void initializeVideo(
|
|
25
28
|
const VideoStreamOptions& videoStreamOptions,
|
|
26
|
-
const
|
|
29
|
+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
|
|
30
|
+
transforms,
|
|
31
|
+
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims)
|
|
32
|
+
override;
|
|
33
|
+
|
|
34
|
+
void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override;
|
|
35
|
+
|
|
36
|
+
void convertAVFrameToFrameOutput(
|
|
27
37
|
UniqueAVFrame& avFrame,
|
|
28
38
|
FrameOutput& frameOutput,
|
|
29
39
|
std::optional<torch::Tensor> preAllocatedOutputTensor =
|
|
30
40
|
std::nullopt) override;
|
|
31
41
|
|
|
32
42
|
private:
|
|
33
|
-
|
|
34
|
-
|
|
43
|
+
// Our CUDA decoding code assumes NV12 format. In order to handle other
|
|
44
|
+
// kinds of input, we need to convert them to NV12. Our current implementation
|
|
45
|
+
// does this using filtergraph.
|
|
46
|
+
UniqueAVFrame maybeConvertAVFrameToNV12OrRGB24(UniqueAVFrame& avFrame);
|
|
47
|
+
|
|
48
|
+
// We sometimes encounter frames that cannot be decoded on the CUDA device.
|
|
49
|
+
// Rather than erroring out, we decode them on the CPU.
|
|
50
|
+
std::unique_ptr<DeviceInterface> cpuInterface_;
|
|
51
|
+
|
|
52
|
+
VideoStreamOptions videoStreamOptions_;
|
|
53
|
+
AVRational timeBase_;
|
|
54
|
+
|
|
55
|
+
UniqueAVBufferRef hardwareDeviceCtx_;
|
|
56
|
+
UniqueNppContext nppCtx_;
|
|
57
|
+
|
|
58
|
+
// This filtergraph instance is only used for NV12 format conversion in
|
|
59
|
+
// maybeConvertAVFrameToNV12().
|
|
60
|
+
std::unique_ptr<FiltersContext> nv12ConversionContext_;
|
|
61
|
+
std::unique_ptr<FilterGraph> nv12Conversion_;
|
|
35
62
|
};
|
|
36
63
|
|
|
37
64
|
} // namespace facebook::torchcodec
|
|
@@ -11,7 +11,8 @@
|
|
|
11
11
|
namespace facebook::torchcodec {
|
|
12
12
|
|
|
13
13
|
namespace {
|
|
14
|
-
using DeviceInterfaceMap =
|
|
14
|
+
using DeviceInterfaceMap =
|
|
15
|
+
std::map<DeviceInterfaceKey, CreateDeviceInterfaceFn>;
|
|
15
16
|
static std::mutex g_interface_mutex;
|
|
16
17
|
|
|
17
18
|
DeviceInterfaceMap& getDeviceMap() {
|
|
@@ -30,50 +31,87 @@ std::string getDeviceType(const std::string& device) {
|
|
|
30
31
|
} // namespace
|
|
31
32
|
|
|
32
33
|
bool registerDeviceInterface(
|
|
33
|
-
|
|
34
|
+
const DeviceInterfaceKey& key,
|
|
34
35
|
CreateDeviceInterfaceFn createInterface) {
|
|
35
36
|
std::scoped_lock lock(g_interface_mutex);
|
|
36
37
|
DeviceInterfaceMap& deviceMap = getDeviceMap();
|
|
37
38
|
|
|
38
39
|
TORCH_CHECK(
|
|
39
|
-
deviceMap.find(
|
|
40
|
-
"Device interface already registered for ",
|
|
41
|
-
deviceType
|
|
42
|
-
|
|
40
|
+
deviceMap.find(key) == deviceMap.end(),
|
|
41
|
+
"Device interface already registered for device type ",
|
|
42
|
+
key.deviceType,
|
|
43
|
+
" variant '",
|
|
44
|
+
key.variant,
|
|
45
|
+
"'");
|
|
46
|
+
deviceMap.insert({key, createInterface});
|
|
43
47
|
|
|
44
48
|
return true;
|
|
45
49
|
}
|
|
46
50
|
|
|
47
|
-
|
|
51
|
+
void validateDeviceInterface(
|
|
52
|
+
const std::string device,
|
|
53
|
+
const std::string variant) {
|
|
48
54
|
std::scoped_lock lock(g_interface_mutex);
|
|
49
55
|
std::string deviceType = getDeviceType(device);
|
|
56
|
+
|
|
50
57
|
DeviceInterfaceMap& deviceMap = getDeviceMap();
|
|
51
58
|
|
|
59
|
+
// Find device interface that matches device type and variant
|
|
60
|
+
torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type();
|
|
61
|
+
|
|
52
62
|
auto deviceInterface = std::find_if(
|
|
53
63
|
deviceMap.begin(),
|
|
54
64
|
deviceMap.end(),
|
|
55
|
-
[&](const std::pair<
|
|
56
|
-
return
|
|
57
|
-
|
|
65
|
+
[&](const std::pair<DeviceInterfaceKey, CreateDeviceInterfaceFn>& arg) {
|
|
66
|
+
return arg.first.deviceType == deviceTypeEnum &&
|
|
67
|
+
arg.first.variant == variant;
|
|
58
68
|
});
|
|
59
|
-
TORCH_CHECK(
|
|
60
|
-
deviceInterface != deviceMap.end(), "Unsupported device: ", device);
|
|
61
69
|
|
|
62
|
-
|
|
70
|
+
TORCH_CHECK(
|
|
71
|
+
deviceInterface != deviceMap.end(),
|
|
72
|
+
"Unsupported device: ",
|
|
73
|
+
device,
|
|
74
|
+
" (device type: ",
|
|
75
|
+
deviceType,
|
|
76
|
+
", variant: ",
|
|
77
|
+
variant,
|
|
78
|
+
")");
|
|
63
79
|
}
|
|
64
80
|
|
|
65
81
|
std::unique_ptr<DeviceInterface> createDeviceInterface(
|
|
66
|
-
const torch::Device& device
|
|
67
|
-
|
|
82
|
+
const torch::Device& device,
|
|
83
|
+
const std::string_view variant) {
|
|
84
|
+
DeviceInterfaceKey key(device.type(), variant);
|
|
68
85
|
std::scoped_lock lock(g_interface_mutex);
|
|
69
86
|
DeviceInterfaceMap& deviceMap = getDeviceMap();
|
|
70
87
|
|
|
88
|
+
auto it = deviceMap.find(key);
|
|
89
|
+
if (it != deviceMap.end()) {
|
|
90
|
+
return std::unique_ptr<DeviceInterface>(it->second(device));
|
|
91
|
+
}
|
|
92
|
+
|
|
71
93
|
TORCH_CHECK(
|
|
72
|
-
|
|
73
|
-
"
|
|
74
|
-
device)
|
|
94
|
+
false,
|
|
95
|
+
"No device interface found for device type: ",
|
|
96
|
+
device.type(),
|
|
97
|
+
" variant: '",
|
|
98
|
+
variant,
|
|
99
|
+
"'");
|
|
100
|
+
}
|
|
75
101
|
|
|
76
|
-
|
|
102
|
+
torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) {
|
|
103
|
+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
|
|
104
|
+
|
|
105
|
+
int height = avFrame->height;
|
|
106
|
+
int width = avFrame->width;
|
|
107
|
+
std::vector<int64_t> shape = {height, width, 3};
|
|
108
|
+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
|
|
109
|
+
AVFrame* avFrameClone = av_frame_clone(avFrame.get());
|
|
110
|
+
auto deleter = [avFrameClone](void*) {
|
|
111
|
+
UniqueAVFrame avFrameToDelete(avFrameClone);
|
|
112
|
+
};
|
|
113
|
+
return torch::from_blob(
|
|
114
|
+
avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8});
|
|
77
115
|
}
|
|
78
116
|
|
|
79
117
|
} // namespace facebook::torchcodec
|
|
@@ -14,16 +14,27 @@
|
|
|
14
14
|
#include "FFMPEGCommon.h"
|
|
15
15
|
#include "src/torchcodec/_core/Frame.h"
|
|
16
16
|
#include "src/torchcodec/_core/StreamOptions.h"
|
|
17
|
+
#include "src/torchcodec/_core/Transform.h"
|
|
17
18
|
|
|
18
19
|
namespace facebook::torchcodec {
|
|
19
20
|
|
|
20
|
-
//
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
// Key for device interface registration with device type + variant support
|
|
22
|
+
struct DeviceInterfaceKey {
|
|
23
|
+
torch::DeviceType deviceType;
|
|
24
|
+
std::string_view variant = "default"; // e.g., "default", "beta", etc.
|
|
25
|
+
|
|
26
|
+
bool operator<(const DeviceInterfaceKey& other) const {
|
|
27
|
+
if (deviceType != other.deviceType) {
|
|
28
|
+
return deviceType < other.deviceType;
|
|
29
|
+
}
|
|
30
|
+
return variant < other.variant;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
explicit DeviceInterfaceKey(torch::DeviceType type) : deviceType(type) {}
|
|
34
|
+
|
|
35
|
+
DeviceInterfaceKey(torch::DeviceType type, const std::string_view& variant)
|
|
36
|
+
: deviceType(type), variant(variant) {}
|
|
37
|
+
};
|
|
27
38
|
|
|
28
39
|
class DeviceInterface {
|
|
29
40
|
public:
|
|
@@ -35,19 +46,84 @@ class DeviceInterface {
|
|
|
35
46
|
return device_;
|
|
36
47
|
};
|
|
37
48
|
|
|
38
|
-
virtual std::optional<const AVCodec*> findCodec(
|
|
49
|
+
virtual std::optional<const AVCodec*> findCodec(
|
|
50
|
+
[[maybe_unused]] const AVCodecID& codecId) {
|
|
51
|
+
return std::nullopt;
|
|
52
|
+
};
|
|
39
53
|
|
|
40
|
-
// Initialize the
|
|
41
|
-
|
|
42
|
-
|
|
54
|
+
// Initialize the device with parameters generic to all kinds of decoding.
|
|
55
|
+
virtual void initialize(
|
|
56
|
+
const AVStream* avStream,
|
|
57
|
+
const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
|
|
58
|
+
|
|
59
|
+
// Initialize the device with parameters specific to video decoding. There is
|
|
60
|
+
// a default empty implementation.
|
|
61
|
+
virtual void initializeVideo(
|
|
62
|
+
[[maybe_unused]] const VideoStreamOptions& videoStreamOptions,
|
|
63
|
+
[[maybe_unused]] const std::vector<std::unique_ptr<Transform>>&
|
|
64
|
+
transforms,
|
|
65
|
+
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {}
|
|
66
|
+
|
|
67
|
+
// In order for decoding to actually happen on an FFmpeg managed hardware
|
|
68
|
+
// device, we need to register the DeviceInterface managed
|
|
69
|
+
// AVHardwareDeviceContext with the AVCodecContext. We don't need to do this
|
|
70
|
+
// on the CPU and if FFmpeg is not managing the hardware device.
|
|
71
|
+
virtual void registerHardwareDeviceWithCodec(
|
|
72
|
+
[[maybe_unused]] AVCodecContext* codecContext) {}
|
|
43
73
|
|
|
44
74
|
virtual void convertAVFrameToFrameOutput(
|
|
45
|
-
const VideoStreamOptions& videoStreamOptions,
|
|
46
|
-
const AVRational& timeBase,
|
|
47
75
|
UniqueAVFrame& avFrame,
|
|
48
76
|
FrameOutput& frameOutput,
|
|
49
77
|
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
|
|
50
78
|
|
|
79
|
+
// ------------------------------------------
|
|
80
|
+
// Extension points for custom decoding paths
|
|
81
|
+
// ------------------------------------------
|
|
82
|
+
|
|
83
|
+
// Override to return true if this device interface can decode packets
|
|
84
|
+
// directly. This means that the following two member functions can both
|
|
85
|
+
// be called:
|
|
86
|
+
//
|
|
87
|
+
// 1. sendPacket()
|
|
88
|
+
// 2. receiveFrame()
|
|
89
|
+
virtual bool canDecodePacketDirectly() const {
|
|
90
|
+
return false;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
// Moral equivalent of avcodec_send_packet()
|
|
94
|
+
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
|
|
95
|
+
// other AVERROR on failure
|
|
96
|
+
virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
|
|
97
|
+
TORCH_CHECK(
|
|
98
|
+
false,
|
|
99
|
+
"Send/receive packet decoding not implemented for this device interface");
|
|
100
|
+
return AVERROR(ENOSYS);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Send an EOF packet to flush the decoder
|
|
104
|
+
// Returns AVSUCCESS on success, or other AVERROR on failure
|
|
105
|
+
virtual int sendEOFPacket() {
|
|
106
|
+
TORCH_CHECK(
|
|
107
|
+
false, "Send EOF packet not implemented for this device interface");
|
|
108
|
+
return AVERROR(ENOSYS);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// Moral equivalent of avcodec_receive_frame()
|
|
112
|
+
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
|
|
113
|
+
// AVERROR_EOF if end of stream, or other AVERROR on failure
|
|
114
|
+
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
|
|
115
|
+
TORCH_CHECK(
|
|
116
|
+
false,
|
|
117
|
+
"Send/receive packet decoding not implemented for this device interface");
|
|
118
|
+
return AVERROR(ENOSYS);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// Flush remaining frames from decoder
|
|
122
|
+
virtual void flush() {
|
|
123
|
+
// Default implementation is no-op for standard decoders
|
|
124
|
+
// Custom decoders can override this method
|
|
125
|
+
}
|
|
126
|
+
|
|
51
127
|
protected:
|
|
52
128
|
torch::Device device_;
|
|
53
129
|
};
|
|
@@ -56,12 +132,17 @@ using CreateDeviceInterfaceFn =
|
|
|
56
132
|
std::function<DeviceInterface*(const torch::Device& device)>;
|
|
57
133
|
|
|
58
134
|
bool registerDeviceInterface(
|
|
59
|
-
|
|
135
|
+
const DeviceInterfaceKey& key,
|
|
60
136
|
const CreateDeviceInterfaceFn createInterface);
|
|
61
137
|
|
|
62
|
-
|
|
138
|
+
void validateDeviceInterface(
|
|
139
|
+
const std::string device,
|
|
140
|
+
const std::string variant);
|
|
63
141
|
|
|
64
142
|
std::unique_ptr<DeviceInterface> createDeviceInterface(
|
|
65
|
-
const torch::Device& device
|
|
143
|
+
const torch::Device& device,
|
|
144
|
+
const std::string_view variant = "default");
|
|
145
|
+
|
|
146
|
+
torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame);
|
|
66
147
|
|
|
67
148
|
} // namespace facebook::torchcodec
|
torchcodec/_core/Encoder.cpp
CHANGED
|
@@ -4,6 +4,10 @@
|
|
|
4
4
|
#include "src/torchcodec/_core/Encoder.h"
|
|
5
5
|
#include "torch/types.h"
|
|
6
6
|
|
|
7
|
+
extern "C" {
|
|
8
|
+
#include <libavutil/pixdesc.h>
|
|
9
|
+
}
|
|
10
|
+
|
|
7
11
|
namespace facebook::torchcodec {
|
|
8
12
|
|
|
9
13
|
namespace {
|
|
@@ -33,21 +37,22 @@ torch::Tensor validateSamples(const torch::Tensor& samples) {
|
|
|
33
37
|
}
|
|
34
38
|
|
|
35
39
|
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
|
|
36
|
-
|
|
40
|
+
const int* supportedSampleRates = getSupportedSampleRates(avCodec);
|
|
41
|
+
if (supportedSampleRates == nullptr) {
|
|
37
42
|
return;
|
|
38
43
|
}
|
|
39
44
|
|
|
40
|
-
for (auto i = 0;
|
|
41
|
-
if (sampleRate ==
|
|
45
|
+
for (auto i = 0; supportedSampleRates[i] != 0; ++i) {
|
|
46
|
+
if (sampleRate == supportedSampleRates[i]) {
|
|
42
47
|
return;
|
|
43
48
|
}
|
|
44
49
|
}
|
|
45
50
|
std::stringstream supportedRates;
|
|
46
|
-
for (auto i = 0;
|
|
51
|
+
for (auto i = 0; supportedSampleRates[i] != 0; ++i) {
|
|
47
52
|
if (i > 0) {
|
|
48
53
|
supportedRates << ", ";
|
|
49
54
|
}
|
|
50
|
-
supportedRates <<
|
|
55
|
+
supportedRates << supportedSampleRates[i];
|
|
51
56
|
}
|
|
52
57
|
|
|
53
58
|
TORCH_CHECK(
|
|
@@ -73,19 +78,22 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
|
|
|
73
78
|
AV_SAMPLE_FMT_U8};
|
|
74
79
|
|
|
75
80
|
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
|
|
81
|
+
const AVSampleFormat* supportedSampleFormats =
|
|
82
|
+
getSupportedOutputSampleFormats(avCodec);
|
|
83
|
+
|
|
76
84
|
// Find a sample format that the encoder supports. We prefer using FLT[P],
|
|
77
85
|
// since this is the format of the input samples. If FLTP isn't supported
|
|
78
86
|
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
|
|
79
87
|
// into the format with the highest resolution.
|
|
80
|
-
if (
|
|
88
|
+
if (supportedSampleFormats == nullptr) {
|
|
81
89
|
// Can't really validate anything in this case, best we can do is hope that
|
|
82
90
|
// FLTP is supported by the encoder. If not, FFmpeg will raise.
|
|
83
91
|
return AV_SAMPLE_FMT_FLTP;
|
|
84
92
|
}
|
|
85
93
|
|
|
86
94
|
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
|
|
87
|
-
for (int i = 0;
|
|
88
|
-
if (
|
|
95
|
+
for (int i = 0; supportedSampleFormats[i] != -1; ++i) {
|
|
96
|
+
if (supportedSampleFormats[i] == preferredFormat) {
|
|
89
97
|
return preferredFormat;
|
|
90
98
|
}
|
|
91
99
|
}
|
|
@@ -93,7 +101,7 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
|
|
|
93
101
|
// We should always find a match in preferredFormatsOrder, so we should always
|
|
94
102
|
// return earlier. But in the event that a future FFmpeg version defines an
|
|
95
103
|
// additional sample format that isn't in preferredFormatsOrder, we fallback:
|
|
96
|
-
return
|
|
104
|
+
return supportedSampleFormats[0];
|
|
97
105
|
}
|
|
98
106
|
|
|
99
107
|
} // namespace
|
|
@@ -511,4 +519,289 @@ void AudioEncoder::flushBuffers() {
|
|
|
511
519
|
|
|
512
520
|
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
|
|
513
521
|
}
|
|
522
|
+
|
|
523
|
+
namespace {
|
|
524
|
+
|
|
525
|
+
torch::Tensor validateFrames(const torch::Tensor& frames) {
|
|
526
|
+
TORCH_CHECK(
|
|
527
|
+
frames.dtype() == torch::kUInt8,
|
|
528
|
+
"frames must have uint8 dtype, got ",
|
|
529
|
+
frames.dtype());
|
|
530
|
+
TORCH_CHECK(
|
|
531
|
+
frames.dim() == 4,
|
|
532
|
+
"frames must have 4 dimensions (N, C, H, W), got ",
|
|
533
|
+
frames.dim());
|
|
534
|
+
TORCH_CHECK(
|
|
535
|
+
frames.sizes()[1] == 3,
|
|
536
|
+
"frame must have 3 channels (R, G, B), got ",
|
|
537
|
+
frames.sizes()[1]);
|
|
538
|
+
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
|
|
539
|
+
return frames.contiguous();
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
} // namespace
|
|
543
|
+
|
|
544
|
+
VideoEncoder::~VideoEncoder() {
|
|
545
|
+
if (avFormatContext_ && avFormatContext_->pb) {
|
|
546
|
+
avio_flush(avFormatContext_->pb);
|
|
547
|
+
avio_close(avFormatContext_->pb);
|
|
548
|
+
avFormatContext_->pb = nullptr;
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
VideoEncoder::VideoEncoder(
|
|
553
|
+
const torch::Tensor& frames,
|
|
554
|
+
int frameRate,
|
|
555
|
+
std::string_view fileName,
|
|
556
|
+
const VideoStreamOptions& videoStreamOptions)
|
|
557
|
+
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
|
|
558
|
+
setFFmpegLogLevel();
|
|
559
|
+
|
|
560
|
+
// Allocate output format context
|
|
561
|
+
AVFormatContext* avFormatContext = nullptr;
|
|
562
|
+
int status = avformat_alloc_output_context2(
|
|
563
|
+
&avFormatContext, nullptr, nullptr, fileName.data());
|
|
564
|
+
|
|
565
|
+
TORCH_CHECK(
|
|
566
|
+
avFormatContext != nullptr,
|
|
567
|
+
"Couldn't allocate AVFormatContext. ",
|
|
568
|
+
"The destination file is ",
|
|
569
|
+
fileName,
|
|
570
|
+
", check the desired extension? ",
|
|
571
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
572
|
+
avFormatContext_.reset(avFormatContext);
|
|
573
|
+
|
|
574
|
+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
|
|
575
|
+
TORCH_CHECK(
|
|
576
|
+
status >= 0,
|
|
577
|
+
"avio_open failed. The destination file is ",
|
|
578
|
+
fileName,
|
|
579
|
+
", make sure it's a valid path? ",
|
|
580
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
581
|
+
initializeEncoder(videoStreamOptions);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
void VideoEncoder::initializeEncoder(
|
|
585
|
+
const VideoStreamOptions& videoStreamOptions) {
|
|
586
|
+
const AVCodec* avCodec =
|
|
587
|
+
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
|
|
588
|
+
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
|
|
589
|
+
|
|
590
|
+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
|
|
591
|
+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
|
|
592
|
+
avCodecContext_.reset(avCodecContext);
|
|
593
|
+
|
|
594
|
+
// Store dimension order and input pixel format
|
|
595
|
+
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
|
|
596
|
+
auto sizes = frames_.sizes();
|
|
597
|
+
inPixelFormat_ = AV_PIX_FMT_GBRP;
|
|
598
|
+
inHeight_ = static_cast<int>(sizes[2]);
|
|
599
|
+
inWidth_ = static_cast<int>(sizes[3]);
|
|
600
|
+
|
|
601
|
+
// Use specified dimensions or input dimensions
|
|
602
|
+
// TODO-VideoEncoder: Allow height and width to be set
|
|
603
|
+
outWidth_ = inWidth_;
|
|
604
|
+
outHeight_ = inHeight_;
|
|
605
|
+
|
|
606
|
+
// TODO-VideoEncoder: Enable other pixel formats
|
|
607
|
+
// Let FFmpeg choose best pixel format to minimize loss
|
|
608
|
+
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
|
|
609
|
+
getSupportedPixelFormats(*avCodec), // List of supported formats
|
|
610
|
+
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
|
|
611
|
+
0, // No alpha channel
|
|
612
|
+
nullptr // Discard conversion loss information
|
|
613
|
+
);
|
|
614
|
+
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
|
|
615
|
+
|
|
616
|
+
// Configure codec parameters
|
|
617
|
+
avCodecContext_->codec_id = avCodec->id;
|
|
618
|
+
avCodecContext_->width = outWidth_;
|
|
619
|
+
avCodecContext_->height = outHeight_;
|
|
620
|
+
avCodecContext_->pix_fmt = outPixelFormat_;
|
|
621
|
+
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
|
|
622
|
+
avCodecContext_->time_base = {1, inFrameRate_};
|
|
623
|
+
avCodecContext_->framerate = {inFrameRate_, 1};
|
|
624
|
+
|
|
625
|
+
// Set flag for containers that require extradata to be in the codec context
|
|
626
|
+
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
|
|
627
|
+
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
// Apply videoStreamOptions
|
|
631
|
+
AVDictionary* options = nullptr;
|
|
632
|
+
if (videoStreamOptions.crf.has_value()) {
|
|
633
|
+
av_dict_set(
|
|
634
|
+
&options,
|
|
635
|
+
"crf",
|
|
636
|
+
std::to_string(videoStreamOptions.crf.value()).c_str(),
|
|
637
|
+
0);
|
|
638
|
+
}
|
|
639
|
+
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
|
|
640
|
+
av_dict_free(&options);
|
|
641
|
+
|
|
642
|
+
TORCH_CHECK(
|
|
643
|
+
status == AVSUCCESS,
|
|
644
|
+
"avcodec_open2 failed: ",
|
|
645
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
646
|
+
|
|
647
|
+
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
|
|
648
|
+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
|
|
649
|
+
|
|
650
|
+
// Set the stream time base to encode correct frame timestamps
|
|
651
|
+
avStream_->time_base = avCodecContext_->time_base;
|
|
652
|
+
status = avcodec_parameters_from_context(
|
|
653
|
+
avStream_->codecpar, avCodecContext_.get());
|
|
654
|
+
TORCH_CHECK(
|
|
655
|
+
status == AVSUCCESS,
|
|
656
|
+
"avcodec_parameters_from_context failed: ",
|
|
657
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
void VideoEncoder::encode() {
|
|
661
|
+
// To be on the safe side we enforce that encode() can only be called once
|
|
662
|
+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
|
|
663
|
+
encodeWasCalled_ = true;
|
|
664
|
+
|
|
665
|
+
int status = avformat_write_header(avFormatContext_.get(), nullptr);
|
|
666
|
+
TORCH_CHECK(
|
|
667
|
+
status == AVSUCCESS,
|
|
668
|
+
"Error in avformat_write_header: ",
|
|
669
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
670
|
+
|
|
671
|
+
AutoAVPacket autoAVPacket;
|
|
672
|
+
int numFrames = static_cast<int>(frames_.sizes()[0]);
|
|
673
|
+
for (int i = 0; i < numFrames; ++i) {
|
|
674
|
+
torch::Tensor currFrame = frames_[i];
|
|
675
|
+
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
|
|
676
|
+
encodeFrame(autoAVPacket, avFrame);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
flushBuffers();
|
|
680
|
+
|
|
681
|
+
status = av_write_trailer(avFormatContext_.get());
|
|
682
|
+
TORCH_CHECK(
|
|
683
|
+
status == AVSUCCESS,
|
|
684
|
+
"Error in av_write_trailer: ",
|
|
685
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
|
|
689
|
+
const torch::Tensor& frame,
|
|
690
|
+
int frameIndex) {
|
|
691
|
+
// Initialize and cache scaling context if it does not exist
|
|
692
|
+
if (!swsContext_) {
|
|
693
|
+
swsContext_.reset(sws_getContext(
|
|
694
|
+
inWidth_,
|
|
695
|
+
inHeight_,
|
|
696
|
+
inPixelFormat_,
|
|
697
|
+
outWidth_,
|
|
698
|
+
outHeight_,
|
|
699
|
+
outPixelFormat_,
|
|
700
|
+
SWS_BICUBIC, // Used by FFmpeg CLI
|
|
701
|
+
nullptr,
|
|
702
|
+
nullptr,
|
|
703
|
+
nullptr));
|
|
704
|
+
TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
UniqueAVFrame avFrame(av_frame_alloc());
|
|
708
|
+
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
|
|
709
|
+
|
|
710
|
+
// Set output frame properties
|
|
711
|
+
avFrame->format = outPixelFormat_;
|
|
712
|
+
avFrame->width = outWidth_;
|
|
713
|
+
avFrame->height = outHeight_;
|
|
714
|
+
avFrame->pts = frameIndex;
|
|
715
|
+
|
|
716
|
+
int status = av_frame_get_buffer(avFrame.get(), 0);
|
|
717
|
+
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");
|
|
718
|
+
|
|
719
|
+
// Need to convert/scale the frame
|
|
720
|
+
// Create temporary frame with input format
|
|
721
|
+
UniqueAVFrame inputFrame(av_frame_alloc());
|
|
722
|
+
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");
|
|
723
|
+
|
|
724
|
+
inputFrame->format = inPixelFormat_;
|
|
725
|
+
inputFrame->width = inWidth_;
|
|
726
|
+
inputFrame->height = inHeight_;
|
|
727
|
+
|
|
728
|
+
uint8_t* tensorData = static_cast<uint8_t*>(frame.data_ptr());
|
|
729
|
+
|
|
730
|
+
// TODO-VideoEncoder: Reorder tensor if in NHWC format
|
|
731
|
+
int channelSize = inHeight_ * inWidth_;
|
|
732
|
+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
|
|
733
|
+
// TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
|
|
734
|
+
inputFrame->data[0] = tensorData + channelSize;
|
|
735
|
+
inputFrame->data[1] = tensorData + (2 * channelSize);
|
|
736
|
+
inputFrame->data[2] = tensorData;
|
|
737
|
+
|
|
738
|
+
inputFrame->linesize[0] = inWidth_;
|
|
739
|
+
inputFrame->linesize[1] = inWidth_;
|
|
740
|
+
inputFrame->linesize[2] = inWidth_;
|
|
741
|
+
|
|
742
|
+
status = sws_scale(
|
|
743
|
+
swsContext_.get(),
|
|
744
|
+
inputFrame->data,
|
|
745
|
+
inputFrame->linesize,
|
|
746
|
+
0,
|
|
747
|
+
inputFrame->height,
|
|
748
|
+
avFrame->data,
|
|
749
|
+
avFrame->linesize);
|
|
750
|
+
TORCH_CHECK(status == outHeight_, "sws_scale failed");
|
|
751
|
+
return avFrame;
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
void VideoEncoder::encodeFrame(
|
|
755
|
+
AutoAVPacket& autoAVPacket,
|
|
756
|
+
const UniqueAVFrame& avFrame) {
|
|
757
|
+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
|
|
758
|
+
TORCH_CHECK(
|
|
759
|
+
status == AVSUCCESS,
|
|
760
|
+
"Error while sending frame: ",
|
|
761
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
762
|
+
|
|
763
|
+
while (status >= 0) {
|
|
764
|
+
ReferenceAVPacket packet(autoAVPacket);
|
|
765
|
+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
|
|
766
|
+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
|
|
767
|
+
if (status == AVERROR_EOF) {
|
|
768
|
+
// Flush remaining buffered packets
|
|
769
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
|
|
770
|
+
TORCH_CHECK(
|
|
771
|
+
status == AVSUCCESS,
|
|
772
|
+
"Failed to flush packet: ",
|
|
773
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
774
|
+
}
|
|
775
|
+
return;
|
|
776
|
+
}
|
|
777
|
+
TORCH_CHECK(
|
|
778
|
+
status >= 0,
|
|
779
|
+
"Error receiving packet: ",
|
|
780
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
781
|
+
|
|
782
|
+
// The code below is borrowed from torchaudio:
|
|
783
|
+
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
|
|
784
|
+
// Setting packet->duration to 1 allows the last frame to be properly
|
|
785
|
+
// encoded, and needs to be set before calling av_packet_rescale_ts.
|
|
786
|
+
if (packet->duration == 0) {
|
|
787
|
+
packet->duration = 1;
|
|
788
|
+
}
|
|
789
|
+
av_packet_rescale_ts(
|
|
790
|
+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
|
|
791
|
+
packet->stream_index = avStream_->index;
|
|
792
|
+
|
|
793
|
+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
|
|
794
|
+
TORCH_CHECK(
|
|
795
|
+
status == AVSUCCESS,
|
|
796
|
+
"Error in av_interleaved_write_frame: ",
|
|
797
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
798
|
+
}
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
void VideoEncoder::flushBuffers() {
|
|
802
|
+
AutoAVPacket autoAVPacket;
|
|
803
|
+
// Send null frame to signal end of input
|
|
804
|
+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
|
|
805
|
+
}
|
|
806
|
+
|
|
514
807
|
} // namespace facebook::torchcodec
|