torchcodec 0.7.0__cp313-cp313-win_amd64.whl → 0.8.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchcodec might be problematic. Click here for more details.
- torchcodec/_core/AVIOTensorContext.cpp +23 -16
- torchcodec/_core/AVIOTensorContext.h +2 -1
- torchcodec/_core/BetaCudaDeviceInterface.cpp +718 -0
- torchcodec/_core/BetaCudaDeviceInterface.h +193 -0
- torchcodec/_core/CMakeLists.txt +18 -3
- torchcodec/_core/CUDACommon.cpp +330 -0
- torchcodec/_core/CUDACommon.h +51 -0
- torchcodec/_core/Cache.h +6 -20
- torchcodec/_core/CpuDeviceInterface.cpp +195 -108
- torchcodec/_core/CpuDeviceInterface.h +84 -19
- torchcodec/_core/CudaDeviceInterface.cpp +227 -376
- torchcodec/_core/CudaDeviceInterface.h +38 -6
- torchcodec/_core/DeviceInterface.cpp +57 -19
- torchcodec/_core/DeviceInterface.h +97 -16
- torchcodec/_core/Encoder.cpp +346 -9
- torchcodec/_core/Encoder.h +62 -1
- torchcodec/_core/FFMPEGCommon.cpp +190 -3
- torchcodec/_core/FFMPEGCommon.h +27 -1
- torchcodec/_core/FilterGraph.cpp +30 -22
- torchcodec/_core/FilterGraph.h +15 -1
- torchcodec/_core/Frame.cpp +22 -7
- torchcodec/_core/Frame.h +15 -61
- torchcodec/_core/Metadata.h +2 -2
- torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
- torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
- torchcodec/_core/NVDECCache.cpp +60 -0
- torchcodec/_core/NVDECCache.h +102 -0
- torchcodec/_core/SingleStreamDecoder.cpp +196 -201
- torchcodec/_core/SingleStreamDecoder.h +42 -15
- torchcodec/_core/StreamOptions.h +16 -6
- torchcodec/_core/Transform.cpp +87 -0
- torchcodec/_core/Transform.h +84 -0
- torchcodec/_core/__init__.py +4 -0
- torchcodec/_core/custom_ops.cpp +257 -32
- torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
- torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
- torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
- torchcodec/_core/ops.py +147 -44
- torchcodec/_core/pybind_ops.cpp +22 -59
- torchcodec/_samplers/video_clip_sampler.py +7 -19
- torchcodec/decoders/__init__.py +1 -0
- torchcodec/decoders/_decoder_utils.py +61 -1
- torchcodec/decoders/_video_decoder.py +46 -20
- torchcodec/libtorchcodec_core4.dll +0 -0
- torchcodec/libtorchcodec_core5.dll +0 -0
- torchcodec/libtorchcodec_core6.dll +0 -0
- torchcodec/libtorchcodec_core7.dll +0 -0
- torchcodec/libtorchcodec_core8.dll +0 -0
- torchcodec/libtorchcodec_custom_ops4.dll +0 -0
- torchcodec/libtorchcodec_custom_ops5.dll +0 -0
- torchcodec/libtorchcodec_custom_ops6.dll +0 -0
- torchcodec/libtorchcodec_custom_ops7.dll +0 -0
- torchcodec/libtorchcodec_custom_ops8.dll +0 -0
- torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
- torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
- torchcodec/samplers/_time_based.py +8 -0
- torchcodec/version.py +1 -1
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/METADATA +29 -16
- torchcodec-0.8.1.dist-info/RECORD +82 -0
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/WHEEL +1 -1
- torchcodec-0.7.0.dist-info/RECORD +0 -67
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
// All rights reserved.
|
|
3
|
+
//
|
|
4
|
+
// This source code is licensed under the BSD-style license found in the
|
|
5
|
+
// LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
// BETA CUDA device interface that provides direct control over NVDEC
|
|
8
|
+
// while keeping FFmpeg for demuxing. A lot of the logic, particularly the use
|
|
9
|
+
// of a cache for the decoders, is inspired by DALI's implementation which is
|
|
10
|
+
// APACHE 2.0:
|
|
11
|
+
// https://github.com/NVIDIA/DALI/blob/c7539676a24a8e9e99a6e8665e277363c5445259/dali/operators/video/frames_decoder_gpu.cc#L1
|
|
12
|
+
//
|
|
13
|
+
// NVDEC / NVCUVID docs:
|
|
14
|
+
// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#using-nvidia-video-decoder-nvdecode-api
|
|
15
|
+
|
|
16
|
+
#pragma once
|
|
17
|
+
|
|
18
|
+
#include "src/torchcodec/_core/CUDACommon.h"
|
|
19
|
+
#include "src/torchcodec/_core/Cache.h"
|
|
20
|
+
#include "src/torchcodec/_core/DeviceInterface.h"
|
|
21
|
+
#include "src/torchcodec/_core/FFMPEGCommon.h"
|
|
22
|
+
#include "src/torchcodec/_core/NVDECCache.h"
|
|
23
|
+
|
|
24
|
+
#include <map>
|
|
25
|
+
#include <memory>
|
|
26
|
+
#include <mutex>
|
|
27
|
+
#include <queue>
|
|
28
|
+
#include <unordered_map>
|
|
29
|
+
#include <vector>
|
|
30
|
+
|
|
31
|
+
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
|
|
32
|
+
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
|
|
33
|
+
|
|
34
|
+
namespace facebook::torchcodec {
|
|
35
|
+
|
|
36
|
+
class BetaCudaDeviceInterface : public DeviceInterface {
|
|
37
|
+
public:
|
|
38
|
+
explicit BetaCudaDeviceInterface(const torch::Device& device);
|
|
39
|
+
virtual ~BetaCudaDeviceInterface();
|
|
40
|
+
|
|
41
|
+
void initialize(
|
|
42
|
+
const AVStream* avStream,
|
|
43
|
+
const UniqueDecodingAVFormatContext& avFormatCtx,
|
|
44
|
+
const SharedAVCodecContext& codecContext) override;
|
|
45
|
+
|
|
46
|
+
void convertAVFrameToFrameOutput(
|
|
47
|
+
UniqueAVFrame& avFrame,
|
|
48
|
+
FrameOutput& frameOutput,
|
|
49
|
+
std::optional<torch::Tensor> preAllocatedOutputTensor =
|
|
50
|
+
std::nullopt) override;
|
|
51
|
+
|
|
52
|
+
int sendPacket(ReferenceAVPacket& packet) override;
|
|
53
|
+
int sendEOFPacket() override;
|
|
54
|
+
int receiveFrame(UniqueAVFrame& avFrame) override;
|
|
55
|
+
void flush() override;
|
|
56
|
+
|
|
57
|
+
// NVDEC callback functions (must be public for C callbacks)
|
|
58
|
+
int streamPropertyChange(CUVIDEOFORMAT* videoFormat);
|
|
59
|
+
int frameReadyForDecoding(CUVIDPICPARAMS* picParams);
|
|
60
|
+
int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo);
|
|
61
|
+
|
|
62
|
+
std::string getDetails() override;
|
|
63
|
+
|
|
64
|
+
private:
|
|
65
|
+
int sendCuvidPacket(CUVIDSOURCEDATAPACKET& cuvidPacket);
|
|
66
|
+
|
|
67
|
+
void initializeBSF(
|
|
68
|
+
const AVCodecParameters* codecPar,
|
|
69
|
+
const UniqueDecodingAVFormatContext& avFormatCtx);
|
|
70
|
+
// Apply bitstream filter, returns filtered packet or original if no filter
|
|
71
|
+
// needed.
|
|
72
|
+
ReferenceAVPacket& applyBSF(
|
|
73
|
+
ReferenceAVPacket& packet,
|
|
74
|
+
ReferenceAVPacket& filteredPacket);
|
|
75
|
+
|
|
76
|
+
CUdeviceptr previouslyMappedFrame_ = 0;
|
|
77
|
+
void unmapPreviousFrame();
|
|
78
|
+
|
|
79
|
+
UniqueAVFrame convertCudaFrameToAVFrame(
|
|
80
|
+
CUdeviceptr framePtr,
|
|
81
|
+
unsigned int pitch,
|
|
82
|
+
const CUVIDPARSERDISPINFO& dispInfo);
|
|
83
|
+
|
|
84
|
+
CUvideoparser videoParser_ = nullptr;
|
|
85
|
+
UniqueCUvideodecoder decoder_;
|
|
86
|
+
CUVIDEOFORMAT videoFormat_ = {};
|
|
87
|
+
|
|
88
|
+
std::queue<CUVIDPARSERDISPINFO> readyFrames_;
|
|
89
|
+
|
|
90
|
+
bool eofSent_ = false;
|
|
91
|
+
|
|
92
|
+
AVRational timeBase_ = {0, 1};
|
|
93
|
+
AVRational frameRateAvgFromFFmpeg_ = {0, 1};
|
|
94
|
+
|
|
95
|
+
UniqueAVBSFContext bitstreamFilter_;
|
|
96
|
+
|
|
97
|
+
// NPP context for color conversion
|
|
98
|
+
UniqueNppContext nppCtx_;
|
|
99
|
+
|
|
100
|
+
std::unique_ptr<DeviceInterface> cpuFallback_;
|
|
101
|
+
bool nvcuvidAvailable_ = false;
|
|
102
|
+
};
|
|
103
|
+
|
|
104
|
+
} // namespace facebook::torchcodec
|
|
105
|
+
|
|
106
|
+
/* clang-format off */
|
|
107
|
+
// Note: [General design, sendPacket, receiveFrame, frame ordering and NVCUVID callbacks]
|
|
108
|
+
//
|
|
109
|
+
// At a high level, this decoding interface mimics the FFmpeg send/receive
|
|
110
|
+
// architecture:
|
|
111
|
+
// - sendPacket(AVPacket) sends an AVPacket from the FFmpeg demuxer to the
|
|
112
|
+
// NVCUVID parser.
|
|
113
|
+
// - receiveFrame(AVFrame) is a non-blocking call:
|
|
114
|
+
// - if a frame is ready **in display order**, it must return it. By display
|
|
115
|
+
// order, we mean that receiveFrame() must return frames with increasing pts
|
|
116
|
+
// values when called successively.
|
|
117
|
+
// - if no frame is ready, it must return AVERROR(EAGAIN) to indicate the
|
|
118
|
+
// caller should send more packets.
|
|
119
|
+
//
|
|
120
|
+
// The rest of this note assumes you have a reasonable level of familiarity with
|
|
121
|
+
// the sendPacket/receiveFrame calling pattern. If you don't, look up the core
|
|
122
|
+
// decoding loop in SingleVideoDecoder.
|
|
123
|
+
//
|
|
124
|
+
// The frame re-ordering problem:
|
|
125
|
+
// Depending on the codec and on the encoding parameters, a packet from a video
|
|
126
|
+
// stream may contain exactly one frame, more than one frame, or a fraction of a
|
|
127
|
+
// frame. And, there may be non-linear frame dependencies because of B-frames,
|
|
128
|
+
// which need both past *and* future frames to be decoded. Consider the
|
|
129
|
+
// following stream, with frames presented in display order: I0 B1 P2 B3 P4 ...
|
|
130
|
+
// - I0 is an I-frame (also key frame, can be decoded independently)
|
|
131
|
+
// - B1 is a B-frame (bi-directional) which needs both I0 and P2 to be decoded
|
|
132
|
+
// - P2 is a P-frame (predicted frame) which only needs I0 to be decodec.
|
|
133
|
+
//
|
|
134
|
+
// Because B1 needs both I0 and P2 to be properly decoded, the decode order
|
|
135
|
+
// (packet order), defined by the encoder, must be: I0 P2 B1 P4 B3 ... which is
|
|
136
|
+
// different from the display order.
|
|
137
|
+
//
|
|
138
|
+
// SendPacket(AVPacket)'s job is just to pass down the packet to the NVCUVID
|
|
139
|
+
// parser by calling cuvidParseVideoData(packet). When
|
|
140
|
+
// cuvidParseVideoData(packet) is called, it may trigger callbacks,
|
|
141
|
+
// particularly:
|
|
142
|
+
// - streamPropertyChange(videoFormat): triggered once at the start of the
|
|
143
|
+
// stream, and possibly later if the stream properties change (e.g.
|
|
144
|
+
// resolution).
|
|
145
|
+
// - frameReadyForDecoding(picParams)): triggered **in decode order** when the
|
|
146
|
+
// parser has accumulated enough data to decode a frame. We send that frame to
|
|
147
|
+
// the NVDEC hardware for **async** decoding.
|
|
148
|
+
// - frameReadyInDisplayOrder(dispInfo)): triggered **in display order** when a
|
|
149
|
+
// frame is ready to be "displayed" (returned). At that point, the parser also
|
|
150
|
+
// gives us the pts of that frame. We store (a reference to) that frame in a
|
|
151
|
+
// FIFO queue: readyFrames_.
|
|
152
|
+
//
|
|
153
|
+
// When receiveFrame(AVFrame) is called, if readyFrames_ is not empty, we pop
|
|
154
|
+
// the front of the queue, which is the next frame in display order, and map it
|
|
155
|
+
// to an AVFrame by calling cuvidMapVideoFrame(). If readyFrames_ is empty we
|
|
156
|
+
// return EAGAIN to indicate the caller should send more packets.
|
|
157
|
+
//
|
|
158
|
+
// There is potentially a small inefficiency due to the callback design: in
|
|
159
|
+
// order for us to know that a frame is ready in display order, we need the
|
|
160
|
+
// frameReadyInDisplayOrder callback to be triggered. This can only happen
|
|
161
|
+
// within cuvidParseVideoData(packet) in sendPacket(). This means there may be
|
|
162
|
+
// the following sequence of calls:
|
|
163
|
+
//
|
|
164
|
+
// sendPacket(relevantAVPacket)
|
|
165
|
+
// cuvidParseVideoData(relevantAVPacket)
|
|
166
|
+
// frameReadyForDecoding()
|
|
167
|
+
// cuvidDecodePicture() Send frame to NVDEC for async decoding
|
|
168
|
+
//
|
|
169
|
+
// receiveFrame() -> EAGAIN Frame is potentially already decoded
|
|
170
|
+
// and could be returned, but we don't
|
|
171
|
+
// know because frameReadyInDisplayOrder
|
|
172
|
+
// hasn't been triggered yet. We'll only
|
|
173
|
+
// know after sending another,
|
|
174
|
+
// potentially irrelevant packet.
|
|
175
|
+
//
|
|
176
|
+
// sendPacket(irrelevantAVPacket)
|
|
177
|
+
// cuvidParseVideoData(irrelevantAVPacket)
|
|
178
|
+
// frameReadyInDisplayOrder() Only now do we know that our target
|
|
179
|
+
// frame is ready.
|
|
180
|
+
//
|
|
181
|
+
// receiveFrame() return target frame
|
|
182
|
+
//
|
|
183
|
+
// How much this matters in practice is unclear, but probably negligible in
|
|
184
|
+
// general. Particularly when frames are decoded consecutively anyway, the
|
|
185
|
+
// "irrelevantPacket" is actually relevant for a future target frame.
|
|
186
|
+
//
|
|
187
|
+
// Note that the alternative is to *not* rely on the frameReadyInDisplayOrder
|
|
188
|
+
// callback. It's technically possible, but it would mean we now have to solve
|
|
189
|
+
// two hard, *codec-dependent* problems that the callback was solving for us:
|
|
190
|
+
// - we have to guess the frame's pts ourselves
|
|
191
|
+
// - we have to re-order the frames ourselves to preserve display order.
|
|
192
|
+
//
|
|
193
|
+
/* clang-format on */
|
torchcodec/_core/CMakeLists.txt
CHANGED
|
@@ -95,10 +95,11 @@ function(make_torchcodec_libraries
|
|
|
95
95
|
SingleStreamDecoder.cpp
|
|
96
96
|
Encoder.cpp
|
|
97
97
|
ValidationUtils.cpp
|
|
98
|
+
Transform.cpp
|
|
98
99
|
)
|
|
99
100
|
|
|
100
101
|
if(ENABLE_CUDA)
|
|
101
|
-
list(APPEND core_sources CudaDeviceInterface.cpp)
|
|
102
|
+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp)
|
|
102
103
|
endif()
|
|
103
104
|
|
|
104
105
|
set(core_library_dependencies
|
|
@@ -175,12 +176,23 @@ function(make_torchcodec_libraries
|
|
|
175
176
|
# stray initialization of py::objects. The rest of the object code must
|
|
176
177
|
# match. See:
|
|
177
178
|
# https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes
|
|
178
|
-
|
|
179
|
+
# We have to do this for both pybind_ops and custom_ops because they include
|
|
180
|
+
# some of the same headers.
|
|
181
|
+
#
|
|
182
|
+
# Note that this is Linux only. It's not necessary on Windows, and on Mac
|
|
183
|
+
# hiding visibility can actually break dyanmic casts across share libraries
|
|
184
|
+
# because the type infos don't get exported.
|
|
185
|
+
if(LINUX)
|
|
179
186
|
target_compile_options(
|
|
180
187
|
${pybind_ops_library_name}
|
|
181
188
|
PUBLIC
|
|
182
189
|
"-fvisibility=hidden"
|
|
183
190
|
)
|
|
191
|
+
target_compile_options(
|
|
192
|
+
${custom_ops_library_name}
|
|
193
|
+
PUBLIC
|
|
194
|
+
"-fvisibility=hidden"
|
|
195
|
+
)
|
|
184
196
|
endif()
|
|
185
197
|
|
|
186
198
|
# The value we use here must match the value we return from
|
|
@@ -233,11 +245,12 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
|
|
|
233
245
|
you still need a different FFmpeg to be installed for run time!"
|
|
234
246
|
)
|
|
235
247
|
|
|
236
|
-
# This will expose the ffmpeg4, ffmpeg5, ffmpeg6, and
|
|
248
|
+
# This will expose the ffmpeg4, ffmpeg5, ffmpeg6, ffmpeg7, and ffmpeg8 targets
|
|
237
249
|
include(
|
|
238
250
|
${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake
|
|
239
251
|
)
|
|
240
252
|
|
|
253
|
+
make_torchcodec_libraries(8 ffmpeg8)
|
|
241
254
|
make_torchcodec_libraries(7 ffmpeg7)
|
|
242
255
|
make_torchcodec_libraries(6 ffmpeg6)
|
|
243
256
|
make_torchcodec_libraries(4 ffmpeg4)
|
|
@@ -274,6 +287,8 @@ else()
|
|
|
274
287
|
set(ffmpeg_major_version "6")
|
|
275
288
|
elseif (${libavcodec_major_version} STREQUAL "61")
|
|
276
289
|
set(ffmpeg_major_version "7")
|
|
290
|
+
elseif (${libavcodec_major_version} STREQUAL "62")
|
|
291
|
+
set(ffmpeg_major_version "8")
|
|
277
292
|
else()
|
|
278
293
|
message(
|
|
279
294
|
FATAL_ERROR
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
// All rights reserved.
|
|
3
|
+
//
|
|
4
|
+
// This source code is licensed under the BSD-style license found in the
|
|
5
|
+
// LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
#include "src/torchcodec/_core/CUDACommon.h"
|
|
8
|
+
#include "src/torchcodec/_core/Cache.h" // for PerGpuCache
|
|
9
|
+
|
|
10
|
+
namespace facebook::torchcodec {
|
|
11
|
+
|
|
12
|
+
namespace {
|
|
13
|
+
|
|
14
|
+
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
|
|
15
|
+
// Set to a positive number to have a cache of that size.
|
|
16
|
+
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
|
|
17
|
+
|
|
18
|
+
PerGpuCache<NppStreamContext> g_cached_npp_ctxs(
|
|
19
|
+
MAX_CUDA_GPUS,
|
|
20
|
+
MAX_CONTEXTS_PER_GPU_IN_CACHE);
|
|
21
|
+
|
|
22
|
+
} // namespace
|
|
23
|
+
|
|
24
|
+
void initializeCudaContextWithPytorch(const torch::Device& device) {
|
|
25
|
+
// It is important for pytorch itself to create the cuda context. If ffmpeg
|
|
26
|
+
// creates the context it may not be compatible with pytorch.
|
|
27
|
+
// This is a dummy tensor to initialize the cuda context.
|
|
28
|
+
torch::Tensor dummyTensorForCudaInitialization = torch::zeros(
|
|
29
|
+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
/* clang-format off */
|
|
33
|
+
// Note: [YUV -> RGB Color Conversion, color space and color range]
|
|
34
|
+
//
|
|
35
|
+
// The frames we get from the decoder (FFmpeg decoder, or NVCUVID) are in YUV
|
|
36
|
+
// format. We need to convert them to RGB. This note attempts to describe this
|
|
37
|
+
// process. There may be some inaccuracies and approximations that experts will
|
|
38
|
+
// notice, but our goal is only to provide a good enough understanding of the
|
|
39
|
+
// process for torchcodec developers to implement and maintain it.
|
|
40
|
+
// On CPU, filtergraph and swscale handle everything for us. With CUDA, we have
|
|
41
|
+
// to do a lot of the heavy lifting ourselves.
|
|
42
|
+
//
|
|
43
|
+
// Color space and color range
|
|
44
|
+
// ---------------------------
|
|
45
|
+
// Two main characteristics of a frame will affect the conversion process:
|
|
46
|
+
// 1. Color space: This basically defines what YUV values correspond to which
|
|
47
|
+
// physical wavelength. No need to go into details here,the point is that
|
|
48
|
+
// videos can come in different color spaces, the most common ones being
|
|
49
|
+
// BT.601 and BT.709, but there are others.
|
|
50
|
+
// In FFmpeg this is represented with AVColorSpace:
|
|
51
|
+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#aff71a069509a1ad3ff54d53a1c894c85
|
|
52
|
+
// 2. Color range: This defines the range of YUV values. There is:
|
|
53
|
+
// - full range, also called PC range: AVCOL_RANGE_JPEG
|
|
54
|
+
// - and the "limited" range, also called studio or TV range: AVCOL_RANGE_MPEG
|
|
55
|
+
// https://ffmpeg.org/doxygen/4.0/pixfmt_8h.html#a3da0bf691418bc22c4bcbe6583ad589a
|
|
56
|
+
//
|
|
57
|
+
// Color space and color range are independent concepts, so we can have a BT.709
|
|
58
|
+
// with full range, and another one with limited range. Same for BT.601.
|
|
59
|
+
//
|
|
60
|
+
// In the first version of this note we'll focus on the full color range. It
|
|
61
|
+
// will later be updated to account for the limited range.
|
|
62
|
+
//
|
|
63
|
+
// Color conversion matrix
|
|
64
|
+
// -----------------------
|
|
65
|
+
// YUV -> RGB conversion is defined as the reverse process of the RGB -> YUV,
|
|
66
|
+
// So this is where we'll start.
|
|
67
|
+
// At the core of a RGB -> YUV conversion are the "luma coefficients", which are
|
|
68
|
+
// specific to a given color space and defined by the color space standard. In
|
|
69
|
+
// FFmpeg they can be found here:
|
|
70
|
+
// https://github.com/FFmpeg/FFmpeg/blob/7d606ef0ccf2946a4a21ab1ec23486cadc21864b/libavutil/csp.c#L46-L56
|
|
71
|
+
//
|
|
72
|
+
// For example, the BT.709 coefficients are: kr=0.2126, kg=0.7152, kb=0.0722
|
|
73
|
+
// Coefficients must sum to 1.
|
|
74
|
+
//
|
|
75
|
+
// Conventionally Y is in [0, 1] range, and U and V are in [-0.5, 0.5] range
|
|
76
|
+
// (that's mathematically, in practice they are represented in integer range).
|
|
77
|
+
// The conversion is defined as:
|
|
78
|
+
// https://en.wikipedia.org/wiki/YCbCr#R'G'B'_to_Y%E2%80%B2PbPr
|
|
79
|
+
// Y = kr*R + kg*G + kb*B
|
|
80
|
+
// U = (B - Y) * 0.5 / (1 - kb) = (B - Y) / u_scale where u_scale = 2 * (1 - kb)
|
|
81
|
+
// V = (R - Y) * 0.5 / (1 - kr) = (R - Y) / v_scale where v_scale = 2 * (1 - kr)
|
|
82
|
+
//
|
|
83
|
+
// Putting all this into matrix form, we get:
|
|
84
|
+
// [Y] = [kr kg kb ] [R]
|
|
85
|
+
// [U] [-kr/u_scale -kg/u_scale (1-kb)/u_scale] [G]
|
|
86
|
+
// [V] [(1-kr)/v_scale -kg/v_scale -kb)/v_scale ] [B]
|
|
87
|
+
//
|
|
88
|
+
//
|
|
89
|
+
// Now, to convert YUV to RGB, we just need to invert this matrix:
|
|
90
|
+
// ```py
|
|
91
|
+
// import torch
|
|
92
|
+
// kr, kg, kb = 0.2126, 0.7152, 0.0722 # BT.709 luma coefficients
|
|
93
|
+
// u_scale = 2 * (1 - kb)
|
|
94
|
+
// v_scale = 2 * (1 - kr)
|
|
95
|
+
//
|
|
96
|
+
// rgb_to_yuv = torch.tensor([
|
|
97
|
+
// [kr, kg, kb],
|
|
98
|
+
// [-kr/u_scale, -kg/u_scale, (1-kb)/u_scale],
|
|
99
|
+
// [(1-kr)/v_scale, -kg/v_scale, -kb/v_scale]
|
|
100
|
+
// ])
|
|
101
|
+
//
|
|
102
|
+
// yuv_to_rgb_full = torch.linalg.inv(rgb_to_yuv)
|
|
103
|
+
// print("YUV->RGB matrix (Full Range):")
|
|
104
|
+
// print(yuv_to_rgb_full)
|
|
105
|
+
// ```
|
|
106
|
+
// And we get:
|
|
107
|
+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00],
|
|
108
|
+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01],
|
|
109
|
+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09]])
|
|
110
|
+
//
|
|
111
|
+
// Which matches https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
|
|
112
|
+
//
|
|
113
|
+
// Color conversion in NPP
|
|
114
|
+
// -----------------------
|
|
115
|
+
// https://docs.nvidia.com/cuda/npp/image_color_conversion.html.
|
|
116
|
+
//
|
|
117
|
+
// NPP provides different ways to convert YUV to RGB:
|
|
118
|
+
// - pre-defined color conversion functions like
|
|
119
|
+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx and nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx
|
|
120
|
+
// which are for BT.709 limited and full range, respectively.
|
|
121
|
+
// - generic color conversion functions that accept a custom color conversion
|
|
122
|
+
// matrix, called ColorTwist, like nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx
|
|
123
|
+
//
|
|
124
|
+
// We use the pre-defined functions or the color twist functions depending on
|
|
125
|
+
// which one we find to be closer to the CPU results.
|
|
126
|
+
//
|
|
127
|
+
// The color twist functionality is *partially* described in a section named
|
|
128
|
+
// "YUVToRGBColorTwist". Importantly:
|
|
129
|
+
//
|
|
130
|
+
// - The `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` function takes the YUV data
|
|
131
|
+
// and the color-conversion matrix as input. The function itself and the
|
|
132
|
+
// matrix assume different ranges for YUV values:
|
|
133
|
+
// - The **matrix coefficient** must assume that Y is in [0, 1] and U,V are in
|
|
134
|
+
// [-0.5, 0.5]. That's how we defined our matrix above.
|
|
135
|
+
// - The function `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` however expects all
|
|
136
|
+
// of the input Y, U, V to be in [0, 255]. That's how the data comes out of
|
|
137
|
+
// the decoder.
|
|
138
|
+
// - But *internally*, `nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx` needs U and V to
|
|
139
|
+
// be centered around 0, i.e. in [-128, 127]. So we need to apply a -128
|
|
140
|
+
// offset to U and V. Y doesn't need to be offset. The offset can be applied
|
|
141
|
+
// by adding a 4th column to the matrix.
|
|
142
|
+
//
|
|
143
|
+
//
|
|
144
|
+
// So our conversion matrix becomes the following, with new offset column:
|
|
145
|
+
// tensor([[ 1.0000e+00, -3.3142e-09, 1.5748e+00, 0]
|
|
146
|
+
// [ 1.0000e+00, -1.8732e-01, -4.6812e-01, -128]
|
|
147
|
+
// [ 1.0000e+00, 1.8556e+00, 4.6231e-09 , -128]])
|
|
148
|
+
//
|
|
149
|
+
// And that's what we need to pass for BT701, full range.
|
|
150
|
+
/* clang-format on */
|
|
151
|
+
|
|
152
|
+
// BT.709 full range color conversion matrix for YUV to RGB conversion.
|
|
153
|
+
// See Note [YUV -> RGB Color Conversion, color space and color range]
|
|
154
|
+
const Npp32f bt709FullRangeColorTwist[3][4] = {
|
|
155
|
+
{1.0f, 0.0f, 1.5748f, 0.0f},
|
|
156
|
+
{1.0f, -0.187324273f, -0.468124273f, -128.0f},
|
|
157
|
+
{1.0f, 1.8556f, 0.0f, -128.0f}};
|
|
158
|
+
|
|
159
|
+
torch::Tensor convertNV12FrameToRGB(
|
|
160
|
+
UniqueAVFrame& avFrame,
|
|
161
|
+
const torch::Device& device,
|
|
162
|
+
const UniqueNppContext& nppCtx,
|
|
163
|
+
at::cuda::CUDAStream nvdecStream,
|
|
164
|
+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
|
|
165
|
+
auto frameDims = FrameDims(avFrame->height, avFrame->width);
|
|
166
|
+
torch::Tensor dst;
|
|
167
|
+
if (preAllocatedOutputTensor.has_value()) {
|
|
168
|
+
dst = preAllocatedOutputTensor.value();
|
|
169
|
+
} else {
|
|
170
|
+
dst = allocateEmptyHWCTensor(frameDims, device);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
// We need to make sure NVDEC has finished decoding a frame before
|
|
174
|
+
// color-converting it with NPP.
|
|
175
|
+
// So we make the NPP stream wait for NVDEC to finish.
|
|
176
|
+
at::cuda::CUDAStream nppStream =
|
|
177
|
+
at::cuda::getCurrentCUDAStream(device.index());
|
|
178
|
+
at::cuda::CUDAEvent nvdecDoneEvent;
|
|
179
|
+
nvdecDoneEvent.record(nvdecStream);
|
|
180
|
+
nvdecDoneEvent.block(nppStream);
|
|
181
|
+
|
|
182
|
+
nppCtx->hStream = nppStream.stream();
|
|
183
|
+
cudaError_t err = cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags);
|
|
184
|
+
TORCH_CHECK(
|
|
185
|
+
err == cudaSuccess,
|
|
186
|
+
"cudaStreamGetFlags failed: ",
|
|
187
|
+
cudaGetErrorString(err));
|
|
188
|
+
|
|
189
|
+
NppiSize oSizeROI = {frameDims.width, frameDims.height};
|
|
190
|
+
Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]};
|
|
191
|
+
|
|
192
|
+
NppStatus status;
|
|
193
|
+
|
|
194
|
+
// For background, see
|
|
195
|
+
// Note [YUV -> RGB Color Conversion, color space and color range]
|
|
196
|
+
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
|
|
197
|
+
if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) {
|
|
198
|
+
// NPP provides a pre-defined color conversion function for BT.709 full
|
|
199
|
+
// range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely
|
|
200
|
+
// matching the results we have on CPU. So we're using a custom color
|
|
201
|
+
// conversion matrix, which provides more accurate results. See the note
|
|
202
|
+
// mentioned above for details, and headaches.
|
|
203
|
+
|
|
204
|
+
int srcStep[2] = {avFrame->linesize[0], avFrame->linesize[1]};
|
|
205
|
+
|
|
206
|
+
status = nppiNV12ToRGB_8u_ColorTwist32f_P2C3R_Ctx(
|
|
207
|
+
yuvData,
|
|
208
|
+
srcStep,
|
|
209
|
+
static_cast<Npp8u*>(dst.data_ptr()),
|
|
210
|
+
dst.stride(0),
|
|
211
|
+
oSizeROI,
|
|
212
|
+
bt709FullRangeColorTwist,
|
|
213
|
+
*nppCtx);
|
|
214
|
+
} else {
|
|
215
|
+
// If not full range, we assume studio limited range.
|
|
216
|
+
// The color conversion matrix for BT.709 limited range should be:
|
|
217
|
+
// static const Npp32f bt709LimitedRangeColorTwist[3][4] = {
|
|
218
|
+
// {1.16438356f, 0.0f, 1.79274107f, -16.0f},
|
|
219
|
+
// {1.16438356f, -0.213248614f, -0.5329093290f, -128.0f},
|
|
220
|
+
// {1.16438356f, 2.11240179f, 0.0f, -128.0f}
|
|
221
|
+
// };
|
|
222
|
+
// We get very close results to CPU with that, but using the pre-defined
|
|
223
|
+
// nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx seems to be even more accurate.
|
|
224
|
+
status = nppiNV12ToRGB_709CSC_8u_P2C3R_Ctx(
|
|
225
|
+
yuvData,
|
|
226
|
+
avFrame->linesize[0],
|
|
227
|
+
static_cast<Npp8u*>(dst.data_ptr()),
|
|
228
|
+
dst.stride(0),
|
|
229
|
+
oSizeROI,
|
|
230
|
+
*nppCtx);
|
|
231
|
+
}
|
|
232
|
+
} else {
|
|
233
|
+
// TODO we're assuming BT.601 color space (and probably limited range) by
|
|
234
|
+
// calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range,
|
|
235
|
+
// and other color-spaces like 2020.
|
|
236
|
+
status = nppiNV12ToRGB_8u_P2C3R_Ctx(
|
|
237
|
+
yuvData,
|
|
238
|
+
avFrame->linesize[0],
|
|
239
|
+
static_cast<Npp8u*>(dst.data_ptr()),
|
|
240
|
+
dst.stride(0),
|
|
241
|
+
oSizeROI,
|
|
242
|
+
*nppCtx);
|
|
243
|
+
}
|
|
244
|
+
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
|
|
245
|
+
|
|
246
|
+
return dst;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
UniqueNppContext getNppStreamContext(const torch::Device& device) {
|
|
250
|
+
int deviceIndex = getDeviceIndex(device);
|
|
251
|
+
|
|
252
|
+
UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device);
|
|
253
|
+
if (nppCtx) {
|
|
254
|
+
return nppCtx;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// From 12.9, NPP recommends using a user-created NppStreamContext and using
|
|
258
|
+
// the `_Ctx()` calls:
|
|
259
|
+
// https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#npp-release-12-9-update-1
|
|
260
|
+
// And the nppGetStreamContext() helper is deprecated. We are explicitly
|
|
261
|
+
// supposed to create the NppStreamContext manually from the CUDA device
|
|
262
|
+
// properties:
|
|
263
|
+
// https://github.com/NVIDIA/CUDALibrarySamples/blob/d97803a40fab83c058bb3d68b6c38bd6eebfff43/NPP/README.md?plain=1#L54-L72
|
|
264
|
+
|
|
265
|
+
nppCtx = std::make_unique<NppStreamContext>();
|
|
266
|
+
cudaDeviceProp prop{};
|
|
267
|
+
cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
|
|
268
|
+
TORCH_CHECK(
|
|
269
|
+
err == cudaSuccess,
|
|
270
|
+
"cudaGetDeviceProperties failed: ",
|
|
271
|
+
cudaGetErrorString(err));
|
|
272
|
+
|
|
273
|
+
nppCtx->nCudaDeviceId = deviceIndex;
|
|
274
|
+
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
|
|
275
|
+
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
|
|
276
|
+
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
|
|
277
|
+
nppCtx->nSharedMemPerBlock = prop.sharedMemPerBlock;
|
|
278
|
+
nppCtx->nCudaDevAttrComputeCapabilityMajor = prop.major;
|
|
279
|
+
nppCtx->nCudaDevAttrComputeCapabilityMinor = prop.minor;
|
|
280
|
+
|
|
281
|
+
return nppCtx;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
void returnNppStreamContextToCache(
|
|
285
|
+
const torch::Device& device,
|
|
286
|
+
UniqueNppContext nppCtx) {
|
|
287
|
+
if (nppCtx) {
|
|
288
|
+
g_cached_npp_ctxs.addIfCacheHasCapacity(device, std::move(nppCtx));
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
void validatePreAllocatedTensorShape(
|
|
293
|
+
const std::optional<torch::Tensor>& preAllocatedOutputTensor,
|
|
294
|
+
const UniqueAVFrame& avFrame) {
|
|
295
|
+
// Note that CUDA does not yet support transforms, so the only possible
|
|
296
|
+
// frame dimensions are the raw decoded frame's dimensions.
|
|
297
|
+
auto frameDims = FrameDims(avFrame->height, avFrame->width);
|
|
298
|
+
|
|
299
|
+
if (preAllocatedOutputTensor.has_value()) {
|
|
300
|
+
auto shape = preAllocatedOutputTensor.value().sizes();
|
|
301
|
+
TORCH_CHECK(
|
|
302
|
+
(shape.size() == 3) && (shape[0] == frameDims.height) &&
|
|
303
|
+
(shape[1] == frameDims.width) && (shape[2] == 3),
|
|
304
|
+
"Expected tensor of shape ",
|
|
305
|
+
frameDims.height,
|
|
306
|
+
"x",
|
|
307
|
+
frameDims.width,
|
|
308
|
+
"x3, got ",
|
|
309
|
+
shape);
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
int getDeviceIndex(const torch::Device& device) {
|
|
314
|
+
// PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA
|
|
315
|
+
// libraries use int. So we use int, too.
|
|
316
|
+
int deviceIndex = static_cast<int>(device.index());
|
|
317
|
+
TORCH_CHECK(
|
|
318
|
+
deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
|
|
319
|
+
"Invalid device index = ",
|
|
320
|
+
deviceIndex);
|
|
321
|
+
|
|
322
|
+
if (deviceIndex == -1) {
|
|
323
|
+
TORCH_CHECK(
|
|
324
|
+
cudaGetDevice(&deviceIndex) == cudaSuccess,
|
|
325
|
+
"Failed to get current CUDA device.");
|
|
326
|
+
}
|
|
327
|
+
return deviceIndex;
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
} // namespace facebook::torchcodec
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
// All rights reserved.
|
|
3
|
+
//
|
|
4
|
+
// This source code is licensed under the BSD-style license found in the
|
|
5
|
+
// LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
#pragma once
|
|
8
|
+
|
|
9
|
+
#include <ATen/cuda/CUDAEvent.h>
|
|
10
|
+
#include <c10/cuda/CUDAStream.h>
|
|
11
|
+
#include <npp.h>
|
|
12
|
+
#include <torch/types.h>
|
|
13
|
+
|
|
14
|
+
#include "src/torchcodec/_core/FFMPEGCommon.h"
|
|
15
|
+
#include "src/torchcodec/_core/Frame.h"
|
|
16
|
+
|
|
17
|
+
extern "C" {
|
|
18
|
+
#include <libavutil/hwcontext_cuda.h>
|
|
19
|
+
#include <libavutil/pixdesc.h>
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
namespace facebook::torchcodec {
|
|
23
|
+
|
|
24
|
+
// Pytorch can only handle up to 128 GPUs.
|
|
25
|
+
// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
|
|
26
|
+
constexpr int MAX_CUDA_GPUS = 128;
|
|
27
|
+
|
|
28
|
+
void initializeCudaContextWithPytorch(const torch::Device& device);
|
|
29
|
+
|
|
30
|
+
// Unique pointer type for NPP stream context
|
|
31
|
+
using UniqueNppContext = std::unique_ptr<NppStreamContext>;
|
|
32
|
+
|
|
33
|
+
torch::Tensor convertNV12FrameToRGB(
|
|
34
|
+
UniqueAVFrame& avFrame,
|
|
35
|
+
const torch::Device& device,
|
|
36
|
+
const UniqueNppContext& nppCtx,
|
|
37
|
+
at::cuda::CUDAStream nvdecStream,
|
|
38
|
+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
|
|
39
|
+
|
|
40
|
+
UniqueNppContext getNppStreamContext(const torch::Device& device);
|
|
41
|
+
void returnNppStreamContextToCache(
|
|
42
|
+
const torch::Device& device,
|
|
43
|
+
UniqueNppContext nppCtx);
|
|
44
|
+
|
|
45
|
+
void validatePreAllocatedTensorShape(
|
|
46
|
+
const std::optional<torch::Tensor>& preAllocatedOutputTensor,
|
|
47
|
+
const UniqueAVFrame& avFrame);
|
|
48
|
+
|
|
49
|
+
int getDeviceIndex(const torch::Device& device);
|
|
50
|
+
|
|
51
|
+
} // namespace facebook::torchcodec
|
torchcodec/_core/Cache.h
CHANGED
|
@@ -95,30 +95,16 @@ class PerGpuCache {
|
|
|
95
95
|
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
|
|
96
96
|
};
|
|
97
97
|
|
|
98
|
-
//
|
|
99
|
-
//
|
|
100
|
-
//
|
|
101
|
-
|
|
102
|
-
// annoying for such a small amount of code, so we just inline it. If this file
|
|
103
|
-
// grows, and there are more such functions, we should break them out into a
|
|
104
|
-
// .cpp file.
|
|
105
|
-
inline torch::DeviceIndex getNonNegativeDeviceIndex(
|
|
106
|
-
const torch::Device& device) {
|
|
107
|
-
torch::DeviceIndex deviceIndex = device.index();
|
|
108
|
-
// For single GPU machines libtorch returns -1 for the device index. So for
|
|
109
|
-
// that case we set the device index to 0. That's used in per-gpu cache
|
|
110
|
-
// implementation and during initialization of CUDA and FFmpeg contexts
|
|
111
|
-
// which require non negative indices.
|
|
112
|
-
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
|
|
113
|
-
TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
|
|
114
|
-
return deviceIndex;
|
|
115
|
-
}
|
|
98
|
+
// Forward declaration of getDeviceIndex which exists in CUDACommon.h
|
|
99
|
+
// This avoids circular dependency between Cache.h and CUDACommon.cpp which also
|
|
100
|
+
// needs to include Cache.h
|
|
101
|
+
int getDeviceIndex(const torch::Device& device);
|
|
116
102
|
|
|
117
103
|
template <typename T, typename D>
|
|
118
104
|
bool PerGpuCache<T, D>::addIfCacheHasCapacity(
|
|
119
105
|
const torch::Device& device,
|
|
120
106
|
element_type&& obj) {
|
|
121
|
-
|
|
107
|
+
int deviceIndex = getDeviceIndex(device);
|
|
122
108
|
TORCH_CHECK(
|
|
123
109
|
static_cast<size_t>(deviceIndex) < cache_.size(),
|
|
124
110
|
"Device index out of range");
|
|
@@ -128,7 +114,7 @@ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
|
|
|
128
114
|
template <typename T, typename D>
|
|
129
115
|
typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
|
|
130
116
|
const torch::Device& device) {
|
|
131
|
-
|
|
117
|
+
int deviceIndex = getDeviceIndex(device);
|
|
132
118
|
TORCH_CHECK(
|
|
133
119
|
static_cast<size_t>(deviceIndex) < cache_.size(),
|
|
134
120
|
"Device index out of range");
|