torchcodec 0.3.0__cp310-cp310-macosx_11_0_arm64.whl → 0.4.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchcodec might be problematic. Click here for more details.
- torchcodec/.dylibs/libc++.1.0.dylib +0 -0
- torchcodec/.dylibs/libpython3.10.dylib +0 -0
- torchcodec/_core/AVIOBytesContext.cpp +68 -1
- torchcodec/_core/AVIOBytesContext.h +24 -2
- torchcodec/_core/AVIOContextHolder.cpp +8 -3
- torchcodec/_core/AVIOContextHolder.h +4 -7
- torchcodec/_core/AVIOFileLikeContext.cpp +1 -1
- torchcodec/_core/CMakeLists.txt +12 -3
- torchcodec/_core/CpuDeviceInterface.cpp +364 -0
- torchcodec/_core/CpuDeviceInterface.h +80 -0
- torchcodec/_core/CudaDeviceInterface.cpp +3 -2
- torchcodec/_core/CudaDeviceInterface.h +1 -0
- torchcodec/_core/DeviceInterface.cpp +20 -29
- torchcodec/_core/DeviceInterface.h +1 -0
- torchcodec/_core/Encoder.cpp +88 -41
- torchcodec/_core/Encoder.h +16 -0
- torchcodec/_core/FFMPEGCommon.cpp +98 -26
- torchcodec/_core/FFMPEGCommon.h +32 -9
- torchcodec/_core/Frame.cpp +32 -0
- torchcodec/_core/Frame.h +71 -0
- torchcodec/_core/SingleStreamDecoder.cpp +91 -414
- torchcodec/_core/SingleStreamDecoder.h +0 -107
- torchcodec/_core/StreamOptions.h +1 -0
- torchcodec/_core/__init__.py +2 -2
- torchcodec/_core/_metadata.py +3 -0
- torchcodec/_core/custom_ops.cpp +35 -38
- torchcodec/_core/ops.py +14 -10
- torchcodec/decoders/_audio_decoder.py +8 -2
- torchcodec/libtorchcodec_custom_ops4.dylib +0 -0
- torchcodec/libtorchcodec_custom_ops5.dylib +0 -0
- torchcodec/libtorchcodec_custom_ops6.dylib +0 -0
- torchcodec/libtorchcodec_custom_ops7.dylib +0 -0
- torchcodec/libtorchcodec_decoder4.dylib +0 -0
- torchcodec/libtorchcodec_decoder5.dylib +0 -0
- torchcodec/libtorchcodec_decoder6.dylib +0 -0
- torchcodec/libtorchcodec_decoder7.dylib +0 -0
- torchcodec/libtorchcodec_pybind_ops4.so +0 -0
- torchcodec/libtorchcodec_pybind_ops5.so +0 -0
- torchcodec/libtorchcodec_pybind_ops6.so +0 -0
- torchcodec/libtorchcodec_pybind_ops7.so +0 -0
- torchcodec/version.py +1 -1
- {torchcodec-0.3.0.dist-info → torchcodec-0.4.0.dist-info}/METADATA +4 -17
- torchcodec-0.4.0.dist-info/RECORD +62 -0
- {torchcodec-0.3.0.dist-info → torchcodec-0.4.0.dist-info}/WHEEL +1 -1
- torchcodec-0.3.0.dist-info/RECORD +0 -59
- {torchcodec-0.3.0.dist-info → torchcodec-0.4.0.dist-info/licenses}/LICENSE +0 -0
- {torchcodec-0.3.0.dist-info → torchcodec-0.4.0.dist-info}/top_level.txt +0 -0
|
Binary file
|
|
Binary file
|
|
@@ -13,7 +13,7 @@ AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize)
|
|
|
13
13
|
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
|
|
14
14
|
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
|
|
15
15
|
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
|
|
16
|
-
createAVIOContext(&read, &seek, &dataContext_);
|
|
16
|
+
createAVIOContext(&read, nullptr, &seek, &dataContext_);
|
|
17
17
|
}
|
|
18
18
|
|
|
19
19
|
// The signature of this function is defined by FFMPEG.
|
|
@@ -67,4 +67,71 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
|
|
|
67
67
|
return ret;
|
|
68
68
|
}
|
|
69
69
|
|
|
70
|
+
AVIOToTensorContext::AVIOToTensorContext()
|
|
71
|
+
: dataContext_{
|
|
72
|
+
torch::empty(
|
|
73
|
+
{AVIOToTensorContext::INITIAL_TENSOR_SIZE},
|
|
74
|
+
{torch::kUInt8}),
|
|
75
|
+
0} {
|
|
76
|
+
createAVIOContext(nullptr, &write, &seek, &dataContext_);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// The signature of this function is defined by FFMPEG.
|
|
80
|
+
int AVIOToTensorContext::write(void* opaque, const uint8_t* buf, int buf_size) {
|
|
81
|
+
auto dataContext = static_cast<DataContext*>(opaque);
|
|
82
|
+
|
|
83
|
+
int64_t bufSize = static_cast<int64_t>(buf_size);
|
|
84
|
+
if (dataContext->current + bufSize > dataContext->outputTensor.numel()) {
|
|
85
|
+
TORCH_CHECK(
|
|
86
|
+
dataContext->outputTensor.numel() * 2 <=
|
|
87
|
+
AVIOToTensorContext::MAX_TENSOR_SIZE,
|
|
88
|
+
"We tried to allocate an output encoded tensor larger than ",
|
|
89
|
+
AVIOToTensorContext::MAX_TENSOR_SIZE,
|
|
90
|
+
" bytes. If you think this should be supported, please report.");
|
|
91
|
+
|
|
92
|
+
// We double the size of the outpout tensor. Calling cat() may not be the
|
|
93
|
+
// most efficient, but it's simple.
|
|
94
|
+
dataContext->outputTensor =
|
|
95
|
+
torch::cat({dataContext->outputTensor, dataContext->outputTensor});
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
TORCH_CHECK(
|
|
99
|
+
dataContext->current + bufSize <= dataContext->outputTensor.numel(),
|
|
100
|
+
"Re-allocation of the output tensor didn't work. ",
|
|
101
|
+
"This should not happen, please report on TorchCodec bug tracker");
|
|
102
|
+
|
|
103
|
+
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
|
|
104
|
+
std::memcpy(outputTensorData + dataContext->current, buf, bufSize);
|
|
105
|
+
dataContext->current += bufSize;
|
|
106
|
+
return buf_size;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// The signature of this function is defined by FFMPEG.
|
|
110
|
+
// Note: This `seek()` implementation is very similar to that of
|
|
111
|
+
// AVIOBytesContext. We could consider merging both classes, or do some kind of
|
|
112
|
+
// refac, but this doesn't seem worth it ATM.
|
|
113
|
+
int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
|
|
114
|
+
auto dataContext = static_cast<DataContext*>(opaque);
|
|
115
|
+
int64_t ret = -1;
|
|
116
|
+
|
|
117
|
+
switch (whence) {
|
|
118
|
+
case AVSEEK_SIZE:
|
|
119
|
+
ret = dataContext->outputTensor.numel();
|
|
120
|
+
break;
|
|
121
|
+
case SEEK_SET:
|
|
122
|
+
dataContext->current = offset;
|
|
123
|
+
ret = offset;
|
|
124
|
+
break;
|
|
125
|
+
default:
|
|
126
|
+
break;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
return ret;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
torch::Tensor AVIOToTensorContext::getOutputTensor() {
|
|
133
|
+
return dataContext_.outputTensor.narrow(
|
|
134
|
+
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);
|
|
135
|
+
}
|
|
136
|
+
|
|
70
137
|
} // namespace facebook::torchcodec
|
|
@@ -6,12 +6,13 @@
|
|
|
6
6
|
|
|
7
7
|
#pragma once
|
|
8
8
|
|
|
9
|
+
#include <torch/types.h>
|
|
9
10
|
#include "src/torchcodec/_core/AVIOContextHolder.h"
|
|
10
11
|
|
|
11
12
|
namespace facebook::torchcodec {
|
|
12
13
|
|
|
13
|
-
//
|
|
14
|
-
// functions then traverse the bytes in memory.
|
|
14
|
+
// For Decoding: enables users to pass in the entire video or audio as bytes.
|
|
15
|
+
// Our read and seek functions then traverse the bytes in memory.
|
|
15
16
|
class AVIOBytesContext : public AVIOContextHolder {
|
|
16
17
|
public:
|
|
17
18
|
explicit AVIOBytesContext(const void* data, int64_t dataSize);
|
|
@@ -29,4 +30,25 @@ class AVIOBytesContext : public AVIOContextHolder {
|
|
|
29
30
|
DataContext dataContext_;
|
|
30
31
|
};
|
|
31
32
|
|
|
33
|
+
// For Encoding: used to encode into an output uint8 (bytes) tensor.
|
|
34
|
+
class AVIOToTensorContext : public AVIOContextHolder {
|
|
35
|
+
public:
|
|
36
|
+
explicit AVIOToTensorContext();
|
|
37
|
+
torch::Tensor getOutputTensor();
|
|
38
|
+
|
|
39
|
+
private:
|
|
40
|
+
struct DataContext {
|
|
41
|
+
torch::Tensor outputTensor;
|
|
42
|
+
int64_t current;
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
static constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10MB
|
|
46
|
+
static constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
|
|
47
|
+
static int write(void* opaque, const uint8_t* buf, int buf_size);
|
|
48
|
+
// We need to expose seek() for some formats like mp3.
|
|
49
|
+
static int64_t seek(void* opaque, int64_t offset, int whence);
|
|
50
|
+
|
|
51
|
+
DataContext dataContext_;
|
|
52
|
+
};
|
|
53
|
+
|
|
32
54
|
} // namespace facebook::torchcodec
|
|
@@ -11,6 +11,7 @@ namespace facebook::torchcodec {
|
|
|
11
11
|
|
|
12
12
|
void AVIOContextHolder::createAVIOContext(
|
|
13
13
|
AVIOReadFunction read,
|
|
14
|
+
AVIOWriteFunction write,
|
|
14
15
|
AVIOSeekFunction seek,
|
|
15
16
|
void* heldData,
|
|
16
17
|
int bufferSize) {
|
|
@@ -22,13 +23,17 @@ void AVIOContextHolder::createAVIOContext(
|
|
|
22
23
|
buffer != nullptr,
|
|
23
24
|
"Failed to allocate buffer of size " + std::to_string(bufferSize));
|
|
24
25
|
|
|
25
|
-
|
|
26
|
+
TORCH_CHECK(
|
|
27
|
+
(seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
|
|
28
|
+
"seek method must be defined, and either write or read must be defined. "
|
|
29
|
+
"But not both!")
|
|
30
|
+
avioContext_.reset(avioAllocContext(
|
|
26
31
|
buffer,
|
|
27
32
|
bufferSize,
|
|
28
|
-
|
|
33
|
+
/*write_flag=*/write != nullptr,
|
|
29
34
|
heldData,
|
|
30
35
|
read,
|
|
31
|
-
|
|
36
|
+
write,
|
|
32
37
|
seek));
|
|
33
38
|
|
|
34
39
|
if (!avioContext_) {
|
|
@@ -19,9 +19,9 @@ namespace facebook::torchcodec {
|
|
|
19
19
|
// freed.
|
|
20
20
|
// 2. It is a base class for AVIOContext specializations. When specializing a
|
|
21
21
|
// AVIOContext, we need to provide four things:
|
|
22
|
-
// 1. A read callback function.
|
|
23
|
-
// 2. A seek callback function.
|
|
24
|
-
// 3. A write callback function
|
|
22
|
+
// 1. A read callback function, for decoding.
|
|
23
|
+
// 2. A seek callback function, for decoding and encoding.
|
|
24
|
+
// 3. A write callback function, for encoding.
|
|
25
25
|
// 4. A pointer to some context object that has the same lifetime as the
|
|
26
26
|
// AVIOContext itself. This context object holds the custom state that
|
|
27
27
|
// tracks the custom behavior of reading, seeking and writing. It is
|
|
@@ -44,13 +44,10 @@ class AVIOContextHolder {
|
|
|
44
44
|
// enforced by having a pure virtual methods, but we don't have any.)
|
|
45
45
|
AVIOContextHolder() = default;
|
|
46
46
|
|
|
47
|
-
// These signatures are defined by FFmpeg.
|
|
48
|
-
using AVIOReadFunction = int (*)(void*, uint8_t*, int);
|
|
49
|
-
using AVIOSeekFunction = int64_t (*)(void*, int64_t, int);
|
|
50
|
-
|
|
51
47
|
// Deriving classes should call this function in their constructor.
|
|
52
48
|
void createAVIOContext(
|
|
53
49
|
AVIOReadFunction read,
|
|
50
|
+
AVIOWriteFunction write,
|
|
54
51
|
AVIOSeekFunction seek,
|
|
55
52
|
void* heldData,
|
|
56
53
|
int bufferSize = defaultBufferSize);
|
|
@@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
|
|
|
23
23
|
py::hasattr(fileLike, "seek"),
|
|
24
24
|
"File like object must implement a seek method.");
|
|
25
25
|
}
|
|
26
|
-
createAVIOContext(&read, &seek, &fileLike_);
|
|
26
|
+
createAVIOContext(&read, nullptr, &seek, &fileLike_);
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
|
torchcodec/_core/CMakeLists.txt
CHANGED
|
@@ -8,7 +8,13 @@ find_package(pybind11 REQUIRED)
|
|
|
8
8
|
find_package(Torch REQUIRED)
|
|
9
9
|
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
if(DEFINED TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR AND TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR)
|
|
12
|
+
set(TORCHCODEC_WERROR_OPTION "")
|
|
13
|
+
else()
|
|
14
|
+
set(TORCHCODEC_WERROR_OPTION "-Werror")
|
|
15
|
+
endif()
|
|
16
|
+
|
|
17
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
|
|
12
18
|
|
|
13
19
|
function(make_torchcodec_sublibrary
|
|
14
20
|
library_name
|
|
@@ -59,8 +65,11 @@ function(make_torchcodec_libraries
|
|
|
59
65
|
set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}")
|
|
60
66
|
set(decoder_sources
|
|
61
67
|
AVIOContextHolder.cpp
|
|
68
|
+
AVIOBytesContext.cpp
|
|
62
69
|
FFMPEGCommon.cpp
|
|
63
|
-
|
|
70
|
+
Frame.cpp
|
|
71
|
+
DeviceInterface.cpp
|
|
72
|
+
CpuDeviceInterface.cpp
|
|
64
73
|
SingleStreamDecoder.cpp
|
|
65
74
|
# TODO: lib name should probably not be "*_decoder*" now that it also
|
|
66
75
|
# contains an encoder
|
|
@@ -148,7 +157,7 @@ function(make_torchcodec_libraries
|
|
|
148
157
|
target_link_options(
|
|
149
158
|
${pybind_ops_library_name}
|
|
150
159
|
PUBLIC
|
|
151
|
-
"
|
|
160
|
+
"LINKER:-undefined,dynamic_lookup"
|
|
152
161
|
)
|
|
153
162
|
|
|
154
163
|
# Install all libraries.
|
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
// All rights reserved.
|
|
3
|
+
//
|
|
4
|
+
// This source code is licensed under the BSD-style license found in the
|
|
5
|
+
// LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
#include "src/torchcodec/_core/CpuDeviceInterface.h"
|
|
8
|
+
|
|
9
|
+
extern "C" {
|
|
10
|
+
#include <libavfilter/buffersink.h>
|
|
11
|
+
#include <libavfilter/buffersrc.h>
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
namespace facebook::torchcodec {
|
|
15
|
+
namespace {
|
|
16
|
+
|
|
17
|
+
static bool g_cpu = registerDeviceInterface(
|
|
18
|
+
torch::kCPU,
|
|
19
|
+
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
|
|
20
|
+
|
|
21
|
+
} // namespace
|
|
22
|
+
|
|
23
|
+
bool CpuDeviceInterface::DecodedFrameContext::operator==(
|
|
24
|
+
const CpuDeviceInterface::DecodedFrameContext& other) {
|
|
25
|
+
return decodedWidth == other.decodedWidth &&
|
|
26
|
+
decodedHeight == other.decodedHeight &&
|
|
27
|
+
decodedFormat == other.decodedFormat &&
|
|
28
|
+
expectedWidth == other.expectedWidth &&
|
|
29
|
+
expectedHeight == other.expectedHeight;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
bool CpuDeviceInterface::DecodedFrameContext::operator!=(
|
|
33
|
+
const CpuDeviceInterface::DecodedFrameContext& other) {
|
|
34
|
+
return !(*this == other);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
|
|
38
|
+
: DeviceInterface(device) {
|
|
39
|
+
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
|
|
40
|
+
if (device_.type() != torch::kCPU) {
|
|
41
|
+
throw std::runtime_error("Unsupported device: " + device_.str());
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
|
|
46
|
+
// Callers may pass a pre-allocated tensor, where the output.data tensor will
|
|
47
|
+
// be stored. This parameter is honored in any case, but it only leads to a
|
|
48
|
+
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
|
|
49
|
+
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
|
|
50
|
+
// found a way to do that with filtegraph.
|
|
51
|
+
// TODO: Figure out whether that's possible!
|
|
52
|
+
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
|
|
53
|
+
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
|
|
54
|
+
void CpuDeviceInterface::convertAVFrameToFrameOutput(
|
|
55
|
+
const VideoStreamOptions& videoStreamOptions,
|
|
56
|
+
const AVRational& timeBase,
|
|
57
|
+
UniqueAVFrame& avFrame,
|
|
58
|
+
FrameOutput& frameOutput,
|
|
59
|
+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
|
|
60
|
+
auto frameDims =
|
|
61
|
+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
|
|
62
|
+
int expectedOutputHeight = frameDims.height;
|
|
63
|
+
int expectedOutputWidth = frameDims.width;
|
|
64
|
+
|
|
65
|
+
if (preAllocatedOutputTensor.has_value()) {
|
|
66
|
+
auto shape = preAllocatedOutputTensor.value().sizes();
|
|
67
|
+
TORCH_CHECK(
|
|
68
|
+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
|
|
69
|
+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
|
|
70
|
+
"Expected pre-allocated tensor of shape ",
|
|
71
|
+
expectedOutputHeight,
|
|
72
|
+
"x",
|
|
73
|
+
expectedOutputWidth,
|
|
74
|
+
"x3, got ",
|
|
75
|
+
shape);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
torch::Tensor outputTensor;
|
|
79
|
+
// We need to compare the current frame context with our previous frame
|
|
80
|
+
// context. If they are different, then we need to re-create our colorspace
|
|
81
|
+
// conversion objects. We create our colorspace conversion objects late so
|
|
82
|
+
// that we don't have to depend on the unreliable metadata in the header.
|
|
83
|
+
// And we sometimes re-create them because it's possible for frame
|
|
84
|
+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
|
|
85
|
+
// conversion objects as much as possible for performance reasons.
|
|
86
|
+
enum AVPixelFormat frameFormat =
|
|
87
|
+
static_cast<enum AVPixelFormat>(avFrame->format);
|
|
88
|
+
auto frameContext = DecodedFrameContext{
|
|
89
|
+
avFrame->width,
|
|
90
|
+
avFrame->height,
|
|
91
|
+
frameFormat,
|
|
92
|
+
avFrame->sample_aspect_ratio,
|
|
93
|
+
expectedOutputWidth,
|
|
94
|
+
expectedOutputHeight};
|
|
95
|
+
|
|
96
|
+
// By default, we want to use swscale for color conversion because it is
|
|
97
|
+
// faster. However, it has width requirements, so we may need to fall back
|
|
98
|
+
// to filtergraph. We also need to respect what was requested from the
|
|
99
|
+
// options; we respect the options unconditionally, so it's possible for
|
|
100
|
+
// swscale's width requirements to be violated. We don't expose the ability to
|
|
101
|
+
// choose color conversion library publicly; we only use this ability
|
|
102
|
+
// internally.
|
|
103
|
+
|
|
104
|
+
// swscale requires widths to be multiples of 32:
|
|
105
|
+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
|
|
106
|
+
// so we fall back to filtergraph if the width is not a multiple of 32.
|
|
107
|
+
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
|
|
108
|
+
? ColorConversionLibrary::SWSCALE
|
|
109
|
+
: ColorConversionLibrary::FILTERGRAPH;
|
|
110
|
+
|
|
111
|
+
ColorConversionLibrary colorConversionLibrary =
|
|
112
|
+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
|
|
113
|
+
|
|
114
|
+
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
|
|
115
|
+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
|
|
116
|
+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
|
|
117
|
+
|
|
118
|
+
if (!swsContext_ || prevFrameContext_ != frameContext) {
|
|
119
|
+
createSwsContext(frameContext, avFrame->colorspace);
|
|
120
|
+
prevFrameContext_ = frameContext;
|
|
121
|
+
}
|
|
122
|
+
int resultHeight =
|
|
123
|
+
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
|
|
124
|
+
// If this check failed, it would mean that the frame wasn't reshaped to
|
|
125
|
+
// the expected height.
|
|
126
|
+
// TODO: Can we do the same check for width?
|
|
127
|
+
TORCH_CHECK(
|
|
128
|
+
resultHeight == expectedOutputHeight,
|
|
129
|
+
"resultHeight != expectedOutputHeight: ",
|
|
130
|
+
resultHeight,
|
|
131
|
+
" != ",
|
|
132
|
+
expectedOutputHeight);
|
|
133
|
+
|
|
134
|
+
frameOutput.data = outputTensor;
|
|
135
|
+
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
|
|
136
|
+
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
|
|
137
|
+
createFilterGraph(frameContext, videoStreamOptions, timeBase);
|
|
138
|
+
prevFrameContext_ = frameContext;
|
|
139
|
+
}
|
|
140
|
+
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
|
|
141
|
+
|
|
142
|
+
// Similarly to above, if this check fails it means the frame wasn't
|
|
143
|
+
// reshaped to its expected dimensions by filtergraph.
|
|
144
|
+
auto shape = outputTensor.sizes();
|
|
145
|
+
TORCH_CHECK(
|
|
146
|
+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
|
|
147
|
+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
|
|
148
|
+
"Expected output tensor of shape ",
|
|
149
|
+
expectedOutputHeight,
|
|
150
|
+
"x",
|
|
151
|
+
expectedOutputWidth,
|
|
152
|
+
"x3, got ",
|
|
153
|
+
shape);
|
|
154
|
+
|
|
155
|
+
if (preAllocatedOutputTensor.has_value()) {
|
|
156
|
+
// We have already validated that preAllocatedOutputTensor and
|
|
157
|
+
// outputTensor have the same shape.
|
|
158
|
+
preAllocatedOutputTensor.value().copy_(outputTensor);
|
|
159
|
+
frameOutput.data = preAllocatedOutputTensor.value();
|
|
160
|
+
} else {
|
|
161
|
+
frameOutput.data = outputTensor;
|
|
162
|
+
}
|
|
163
|
+
} else {
|
|
164
|
+
throw std::runtime_error(
|
|
165
|
+
"Invalid color conversion library: " +
|
|
166
|
+
std::to_string(static_cast<int>(colorConversionLibrary)));
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
|
|
171
|
+
const UniqueAVFrame& avFrame,
|
|
172
|
+
torch::Tensor& outputTensor) {
|
|
173
|
+
uint8_t* pointers[4] = {
|
|
174
|
+
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
|
|
175
|
+
int expectedOutputWidth = outputTensor.sizes()[1];
|
|
176
|
+
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
|
|
177
|
+
int resultHeight = sws_scale(
|
|
178
|
+
swsContext_.get(),
|
|
179
|
+
avFrame->data,
|
|
180
|
+
avFrame->linesize,
|
|
181
|
+
0,
|
|
182
|
+
avFrame->height,
|
|
183
|
+
pointers,
|
|
184
|
+
linesizes);
|
|
185
|
+
return resultHeight;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
|
|
189
|
+
const UniqueAVFrame& avFrame) {
|
|
190
|
+
int status = av_buffersrc_write_frame(
|
|
191
|
+
filterGraphContext_.sourceContext, avFrame.get());
|
|
192
|
+
if (status < AVSUCCESS) {
|
|
193
|
+
throw std::runtime_error("Failed to add frame to buffer source context");
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
UniqueAVFrame filteredAVFrame(av_frame_alloc());
|
|
197
|
+
status = av_buffersink_get_frame(
|
|
198
|
+
filterGraphContext_.sinkContext, filteredAVFrame.get());
|
|
199
|
+
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
|
|
200
|
+
|
|
201
|
+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
|
|
202
|
+
int height = frameDims.height;
|
|
203
|
+
int width = frameDims.width;
|
|
204
|
+
std::vector<int64_t> shape = {height, width, 3};
|
|
205
|
+
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
|
|
206
|
+
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
|
|
207
|
+
auto deleter = [filteredAVFramePtr](void*) {
|
|
208
|
+
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
|
|
209
|
+
};
|
|
210
|
+
return torch::from_blob(
|
|
211
|
+
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
void CpuDeviceInterface::createFilterGraph(
|
|
215
|
+
const DecodedFrameContext& frameContext,
|
|
216
|
+
const VideoStreamOptions& videoStreamOptions,
|
|
217
|
+
const AVRational& timeBase) {
|
|
218
|
+
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
|
|
219
|
+
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
|
|
220
|
+
|
|
221
|
+
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
|
|
222
|
+
filterGraphContext_.filterGraph->nb_threads =
|
|
223
|
+
videoStreamOptions.ffmpegThreadCount.value();
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
|
|
227
|
+
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
|
|
228
|
+
|
|
229
|
+
std::stringstream filterArgs;
|
|
230
|
+
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
|
|
231
|
+
<< frameContext.decodedHeight;
|
|
232
|
+
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
|
|
233
|
+
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
|
|
234
|
+
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
|
|
235
|
+
<< frameContext.decodedAspectRatio.den;
|
|
236
|
+
|
|
237
|
+
int status = avfilter_graph_create_filter(
|
|
238
|
+
&filterGraphContext_.sourceContext,
|
|
239
|
+
buffersrc,
|
|
240
|
+
"in",
|
|
241
|
+
filterArgs.str().c_str(),
|
|
242
|
+
nullptr,
|
|
243
|
+
filterGraphContext_.filterGraph.get());
|
|
244
|
+
if (status < 0) {
|
|
245
|
+
throw std::runtime_error(
|
|
246
|
+
std::string("Failed to create filter graph: ") + filterArgs.str() +
|
|
247
|
+
": " + getFFMPEGErrorStringFromErrorCode(status));
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
status = avfilter_graph_create_filter(
|
|
251
|
+
&filterGraphContext_.sinkContext,
|
|
252
|
+
buffersink,
|
|
253
|
+
"out",
|
|
254
|
+
nullptr,
|
|
255
|
+
nullptr,
|
|
256
|
+
filterGraphContext_.filterGraph.get());
|
|
257
|
+
if (status < 0) {
|
|
258
|
+
throw std::runtime_error(
|
|
259
|
+
"Failed to create filter graph: " +
|
|
260
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
|
|
264
|
+
|
|
265
|
+
status = av_opt_set_int_list(
|
|
266
|
+
filterGraphContext_.sinkContext,
|
|
267
|
+
"pix_fmts",
|
|
268
|
+
pix_fmts,
|
|
269
|
+
AV_PIX_FMT_NONE,
|
|
270
|
+
AV_OPT_SEARCH_CHILDREN);
|
|
271
|
+
if (status < 0) {
|
|
272
|
+
throw std::runtime_error(
|
|
273
|
+
"Failed to set output pixel formats: " +
|
|
274
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
|
|
278
|
+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
|
|
279
|
+
|
|
280
|
+
outputs->name = av_strdup("in");
|
|
281
|
+
outputs->filter_ctx = filterGraphContext_.sourceContext;
|
|
282
|
+
outputs->pad_idx = 0;
|
|
283
|
+
outputs->next = nullptr;
|
|
284
|
+
inputs->name = av_strdup("out");
|
|
285
|
+
inputs->filter_ctx = filterGraphContext_.sinkContext;
|
|
286
|
+
inputs->pad_idx = 0;
|
|
287
|
+
inputs->next = nullptr;
|
|
288
|
+
|
|
289
|
+
std::stringstream description;
|
|
290
|
+
description << "scale=" << frameContext.expectedWidth << ":"
|
|
291
|
+
<< frameContext.expectedHeight;
|
|
292
|
+
description << ":sws_flags=bilinear";
|
|
293
|
+
|
|
294
|
+
AVFilterInOut* outputsTmp = outputs.release();
|
|
295
|
+
AVFilterInOut* inputsTmp = inputs.release();
|
|
296
|
+
status = avfilter_graph_parse_ptr(
|
|
297
|
+
filterGraphContext_.filterGraph.get(),
|
|
298
|
+
description.str().c_str(),
|
|
299
|
+
&inputsTmp,
|
|
300
|
+
&outputsTmp,
|
|
301
|
+
nullptr);
|
|
302
|
+
outputs.reset(outputsTmp);
|
|
303
|
+
inputs.reset(inputsTmp);
|
|
304
|
+
if (status < 0) {
|
|
305
|
+
throw std::runtime_error(
|
|
306
|
+
"Failed to parse filter description: " +
|
|
307
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
status =
|
|
311
|
+
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
|
|
312
|
+
if (status < 0) {
|
|
313
|
+
throw std::runtime_error(
|
|
314
|
+
"Failed to configure filter graph: " +
|
|
315
|
+
getFFMPEGErrorStringFromErrorCode(status));
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
void CpuDeviceInterface::createSwsContext(
|
|
320
|
+
const DecodedFrameContext& frameContext,
|
|
321
|
+
const enum AVColorSpace colorspace) {
|
|
322
|
+
SwsContext* swsContext = sws_getContext(
|
|
323
|
+
frameContext.decodedWidth,
|
|
324
|
+
frameContext.decodedHeight,
|
|
325
|
+
frameContext.decodedFormat,
|
|
326
|
+
frameContext.expectedWidth,
|
|
327
|
+
frameContext.expectedHeight,
|
|
328
|
+
AV_PIX_FMT_RGB24,
|
|
329
|
+
SWS_BILINEAR,
|
|
330
|
+
nullptr,
|
|
331
|
+
nullptr,
|
|
332
|
+
nullptr);
|
|
333
|
+
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
|
|
334
|
+
|
|
335
|
+
int* invTable = nullptr;
|
|
336
|
+
int* table = nullptr;
|
|
337
|
+
int srcRange, dstRange, brightness, contrast, saturation;
|
|
338
|
+
int ret = sws_getColorspaceDetails(
|
|
339
|
+
swsContext,
|
|
340
|
+
&invTable,
|
|
341
|
+
&srcRange,
|
|
342
|
+
&table,
|
|
343
|
+
&dstRange,
|
|
344
|
+
&brightness,
|
|
345
|
+
&contrast,
|
|
346
|
+
&saturation);
|
|
347
|
+
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
|
|
348
|
+
|
|
349
|
+
const int* colorspaceTable = sws_getCoefficients(colorspace);
|
|
350
|
+
ret = sws_setColorspaceDetails(
|
|
351
|
+
swsContext,
|
|
352
|
+
colorspaceTable,
|
|
353
|
+
srcRange,
|
|
354
|
+
colorspaceTable,
|
|
355
|
+
dstRange,
|
|
356
|
+
brightness,
|
|
357
|
+
contrast,
|
|
358
|
+
saturation);
|
|
359
|
+
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
|
|
360
|
+
|
|
361
|
+
swsContext_.reset(swsContext);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
} // namespace facebook::torchcodec
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
// All rights reserved.
|
|
3
|
+
//
|
|
4
|
+
// This source code is licensed under the BSD-style license found in the
|
|
5
|
+
// LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
#pragma once
|
|
8
|
+
|
|
9
|
+
#include "src/torchcodec/_core/DeviceInterface.h"
|
|
10
|
+
#include "src/torchcodec/_core/FFMPEGCommon.h"
|
|
11
|
+
|
|
12
|
+
namespace facebook::torchcodec {
|
|
13
|
+
|
|
14
|
+
class CpuDeviceInterface : public DeviceInterface {
|
|
15
|
+
public:
|
|
16
|
+
CpuDeviceInterface(const torch::Device& device);
|
|
17
|
+
|
|
18
|
+
virtual ~CpuDeviceInterface() {}
|
|
19
|
+
|
|
20
|
+
std::optional<const AVCodec*> findCodec(
|
|
21
|
+
[[maybe_unused]] const AVCodecID& codecId) override {
|
|
22
|
+
return std::nullopt;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
void initializeContext(
|
|
26
|
+
[[maybe_unused]] AVCodecContext* codecContext) override {}
|
|
27
|
+
|
|
28
|
+
void convertAVFrameToFrameOutput(
|
|
29
|
+
const VideoStreamOptions& videoStreamOptions,
|
|
30
|
+
const AVRational& timeBase,
|
|
31
|
+
UniqueAVFrame& avFrame,
|
|
32
|
+
FrameOutput& frameOutput,
|
|
33
|
+
std::optional<torch::Tensor> preAllocatedOutputTensor =
|
|
34
|
+
std::nullopt) override;
|
|
35
|
+
|
|
36
|
+
private:
|
|
37
|
+
int convertAVFrameToTensorUsingSwsScale(
|
|
38
|
+
const UniqueAVFrame& avFrame,
|
|
39
|
+
torch::Tensor& outputTensor);
|
|
40
|
+
|
|
41
|
+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
|
|
42
|
+
const UniqueAVFrame& avFrame);
|
|
43
|
+
|
|
44
|
+
struct FilterGraphContext {
|
|
45
|
+
UniqueAVFilterGraph filterGraph;
|
|
46
|
+
AVFilterContext* sourceContext = nullptr;
|
|
47
|
+
AVFilterContext* sinkContext = nullptr;
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
struct DecodedFrameContext {
|
|
51
|
+
int decodedWidth;
|
|
52
|
+
int decodedHeight;
|
|
53
|
+
AVPixelFormat decodedFormat;
|
|
54
|
+
AVRational decodedAspectRatio;
|
|
55
|
+
int expectedWidth;
|
|
56
|
+
int expectedHeight;
|
|
57
|
+
bool operator==(const DecodedFrameContext&);
|
|
58
|
+
bool operator!=(const DecodedFrameContext&);
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
void createSwsContext(
|
|
62
|
+
const DecodedFrameContext& frameContext,
|
|
63
|
+
const enum AVColorSpace colorspace);
|
|
64
|
+
|
|
65
|
+
void createFilterGraph(
|
|
66
|
+
const DecodedFrameContext& frameContext,
|
|
67
|
+
const VideoStreamOptions& videoStreamOptions,
|
|
68
|
+
const AVRational& timeBase);
|
|
69
|
+
|
|
70
|
+
// color-conversion fields. Only one of FilterGraphContext and
|
|
71
|
+
// UniqueSwsContext should be non-null.
|
|
72
|
+
FilterGraphContext filterGraphContext_;
|
|
73
|
+
UniqueSwsContext swsContext_;
|
|
74
|
+
|
|
75
|
+
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
|
|
76
|
+
// be created before decoding a new frame.
|
|
77
|
+
DecodedFrameContext prevFrameContext_;
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
} // namespace facebook::torchcodec
|