torchcodec 0.6.0__cp39-cp39-macosx_11_0_arm64.whl → 0.7.0__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (53) hide show
  1. torchcodec/.dylibs/libc++.1.0.dylib +0 -0
  2. torchcodec/.dylibs/libpython3.9.dylib +0 -0
  3. torchcodec/_core/AVIOContextHolder.cpp +10 -5
  4. torchcodec/_core/AVIOContextHolder.h +1 -0
  5. torchcodec/_core/AVIOFileLikeContext.cpp +23 -5
  6. torchcodec/_core/AVIOFileLikeContext.h +2 -1
  7. torchcodec/_core/AVIOTensorContext.cpp +4 -2
  8. torchcodec/_core/CMakeLists.txt +57 -18
  9. torchcodec/_core/Cache.h +138 -0
  10. torchcodec/_core/CpuDeviceInterface.cpp +55 -149
  11. torchcodec/_core/CpuDeviceInterface.h +13 -23
  12. torchcodec/_core/CudaDeviceInterface.cpp +310 -78
  13. torchcodec/_core/CudaDeviceInterface.h +3 -1
  14. torchcodec/_core/Encoder.cpp +13 -5
  15. torchcodec/_core/Encoder.h +6 -4
  16. torchcodec/_core/FFMPEGCommon.cpp +9 -1
  17. torchcodec/_core/FFMPEGCommon.h +15 -0
  18. torchcodec/_core/FilterGraph.cpp +142 -0
  19. torchcodec/_core/FilterGraph.h +45 -0
  20. torchcodec/_core/SingleStreamDecoder.cpp +32 -32
  21. torchcodec/_core/ValidationUtils.cpp +35 -0
  22. torchcodec/_core/ValidationUtils.h +21 -0
  23. torchcodec/_core/__init__.py +1 -0
  24. torchcodec/_core/custom_ops.cpp +23 -23
  25. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +81 -7
  26. torchcodec/_core/ops.py +56 -0
  27. torchcodec/_core/pybind_ops.cpp +39 -1
  28. torchcodec/_internally_replaced_utils.py +9 -6
  29. torchcodec/decoders/_audio_decoder.py +3 -1
  30. torchcodec/decoders/_decoder_utils.py +1 -1
  31. torchcodec/decoders/_video_decoder.py +88 -29
  32. torchcodec/encoders/_audio_encoder.py +41 -1
  33. torchcodec/libtorchcodec_core4.dylib +0 -0
  34. torchcodec/libtorchcodec_core5.dylib +0 -0
  35. torchcodec/libtorchcodec_core6.dylib +0 -0
  36. torchcodec/libtorchcodec_core7.dylib +0 -0
  37. torchcodec/libtorchcodec_custom_ops4.dylib +0 -0
  38. torchcodec/libtorchcodec_custom_ops5.dylib +0 -0
  39. torchcodec/libtorchcodec_custom_ops6.dylib +0 -0
  40. torchcodec/libtorchcodec_custom_ops7.dylib +0 -0
  41. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  42. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  43. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  44. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  45. torchcodec/samplers/_index_based.py +2 -0
  46. torchcodec/samplers/_time_based.py +2 -0
  47. torchcodec/version.py +1 -1
  48. {torchcodec-0.6.0.dist-info → torchcodec-0.7.0.dist-info}/METADATA +8 -35
  49. torchcodec-0.7.0.dist-info/RECORD +69 -0
  50. torchcodec-0.6.0.dist-info/RECORD +0 -64
  51. {torchcodec-0.6.0.dist-info → torchcodec-0.7.0.dist-info}/WHEEL +0 -0
  52. {torchcodec-0.6.0.dist-info → torchcodec-0.7.0.dist-info}/licenses/LICENSE +0 -0
  53. {torchcodec-0.6.0.dist-info → torchcodec-0.7.0.dist-info}/top_level.txt +0 -0
Binary file
Binary file
@@ -14,6 +14,7 @@ void AVIOContextHolder::createAVIOContext(
14
14
  AVIOWriteFunction write,
15
15
  AVIOSeekFunction seek,
16
16
  void* heldData,
17
+ bool isForWriting,
17
18
  int bufferSize) {
18
19
  TORCH_CHECK(
19
20
  bufferSize > 0,
@@ -23,14 +24,18 @@ void AVIOContextHolder::createAVIOContext(
23
24
  buffer != nullptr,
24
25
  "Failed to allocate buffer of size " + std::to_string(bufferSize));
25
26
 
26
- TORCH_CHECK(
27
- (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
28
- "seek method must be defined, and either write or read must be defined. "
29
- "But not both!")
27
+ TORCH_CHECK(seek != nullptr, "seek method must be defined");
28
+
29
+ if (isForWriting) {
30
+ TORCH_CHECK(write != nullptr, "write method must be defined for writing");
31
+ } else {
32
+ TORCH_CHECK(read != nullptr, "read method must be defined for reading");
33
+ }
34
+
30
35
  avioContext_.reset(avioAllocContext(
31
36
  buffer,
32
37
  bufferSize,
33
- /*write_flag=*/write != nullptr,
38
+ /*write_flag=*/isForWriting,
34
39
  heldData,
35
40
  read,
36
41
  write,
@@ -51,6 +51,7 @@ class AVIOContextHolder {
51
51
  AVIOWriteFunction write,
52
52
  AVIOSeekFunction seek,
53
53
  void* heldData,
54
+ bool isForWriting,
54
55
  int bufferSize = defaultBufferSize);
55
56
 
56
57
  private:
@@ -9,21 +9,31 @@
9
9
 
10
10
  namespace facebook::torchcodec {
11
11
 
12
- AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
12
+ AVIOFileLikeContext::AVIOFileLikeContext(
13
+ const py::object& fileLike,
14
+ bool isForWriting)
13
15
  : fileLike_{UniquePyObject(new py::object(fileLike))} {
14
16
  {
15
17
  // TODO: Is it necessary to acquire the GIL here? Is it maybe even
16
18
  // harmful? At the moment, this is only called from within a pybind
17
19
  // function, and pybind guarantees we have the GIL.
18
20
  py::gil_scoped_acquire gil;
19
- TORCH_CHECK(
20
- py::hasattr(fileLike, "read"),
21
- "File like object must implement a read method.");
21
+
22
+ if (isForWriting) {
23
+ TORCH_CHECK(
24
+ py::hasattr(fileLike, "write"),
25
+ "File like object must implement a write method for writing.");
26
+ } else {
27
+ TORCH_CHECK(
28
+ py::hasattr(fileLike, "read"),
29
+ "File like object must implement a read method for reading.");
30
+ }
31
+
22
32
  TORCH_CHECK(
23
33
  py::hasattr(fileLike, "seek"),
24
34
  "File like object must implement a seek method.");
25
35
  }
26
- createAVIOContext(&read, nullptr, &seek, &fileLike_);
36
+ createAVIOContext(&read, &write, &seek, &fileLike_, isForWriting);
27
37
  }
28
38
 
29
39
  int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
@@ -77,4 +87,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) {
77
87
  return py::cast<int64_t>((*fileLike)->attr("seek")(offset, whence));
78
88
  }
79
89
 
90
+ int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) {
91
+ auto fileLike = static_cast<UniquePyObject*>(opaque);
92
+ py::gil_scoped_acquire gil;
93
+ py::bytes bytes_obj(reinterpret_cast<const char*>(buf), buf_size);
94
+
95
+ return py::cast<int>((*fileLike)->attr("write")(bytes_obj));
96
+ }
97
+
80
98
  } // namespace facebook::torchcodec
@@ -19,11 +19,12 @@ namespace facebook::torchcodec {
19
19
  // and seek calls back up to the methods on the Python object.
20
20
  class AVIOFileLikeContext : public AVIOContextHolder {
21
21
  public:
22
- explicit AVIOFileLikeContext(py::object fileLike);
22
+ explicit AVIOFileLikeContext(const py::object& fileLike, bool isForWriting);
23
23
 
24
24
  private:
25
25
  static int read(void* opaque, uint8_t* buf, int buf_size);
26
26
  static int64_t seek(void* opaque, int64_t offset, int whence);
27
+ static int write(void* opaque, const uint8_t* buf, int buf_size);
27
28
 
28
29
  // Note that we dynamically allocate the Python object because we need to
29
30
  // strictly control when its destructor is called. We must hold the GIL
@@ -105,12 +105,14 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
105
105
  TORCH_CHECK(data.numel() > 0, "data must not be empty");
106
106
  TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107
107
  TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
108
- createAVIOContext(&read, nullptr, &seek, &tensorContext_);
108
+ createAVIOContext(
109
+ &read, nullptr, &seek, &tensorContext_, /*isForWriting=*/false);
109
110
  }
110
111
 
111
112
  AVIOToTensorContext::AVIOToTensorContext()
112
113
  : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
113
- createAVIOContext(nullptr, &write, &seek, &tensorContext_);
114
+ createAVIOContext(
115
+ nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
114
116
  }
115
117
 
116
118
  torch::Tensor AVIOToTensorContext::getOutputTensor() {
@@ -11,10 +11,29 @@ find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
11
11
  if(DEFINED TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR AND TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR)
12
12
  set(TORCHCODEC_WERROR_OPTION "")
13
13
  else()
14
- set(TORCHCODEC_WERROR_OPTION "-Werror")
14
+ if (WIN32)
15
+ # TODO set warnings as errors on Windows as well.
16
+ # set(TORCHCODEC_WERROR_OPTION "/WX")
17
+ else()
18
+ set(TORCHCODEC_WERROR_OPTION "-Werror")
19
+ endif()
20
+ endif()
21
+
22
+ if (WIN32)
23
+ # Avoid warnings about non-ASCII characters in source files.
24
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819")
25
+ # Important for when we add Windows CUDA: exporting all symbols is limited to
26
+ # 65535 symbols, which (apparently) will not work for CUDA.
27
+ # https://github.com/pytorch/pytorch/pull/3650
28
+ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
29
+ endif()
30
+
31
+ if (WIN32)
32
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
33
+ else()
34
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
15
35
  endif()
16
36
 
17
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
18
37
 
19
38
  function(make_torchcodec_sublibrary
20
39
  library_name
@@ -39,6 +58,7 @@ function(make_torchcodec_sublibrary
39
58
  PUBLIC
40
59
  ${library_dependencies}
41
60
  )
61
+
42
62
  endfunction()
43
63
 
44
64
  function(make_torchcodec_libraries
@@ -50,16 +70,17 @@ function(make_torchcodec_libraries
50
70
  #
51
71
  # 1. libtorchcodec_coreN.{ext}: Base library which contains the
52
72
  # implementation of VideoDecoder and everything VideoDecoder needs. On
53
- # Linux, {ext} is so. On Mac, it is dylib.
73
+ # Linux, {ext} is so. On Mac, it is dylib. On Windows it's dll.
54
74
  #
55
75
  # 2. libtorchcodec_custom_opsN.{ext}: Implementation of the PyTorch custom
56
76
  # ops. Depends on libtorchcodec_coreN.{ext}. On Linux, {ext} is so.
57
- # On Mac, it is dylib.
77
+ # On Mac, it is dylib. On Windows it's dll.
58
78
  #
59
79
  # 3. libtorchcodec_pybind_opsN.{ext}: Implementation of the pybind11 ops. We
60
80
  # keep these separate from the PyTorch custom ops because we have to
61
81
  # load these libraries separately on the Python side. Depends on
62
- # libtorchcodec_coreN.{ext}. On BOTH Linux and Mac {ext} is so.
82
+ # libtorchcodec_coreN.{ext}. On BOTH Linux and Mac {ext} is so. On
83
+ # Windows, it's pyd.
63
84
 
64
85
  # 1. Create libtorchcodec_coreN.{ext}.
65
86
  set(core_library_name "libtorchcodec_core${ffmpeg_major_version}")
@@ -67,11 +88,13 @@ function(make_torchcodec_libraries
67
88
  AVIOContextHolder.cpp
68
89
  AVIOTensorContext.cpp
69
90
  FFMPEGCommon.cpp
91
+ FilterGraph.cpp
70
92
  Frame.cpp
71
93
  DeviceInterface.cpp
72
94
  CpuDeviceInterface.cpp
73
95
  SingleStreamDecoder.cpp
74
96
  Encoder.cpp
97
+ ValidationUtils.cpp
75
98
  )
76
99
 
77
100
  if(ENABLE_CUDA)
@@ -140,15 +163,26 @@ function(make_torchcodec_libraries
140
163
  "${pybind_ops_sources}"
141
164
  "${pybind_ops_dependencies}"
142
165
  )
166
+
167
+ if(WIN32)
168
+ # On Windows, we need to set the suffix to .pyd so that Python can
169
+ # import the shared library as a module. Just setting the MODULE type
170
+ # isn't enough.
171
+ set_target_properties(${pybind_ops_library_name} PROPERTIES SUFFIX ".pyd")
172
+ endif()
173
+
143
174
  # pybind11 limits the visibility of symbols in the shared library to prevent
144
175
  # stray initialization of py::objects. The rest of the object code must
145
176
  # match. See:
146
177
  # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes
147
- target_compile_options(
148
- ${pybind_ops_library_name}
149
- PUBLIC
150
- "-fvisibility=hidden"
151
- )
178
+ if(NOT WIN32)
179
+ target_compile_options(
180
+ ${pybind_ops_library_name}
181
+ PUBLIC
182
+ "-fvisibility=hidden"
183
+ )
184
+ endif()
185
+
152
186
  # The value we use here must match the value we return from
153
187
  # _get_pybind_ops_module_name() on the Python side. If the values do not
154
188
  # match, then we will be unable to import the C++ shared library as a
@@ -158,14 +192,17 @@ function(make_torchcodec_libraries
158
192
  PRIVATE
159
193
  PYBIND_OPS_MODULE_NAME=core_pybind_ops
160
194
  )
161
- # If we don't make sure this flag is set, we run into segfauls at import
162
- # time on Mac. See:
163
- # https://github.com/pybind/pybind11/issues/3907#issuecomment-1170412764
164
- target_link_options(
165
- ${pybind_ops_library_name}
166
- PUBLIC
167
- "LINKER:-undefined,dynamic_lookup"
168
- )
195
+
196
+ if(APPLE)
197
+ # If we don't make sure this flag is set, we run into segfauls at import
198
+ # time on Mac. See:
199
+ # https://github.com/pybind/pybind11/issues/3907#issuecomment-1170412764
200
+ target_link_options(
201
+ ${pybind_ops_library_name}
202
+ PUBLIC
203
+ "LINKER:-undefined,dynamic_lookup"
204
+ )
205
+ endif()
169
206
 
170
207
  # Install all libraries.
171
208
  set(
@@ -183,7 +220,9 @@ function(make_torchcodec_libraries
183
220
  install(
184
221
  TARGETS ${all_libraries}
185
222
  LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}
223
+ RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX} # For Windows
186
224
  )
225
+
187
226
  endfunction()
188
227
 
189
228
  if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
@@ -0,0 +1,138 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <torch/types.h>
10
+ #include <memory>
11
+ #include <mutex>
12
+
13
+ namespace facebook::torchcodec {
14
+
15
+ // This header defines simple cache class primitives to store reusable objects
16
+ // across TorchCodec stream instances. Intended usage is to store hardware
17
+ // contexts creation of which is expensive. The cache mechanism is as follows:
18
+ // 1. 'PerGpuCache' provides a dynamic cache with the specified maximum capacity
19
+ // for the given number of GPUs.
20
+ // 2. When stream object (e.g. SingleStreamDecoder) is destoyed cachable object
21
+ // must be released to the cache. Cache will accept the object if it is not
22
+ // full.
23
+ // 3. When stream object (e.g. SingleStreamDecoder) is created cachable object
24
+ // must be first queried from the cache. If the cache is empty then new
25
+ // object must be created.
26
+
27
+ template <typename T, typename D = std::default_delete<T>>
28
+ class Cache {
29
+ public:
30
+ using element_type = std::unique_ptr<T, D>;
31
+
32
+ explicit Cache(int capacity) : capacity_(capacity) {}
33
+
34
+ // Adds an object to the cache if the cache has capacity. Returns true
35
+ // if object was added and false otherwise.
36
+ bool addIfCacheHasCapacity(element_type&& obj);
37
+
38
+ // Returns an object from the cache. Cache does not hold a reference
39
+ // to the object after this call.
40
+ element_type get();
41
+
42
+ private:
43
+ int capacity_;
44
+ std::mutex mutex_;
45
+ std::vector<element_type> cache_;
46
+ };
47
+
48
+ template <typename T, typename D>
49
+ bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
50
+ std::scoped_lock lock(mutex_);
51
+ if (capacity_ >= 0 && cache_.size() >= static_cast<size_t>(capacity_)) {
52
+ return false;
53
+ }
54
+ cache_.push_back(std::move(obj));
55
+ return true;
56
+ }
57
+
58
+ template <typename T, typename D>
59
+ typename Cache<T, D>::element_type Cache<T, D>::get() {
60
+ std::scoped_lock lock(mutex_);
61
+ if (cache_.empty()) {
62
+ return nullptr;
63
+ }
64
+
65
+ element_type obj = std::move(cache_.back());
66
+ cache_.pop_back();
67
+ return obj;
68
+ }
69
+
70
+ template <typename T, typename D = std::default_delete<T>>
71
+ class PerGpuCache {
72
+ public:
73
+ using element_type = typename Cache<T, D>::element_type;
74
+
75
+ // Initializes 'maxGpus' number of caches. Each cache can hold no
76
+ // more than 'capacity' items. If 'capacity' <0 cache size is unlimited.
77
+ PerGpuCache(int maxGpus, int capacity) {
78
+ TORCH_CHECK(maxGpus > 0, "maxGpus for PerGpuCache must be >0");
79
+ for (int i = 0; i < maxGpus; ++i) {
80
+ cache_.emplace_back(std::make_unique<Cache<T, D>>(capacity));
81
+ }
82
+ }
83
+
84
+ // Adds an object to the specified device cache if the cache has
85
+ // capacity. Returns true if object was added and false otherwise.
86
+ bool addIfCacheHasCapacity(const torch::Device& device, element_type&& obj);
87
+
88
+ // Returns an object from the cache of the specified device. Cache
89
+ // does not hold a reference to the object after this call.
90
+ element_type get(const torch::Device& device);
91
+
92
+ private:
93
+ // 'Cache' class implementation contains mutex which makes it non-movable
94
+ // and non-copyable, so we need to wrap it in std::unique_ptr.
95
+ std::vector<std::unique_ptr<Cache<T, D>>> cache_;
96
+ };
97
+
98
+ // Note: this function is inline for convenience, not performance. Because the
99
+ // rest of this file is template functions, they must all be defined in this
100
+ // header. This function is not a template function, and should, in principle,
101
+ // be defined in a .cpp file to preserve the One Definition Rule. That's
102
+ // annoying for such a small amount of code, so we just inline it. If this file
103
+ // grows, and there are more such functions, we should break them out into a
104
+ // .cpp file.
105
+ inline torch::DeviceIndex getNonNegativeDeviceIndex(
106
+ const torch::Device& device) {
107
+ torch::DeviceIndex deviceIndex = device.index();
108
+ // For single GPU machines libtorch returns -1 for the device index. So for
109
+ // that case we set the device index to 0. That's used in per-gpu cache
110
+ // implementation and during initialization of CUDA and FFmpeg contexts
111
+ // which require non negative indices.
112
+ deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
113
+ TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
114
+ return deviceIndex;
115
+ }
116
+
117
+ template <typename T, typename D>
118
+ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
119
+ const torch::Device& device,
120
+ element_type&& obj) {
121
+ torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
122
+ TORCH_CHECK(
123
+ static_cast<size_t>(deviceIndex) < cache_.size(),
124
+ "Device index out of range");
125
+ return cache_[deviceIndex]->addIfCacheHasCapacity(std::move(obj));
126
+ }
127
+
128
+ template <typename T, typename D>
129
+ typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
130
+ const torch::Device& device) {
131
+ torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
132
+ TORCH_CHECK(
133
+ static_cast<size_t>(deviceIndex) < cache_.size(),
134
+ "Device index out of range");
135
+ return cache_[deviceIndex]->get();
136
+ }
137
+
138
+ } // namespace facebook::torchcodec
@@ -6,11 +6,6 @@
6
6
 
7
7
  #include "src/torchcodec/_core/CpuDeviceInterface.h"
8
8
 
9
- extern "C" {
10
- #include <libavfilter/buffersink.h>
11
- #include <libavfilter/buffersrc.h>
12
- }
13
-
14
9
  namespace facebook::torchcodec {
15
10
  namespace {
16
11
 
@@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface(
20
15
 
21
16
  } // namespace
22
17
 
23
- bool CpuDeviceInterface::DecodedFrameContext::operator==(
24
- const CpuDeviceInterface::DecodedFrameContext& other) {
25
- return decodedWidth == other.decodedWidth &&
26
- decodedHeight == other.decodedHeight &&
27
- decodedFormat == other.decodedFormat &&
28
- expectedWidth == other.expectedWidth &&
29
- expectedHeight == other.expectedHeight;
18
+ bool CpuDeviceInterface::SwsFrameContext::operator==(
19
+ const CpuDeviceInterface::SwsFrameContext& other) const {
20
+ return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
21
+ inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
22
+ outputHeight == other.outputHeight;
30
23
  }
31
24
 
32
- bool CpuDeviceInterface::DecodedFrameContext::operator!=(
33
- const CpuDeviceInterface::DecodedFrameContext& other) {
25
+ bool CpuDeviceInterface::SwsFrameContext::operator!=(
26
+ const CpuDeviceInterface::SwsFrameContext& other) const {
34
27
  return !(*this == other);
35
28
  }
36
29
 
@@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
75
68
  }
76
69
 
77
70
  torch::Tensor outputTensor;
78
- // We need to compare the current frame context with our previous frame
79
- // context. If they are different, then we need to re-create our colorspace
80
- // conversion objects. We create our colorspace conversion objects late so
81
- // that we don't have to depend on the unreliable metadata in the header.
82
- // And we sometimes re-create them because it's possible for frame
83
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
84
- // conversion objects as much as possible for performance reasons.
85
71
  enum AVPixelFormat frameFormat =
86
72
  static_cast<enum AVPixelFormat>(avFrame->format);
87
- auto frameContext = DecodedFrameContext{
88
- avFrame->width,
89
- avFrame->height,
90
- frameFormat,
91
- avFrame->sample_aspect_ratio,
92
- expectedOutputWidth,
93
- expectedOutputHeight};
94
73
 
95
74
  // By default, we want to use swscale for color conversion because it is
96
75
  // faster. However, it has width requirements, so we may need to fall back
@@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
111
90
  videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
112
91
 
113
92
  if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93
+ // We need to compare the current frame context with our previous frame
94
+ // context. If they are different, then we need to re-create our colorspace
95
+ // conversion objects. We create our colorspace conversion objects late so
96
+ // that we don't have to depend on the unreliable metadata in the header.
97
+ // And we sometimes re-create them because it's possible for frame
98
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
99
+ // conversion objects as much as possible for performance reasons.
100
+ SwsFrameContext swsFrameContext;
101
+
102
+ swsFrameContext.inputWidth = avFrame->width;
103
+ swsFrameContext.inputHeight = avFrame->height;
104
+ swsFrameContext.inputFormat = frameFormat;
105
+ swsFrameContext.outputWidth = expectedOutputWidth;
106
+ swsFrameContext.outputHeight = expectedOutputHeight;
107
+
114
108
  outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
115
109
  expectedOutputHeight, expectedOutputWidth, torch::kCPU));
116
110
 
117
- if (!swsContext_ || prevFrameContext_ != frameContext) {
118
- createSwsContext(frameContext, avFrame->colorspace);
119
- prevFrameContext_ = frameContext;
111
+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
112
+ createSwsContext(swsFrameContext, avFrame->colorspace);
113
+ prevSwsFrameContext_ = swsFrameContext;
120
114
  }
121
115
  int resultHeight =
122
116
  convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
@@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132
126
 
133
127
  frameOutput.data = outputTensor;
134
128
  } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135
- if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136
- createFilterGraph(frameContext, videoStreamOptions, timeBase);
137
- prevFrameContext_ = frameContext;
129
+ // See comment above in swscale branch about the filterGraphContext_
130
+ // creation. creation
131
+ FiltersContext filtersContext;
132
+
133
+ filtersContext.inputWidth = avFrame->width;
134
+ filtersContext.inputHeight = avFrame->height;
135
+ filtersContext.inputFormat = frameFormat;
136
+ filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137
+ filtersContext.outputWidth = expectedOutputWidth;
138
+ filtersContext.outputHeight = expectedOutputHeight;
139
+ filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140
+ filtersContext.timeBase = timeBase;
141
+
142
+ std::stringstream filters;
143
+ filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144
+ filters << ":sws_flags=bilinear";
145
+
146
+ filtersContext.filtergraphStr = filters.str();
147
+
148
+ if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149
+ filterGraphContext_ =
150
+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151
+ prevFiltersContext_ = std::move(filtersContext);
138
152
  }
139
153
  outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
140
154
 
@@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187
201
 
188
202
  torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189
203
  const UniqueAVFrame& avFrame) {
190
- int status = av_buffersrc_write_frame(
191
- filterGraphContext_.sourceContext, avFrame.get());
192
- TORCH_CHECK(
193
- status >= AVSUCCESS, "Failed to add frame to buffer source context");
204
+ UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
194
205
 
195
- UniqueAVFrame filteredAVFrame(av_frame_alloc());
196
- status = av_buffersink_get_frame(
197
- filterGraphContext_.sinkContext, filteredAVFrame.get());
198
206
  TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
199
207
 
200
208
  auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
@@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210
218
  filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
211
219
  }
212
220
 
213
- void CpuDeviceInterface::createFilterGraph(
214
- const DecodedFrameContext& frameContext,
215
- const VideoStreamOptions& videoStreamOptions,
216
- const AVRational& timeBase) {
217
- filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
218
- TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
219
-
220
- if (videoStreamOptions.ffmpegThreadCount.has_value()) {
221
- filterGraphContext_.filterGraph->nb_threads =
222
- videoStreamOptions.ffmpegThreadCount.value();
223
- }
224
-
225
- const AVFilter* buffersrc = avfilter_get_by_name("buffer");
226
- const AVFilter* buffersink = avfilter_get_by_name("buffersink");
227
-
228
- std::stringstream filterArgs;
229
- filterArgs << "video_size=" << frameContext.decodedWidth << "x"
230
- << frameContext.decodedHeight;
231
- filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
232
- filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
233
- filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
234
- << frameContext.decodedAspectRatio.den;
235
-
236
- int status = avfilter_graph_create_filter(
237
- &filterGraphContext_.sourceContext,
238
- buffersrc,
239
- "in",
240
- filterArgs.str().c_str(),
241
- nullptr,
242
- filterGraphContext_.filterGraph.get());
243
- TORCH_CHECK(
244
- status >= 0,
245
- "Failed to create filter graph: ",
246
- filterArgs.str(),
247
- ": ",
248
- getFFMPEGErrorStringFromErrorCode(status));
249
-
250
- status = avfilter_graph_create_filter(
251
- &filterGraphContext_.sinkContext,
252
- buffersink,
253
- "out",
254
- nullptr,
255
- nullptr,
256
- filterGraphContext_.filterGraph.get());
257
- TORCH_CHECK(
258
- status >= 0,
259
- "Failed to create filter graph: ",
260
- getFFMPEGErrorStringFromErrorCode(status));
261
-
262
- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263
-
264
- status = av_opt_set_int_list(
265
- filterGraphContext_.sinkContext,
266
- "pix_fmts",
267
- pix_fmts,
268
- AV_PIX_FMT_NONE,
269
- AV_OPT_SEARCH_CHILDREN);
270
- TORCH_CHECK(
271
- status >= 0,
272
- "Failed to set output pixel formats: ",
273
- getFFMPEGErrorStringFromErrorCode(status));
274
-
275
- UniqueAVFilterInOut outputs(avfilter_inout_alloc());
276
- UniqueAVFilterInOut inputs(avfilter_inout_alloc());
277
-
278
- outputs->name = av_strdup("in");
279
- outputs->filter_ctx = filterGraphContext_.sourceContext;
280
- outputs->pad_idx = 0;
281
- outputs->next = nullptr;
282
- inputs->name = av_strdup("out");
283
- inputs->filter_ctx = filterGraphContext_.sinkContext;
284
- inputs->pad_idx = 0;
285
- inputs->next = nullptr;
286
-
287
- std::stringstream description;
288
- description << "scale=" << frameContext.expectedWidth << ":"
289
- << frameContext.expectedHeight;
290
- description << ":sws_flags=bilinear";
291
-
292
- AVFilterInOut* outputsTmp = outputs.release();
293
- AVFilterInOut* inputsTmp = inputs.release();
294
- status = avfilter_graph_parse_ptr(
295
- filterGraphContext_.filterGraph.get(),
296
- description.str().c_str(),
297
- &inputsTmp,
298
- &outputsTmp,
299
- nullptr);
300
- outputs.reset(outputsTmp);
301
- inputs.reset(inputsTmp);
302
- TORCH_CHECK(
303
- status >= 0,
304
- "Failed to parse filter description: ",
305
- getFFMPEGErrorStringFromErrorCode(status));
306
-
307
- status =
308
- avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
309
- TORCH_CHECK(
310
- status >= 0,
311
- "Failed to configure filter graph: ",
312
- getFFMPEGErrorStringFromErrorCode(status));
313
- }
314
-
315
221
  void CpuDeviceInterface::createSwsContext(
316
- const DecodedFrameContext& frameContext,
222
+ const SwsFrameContext& swsFrameContext,
317
223
  const enum AVColorSpace colorspace) {
318
224
  SwsContext* swsContext = sws_getContext(
319
- frameContext.decodedWidth,
320
- frameContext.decodedHeight,
321
- frameContext.decodedFormat,
322
- frameContext.expectedWidth,
323
- frameContext.expectedHeight,
225
+ swsFrameContext.inputWidth,
226
+ swsFrameContext.inputHeight,
227
+ swsFrameContext.inputFormat,
228
+ swsFrameContext.outputWidth,
229
+ swsFrameContext.outputHeight,
324
230
  AV_PIX_FMT_RGB24,
325
231
  SWS_BILINEAR,
326
232
  nullptr,