torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. torchcodec/__init__.py +27 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +130 -0
  7. torchcodec/_core/AVIOTensorContext.h +44 -0
  8. torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
  9. torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
  10. torchcodec/_core/CMakeLists.txt +295 -0
  11. torchcodec/_core/CUDACommon.cpp +330 -0
  12. torchcodec/_core/CUDACommon.h +51 -0
  13. torchcodec/_core/Cache.h +124 -0
  14. torchcodec/_core/CpuDeviceInterface.cpp +509 -0
  15. torchcodec/_core/CpuDeviceInterface.h +141 -0
  16. torchcodec/_core/CudaDeviceInterface.cpp +602 -0
  17. torchcodec/_core/CudaDeviceInterface.h +79 -0
  18. torchcodec/_core/DeviceInterface.cpp +117 -0
  19. torchcodec/_core/DeviceInterface.h +191 -0
  20. torchcodec/_core/Encoder.cpp +1054 -0
  21. torchcodec/_core/Encoder.h +192 -0
  22. torchcodec/_core/FFMPEGCommon.cpp +684 -0
  23. torchcodec/_core/FFMPEGCommon.h +314 -0
  24. torchcodec/_core/FilterGraph.cpp +159 -0
  25. torchcodec/_core/FilterGraph.h +59 -0
  26. torchcodec/_core/Frame.cpp +47 -0
  27. torchcodec/_core/Frame.h +72 -0
  28. torchcodec/_core/Metadata.cpp +124 -0
  29. torchcodec/_core/Metadata.h +92 -0
  30. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  31. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  32. torchcodec/_core/NVDECCache.cpp +60 -0
  33. torchcodec/_core/NVDECCache.h +102 -0
  34. torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
  35. torchcodec/_core/SingleStreamDecoder.h +391 -0
  36. torchcodec/_core/StreamOptions.h +70 -0
  37. torchcodec/_core/Transform.cpp +128 -0
  38. torchcodec/_core/Transform.h +86 -0
  39. torchcodec/_core/ValidationUtils.cpp +35 -0
  40. torchcodec/_core/ValidationUtils.h +21 -0
  41. torchcodec/_core/__init__.py +46 -0
  42. torchcodec/_core/_metadata.py +262 -0
  43. torchcodec/_core/custom_ops.cpp +1090 -0
  44. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
  45. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  46. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  47. torchcodec/_core/ops.py +605 -0
  48. torchcodec/_core/pybind_ops.cpp +50 -0
  49. torchcodec/_frame.py +146 -0
  50. torchcodec/_internally_replaced_utils.py +68 -0
  51. torchcodec/_samplers/__init__.py +7 -0
  52. torchcodec/_samplers/video_clip_sampler.py +419 -0
  53. torchcodec/decoders/__init__.py +12 -0
  54. torchcodec/decoders/_audio_decoder.py +185 -0
  55. torchcodec/decoders/_decoder_utils.py +113 -0
  56. torchcodec/decoders/_video_decoder.py +601 -0
  57. torchcodec/encoders/__init__.py +2 -0
  58. torchcodec/encoders/_audio_encoder.py +149 -0
  59. torchcodec/encoders/_video_encoder.py +196 -0
  60. torchcodec/libtorchcodec_core4.so +0 -0
  61. torchcodec/libtorchcodec_core5.so +0 -0
  62. torchcodec/libtorchcodec_core6.so +0 -0
  63. torchcodec/libtorchcodec_core7.so +0 -0
  64. torchcodec/libtorchcodec_core8.so +0 -0
  65. torchcodec/libtorchcodec_custom_ops4.so +0 -0
  66. torchcodec/libtorchcodec_custom_ops5.so +0 -0
  67. torchcodec/libtorchcodec_custom_ops6.so +0 -0
  68. torchcodec/libtorchcodec_custom_ops7.so +0 -0
  69. torchcodec/libtorchcodec_custom_ops8.so +0 -0
  70. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  71. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  72. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  73. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  74. torchcodec/libtorchcodec_pybind_ops8.so +0 -0
  75. torchcodec/samplers/__init__.py +2 -0
  76. torchcodec/samplers/_common.py +84 -0
  77. torchcodec/samplers/_index_based.py +287 -0
  78. torchcodec/samplers/_time_based.py +358 -0
  79. torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
  80. torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
  81. torchcodec/transforms/__init__.py +12 -0
  82. torchcodec/transforms/_decoder_transforms.py +375 -0
  83. torchcodec/version.py +2 -0
  84. torchcodec-0.10.0.dist-info/METADATA +286 -0
  85. torchcodec-0.10.0.dist-info/RECORD +88 -0
  86. torchcodec-0.10.0.dist-info/WHEEL +5 -0
  87. torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
  88. torchcodec-0.10.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1586 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include "SingleStreamDecoder.h"
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+ #include <iostream>
11
+ #include <limits>
12
+ #include <sstream>
13
+ #include <stdexcept>
14
+ #include <string_view>
15
+ #include "Metadata.h"
16
+ #include "torch/types.h"
17
+
18
+ namespace facebook::torchcodec {
19
+ namespace {
20
+
21
+ // Some videos aren't properly encoded and do not specify pts values for
22
+ // packets, and thus for frames. Unset values correspond to INT64_MIN. When that
23
+ // happens, we fallback to the dts value which hopefully exists and is correct.
24
+ // Accessing AVFrames and AVPackets's pts values should **always** go through
25
+ // the helpers below. Then, the "pts" fields in our structs like FrameInfo.pts
26
+ // should be interpreted as "pts if it exists, dts otherwise".
27
+ int64_t getPtsOrDts(ReferenceAVPacket& packet) {
28
+ return packet->pts == INT64_MIN ? packet->dts : packet->pts;
29
+ }
30
+
31
+ int64_t getPtsOrDts(const UniqueAVFrame& avFrame) {
32
+ return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts;
33
+ }
34
+
35
+ } // namespace
36
+
37
+ // --------------------------------------------------------------------------
38
+ // CONSTRUCTORS, INITIALIZATION, DESTRUCTORS
39
+ // --------------------------------------------------------------------------
40
+
41
+ SingleStreamDecoder::SingleStreamDecoder(
42
+ const std::string& videoFilePath,
43
+ SeekMode seekMode)
44
+ : seekMode_(seekMode) {
45
+ setFFmpegLogLevel();
46
+
47
+ AVFormatContext* rawContext = nullptr;
48
+ int status =
49
+ avformat_open_input(&rawContext, videoFilePath.c_str(), nullptr, nullptr);
50
+ TORCH_CHECK(
51
+ status == 0,
52
+ "Could not open input file: " + videoFilePath + " " +
53
+ getFFMPEGErrorStringFromErrorCode(status));
54
+ TORCH_CHECK(rawContext != nullptr);
55
+ formatContext_.reset(rawContext);
56
+
57
+ initializeDecoder();
58
+ }
59
+
60
+ SingleStreamDecoder::SingleStreamDecoder(
61
+ std::unique_ptr<AVIOContextHolder> context,
62
+ SeekMode seekMode)
63
+ : seekMode_(seekMode), avioContextHolder_(std::move(context)) {
64
+ setFFmpegLogLevel();
65
+
66
+ TORCH_CHECK(avioContextHolder_, "Context holder cannot be null");
67
+
68
+ // Because FFmpeg requires a reference to a pointer in the call to open, we
69
+ // can't use a unique pointer here. Note that means we must call free if open
70
+ // fails.
71
+ AVFormatContext* rawContext = avformat_alloc_context();
72
+ TORCH_CHECK(rawContext != nullptr, "Unable to alloc avformat context");
73
+
74
+ rawContext->pb = avioContextHolder_->getAVIOContext();
75
+ int status = avformat_open_input(&rawContext, nullptr, nullptr, nullptr);
76
+ if (status != 0) {
77
+ avformat_free_context(rawContext);
78
+ TORCH_CHECK(
79
+ false,
80
+ "Failed to open input buffer: " +
81
+ getFFMPEGErrorStringFromErrorCode(status));
82
+ }
83
+
84
+ formatContext_.reset(rawContext);
85
+
86
+ initializeDecoder();
87
+ }
88
+
89
+ void SingleStreamDecoder::initializeDecoder() {
90
+ TORCH_CHECK(!initialized_, "Attempted double initialization.");
91
+
92
+ // In principle, the AVFormatContext should be filled in by the call to
93
+ // avformat_open_input() which reads the header. However, some formats do not
94
+ // store enough info in the header, so we call avformat_find_stream_info()
95
+ // which decodes a few frames to get missing info. For more, see:
96
+ // https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html
97
+ int status = avformat_find_stream_info(formatContext_.get(), nullptr);
98
+ TORCH_CHECK(
99
+ status >= 0,
100
+ "Failed to find stream info: ",
101
+ getFFMPEGErrorStringFromErrorCode(status));
102
+
103
+ if (formatContext_->duration > 0) {
104
+ AVRational defaultTimeBase{1, AV_TIME_BASE};
105
+ containerMetadata_.durationSecondsFromHeader =
106
+ ptsToSeconds(formatContext_->duration, defaultTimeBase);
107
+ }
108
+
109
+ if (formatContext_->bit_rate > 0) {
110
+ containerMetadata_.bitRate = formatContext_->bit_rate;
111
+ }
112
+
113
+ int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
114
+ if (bestVideoStream >= 0) {
115
+ containerMetadata_.bestVideoStreamIndex = bestVideoStream;
116
+ }
117
+
118
+ int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
119
+ if (bestAudioStream >= 0) {
120
+ containerMetadata_.bestAudioStreamIndex = bestAudioStream;
121
+ }
122
+
123
+ for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
124
+ AVStream* avStream = formatContext_->streams[i];
125
+ StreamMetadata streamMetadata;
126
+
127
+ TORCH_CHECK(
128
+ static_cast<int>(i) == avStream->index,
129
+ "Our stream index, " + std::to_string(i) +
130
+ ", does not match AVStream's index, " +
131
+ std::to_string(avStream->index) + ".");
132
+ streamMetadata.streamIndex = i;
133
+ streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
134
+ streamMetadata.mediaType = avStream->codecpar->codec_type;
135
+ streamMetadata.bitRate = avStream->codecpar->bit_rate;
136
+
137
+ int64_t frameCount = avStream->nb_frames;
138
+ if (frameCount > 0) {
139
+ streamMetadata.numFramesFromHeader = frameCount;
140
+ }
141
+
142
+ if (avStream->duration > 0 && avStream->time_base.den > 0) {
143
+ streamMetadata.durationSecondsFromHeader =
144
+ ptsToSeconds(avStream->duration, avStream->time_base);
145
+ }
146
+ if (avStream->start_time != AV_NOPTS_VALUE) {
147
+ streamMetadata.beginStreamSecondsFromHeader =
148
+ ptsToSeconds(avStream->start_time, avStream->time_base);
149
+ }
150
+
151
+ if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
152
+ double fps = av_q2d(avStream->r_frame_rate);
153
+ if (fps > 0) {
154
+ streamMetadata.averageFpsFromHeader = fps;
155
+ }
156
+ streamMetadata.width = avStream->codecpar->width;
157
+ streamMetadata.height = avStream->codecpar->height;
158
+ streamMetadata.sampleAspectRatio =
159
+ avStream->codecpar->sample_aspect_ratio;
160
+ containerMetadata_.numVideoStreams++;
161
+ } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
162
+ AVSampleFormat format =
163
+ static_cast<AVSampleFormat>(avStream->codecpar->format);
164
+ streamMetadata.sampleRate =
165
+ static_cast<int64_t>(avStream->codecpar->sample_rate);
166
+ streamMetadata.numChannels =
167
+ static_cast<int64_t>(getNumChannels(avStream->codecpar));
168
+
169
+ // If the AVSampleFormat is not recognized, we get back nullptr. We have
170
+ // to make sure we don't initialize a std::string with nullptr. There's
171
+ // nothing to do on the else branch because we're already using an
172
+ // optional; it'll just remain empty.
173
+ const char* rawSampleFormat = av_get_sample_fmt_name(format);
174
+ if (rawSampleFormat != nullptr) {
175
+ streamMetadata.sampleFormat = std::string(rawSampleFormat);
176
+ }
177
+ containerMetadata_.numAudioStreams++;
178
+ }
179
+
180
+ streamMetadata.durationSecondsFromContainer =
181
+ containerMetadata_.durationSecondsFromHeader;
182
+
183
+ containerMetadata_.allStreamMetadata.push_back(streamMetadata);
184
+ }
185
+
186
+ if (seekMode_ == SeekMode::exact) {
187
+ scanFileAndUpdateMetadataAndIndex();
188
+ }
189
+
190
+ initialized_ = true;
191
+ }
192
+
193
+ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) {
194
+ AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
195
+ int streamIndex =
196
+ av_find_best_stream(formatContext_.get(), mediaType, -1, -1, &avCodec, 0);
197
+ return streamIndex;
198
+ }
199
+
200
+ // --------------------------------------------------------------------------
201
+ // VIDEO METADATA QUERY API
202
+ // --------------------------------------------------------------------------
203
+
204
+ void SingleStreamDecoder::sortAllFrames() {
205
+ // Sort the allFrames and keyFrames vecs in each stream, and also sets
206
+ // additional fields of the FrameInfo entries like nextPts and frameIndex
207
+ // This is called at the end of a scan, or when setting a user-defined frame
208
+ // mapping.
209
+ for (auto& [streamIndex, streamInfo] : streamInfos_) {
210
+ std::sort(
211
+ streamInfo.keyFrames.begin(),
212
+ streamInfo.keyFrames.end(),
213
+ [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
214
+ return frameInfo1.pts < frameInfo2.pts;
215
+ });
216
+ std::sort(
217
+ streamInfo.allFrames.begin(),
218
+ streamInfo.allFrames.end(),
219
+ [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
220
+ return frameInfo1.pts < frameInfo2.pts;
221
+ });
222
+
223
+ size_t keyFrameIndex = 0;
224
+ for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
225
+ streamInfo.allFrames[i].frameIndex = i;
226
+ if (streamInfo.allFrames[i].isKeyFrame) {
227
+ TORCH_CHECK(
228
+ keyFrameIndex < streamInfo.keyFrames.size(),
229
+ "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec.");
230
+ streamInfo.keyFrames[keyFrameIndex].frameIndex = i;
231
+ ++keyFrameIndex;
232
+ }
233
+ if (i + 1 < streamInfo.allFrames.size()) {
234
+ streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
235
+ }
236
+ }
237
+ TORCH_CHECK(
238
+ keyFrameIndex == streamInfo.keyFrames.size(),
239
+ "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec.");
240
+ }
241
+ }
242
+
243
+ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
244
+ if (scannedAllStreams_) {
245
+ return;
246
+ }
247
+
248
+ AutoAVPacket autoAVPacket;
249
+ while (true) {
250
+ ReferenceAVPacket packet(autoAVPacket);
251
+
252
+ // av_read_frame is a misleading name: it gets the next **packet**.
253
+ int status = av_read_frame(formatContext_.get(), packet.get());
254
+
255
+ if (status == AVERROR_EOF) {
256
+ break;
257
+ }
258
+
259
+ TORCH_CHECK(
260
+ status == AVSUCCESS,
261
+ "Failed to read frame from input file: ",
262
+ getFFMPEGErrorStringFromErrorCode(status));
263
+
264
+ if (packet->flags & AV_PKT_FLAG_DISCARD) {
265
+ continue;
266
+ }
267
+
268
+ // We got a valid packet. Let's figure out what stream it belongs to and
269
+ // record its relevant metadata.
270
+ int streamIndex = packet->stream_index;
271
+ auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
272
+ streamMetadata.beginStreamPtsFromContent = std::min(
273
+ streamMetadata.beginStreamPtsFromContent.value_or(INT64_MAX),
274
+ getPtsOrDts(packet));
275
+ streamMetadata.endStreamPtsFromContent = std::max(
276
+ streamMetadata.endStreamPtsFromContent.value_or(INT64_MIN),
277
+ getPtsOrDts(packet) + packet->duration);
278
+ streamMetadata.numFramesFromContent =
279
+ streamMetadata.numFramesFromContent.value_or(0) + 1;
280
+
281
+ // Note that we set the other value in this struct, nextPts, only after
282
+ // we have scanned all packets and sorted by pts.
283
+ FrameInfo frameInfo = {getPtsOrDts(packet)};
284
+ if (packet->flags & AV_PKT_FLAG_KEY) {
285
+ frameInfo.isKeyFrame = true;
286
+ streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
287
+ }
288
+ streamInfos_[streamIndex].allFrames.push_back(frameInfo);
289
+ }
290
+
291
+ // Set all per-stream metadata that requires knowing the content of all
292
+ // packets.
293
+ for (size_t streamIndex = 0;
294
+ streamIndex < containerMetadata_.allStreamMetadata.size();
295
+ ++streamIndex) {
296
+ auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
297
+ auto avStream = formatContext_->streams[streamIndex];
298
+
299
+ streamMetadata.numFramesFromContent =
300
+ streamInfos_[streamIndex].allFrames.size();
301
+
302
+ // This ensures that we are robust in handling cases where
303
+ // we are decoding in exact mode and numFrames is 0. The current metadata
304
+ // validation logic assumes that these values should not be None
305
+ if (streamMetadata.numFramesFromContent.value() == 0) {
306
+ streamMetadata.beginStreamPtsFromContent = 0;
307
+ streamMetadata.endStreamPtsFromContent = 0;
308
+ }
309
+
310
+ if (streamMetadata.beginStreamPtsFromContent.has_value()) {
311
+ streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds(
312
+ *streamMetadata.beginStreamPtsFromContent, avStream->time_base);
313
+ }
314
+ if (streamMetadata.endStreamPtsFromContent.has_value()) {
315
+ streamMetadata.endStreamPtsSecondsFromContent = ptsToSeconds(
316
+ *streamMetadata.endStreamPtsFromContent, avStream->time_base);
317
+ }
318
+ }
319
+
320
+ // Reset the seek-cursor back to the beginning.
321
+ int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
322
+ TORCH_CHECK(
323
+ status >= 0,
324
+ "Could not seek file to pts=0: ",
325
+ getFFMPEGErrorStringFromErrorCode(status));
326
+
327
+ // Sort all frames by their pts.
328
+ sortAllFrames();
329
+ scannedAllStreams_ = true;
330
+ }
331
+
332
+ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
333
+ int streamIndex,
334
+ FrameMappings customFrameMappings) {
335
+ TORCH_CHECK(
336
+ customFrameMappings.all_frames.dtype() == torch::kLong &&
337
+ customFrameMappings.is_key_frame.dtype() == torch::kBool &&
338
+ customFrameMappings.duration.dtype() == torch::kLong,
339
+ "all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype.");
340
+ const torch::Tensor& all_frames =
341
+ customFrameMappings.all_frames.to(torch::kLong);
342
+ const torch::Tensor& is_key_frame =
343
+ customFrameMappings.is_key_frame.to(torch::kBool);
344
+ const torch::Tensor& duration = customFrameMappings.duration.to(torch::kLong);
345
+ TORCH_CHECK(
346
+ all_frames.size(0) == is_key_frame.size(0) &&
347
+ is_key_frame.size(0) == duration.size(0),
348
+ "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.");
349
+
350
+ // Allocate vectors using num frames to reduce reallocations
351
+ int64_t numFrames = all_frames.size(0);
352
+ streamInfos_[streamIndex].allFrames.reserve(numFrames);
353
+ streamInfos_[streamIndex].keyFrames.reserve(numFrames);
354
+ // Use accessor to efficiently access tensor elements
355
+ auto pts_data = all_frames.accessor<int64_t, 1>();
356
+ auto is_key_frame_data = is_key_frame.accessor<bool, 1>();
357
+ auto duration_data = duration.accessor<int64_t, 1>();
358
+
359
+ auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
360
+
361
+ streamMetadata.beginStreamPtsFromContent = pts_data[0];
362
+ streamMetadata.endStreamPtsFromContent =
363
+ pts_data[numFrames - 1] + duration_data[numFrames - 1];
364
+
365
+ auto avStream = formatContext_->streams[streamIndex];
366
+ streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds(
367
+ *streamMetadata.beginStreamPtsFromContent, avStream->time_base);
368
+
369
+ streamMetadata.endStreamPtsSecondsFromContent = ptsToSeconds(
370
+ *streamMetadata.endStreamPtsFromContent, avStream->time_base);
371
+
372
+ streamMetadata.numFramesFromContent = numFrames;
373
+ for (int64_t i = 0; i < numFrames; ++i) {
374
+ FrameInfo frameInfo;
375
+ frameInfo.pts = pts_data[i];
376
+ frameInfo.isKeyFrame = is_key_frame_data[i];
377
+ streamInfos_[streamIndex].allFrames.push_back(frameInfo);
378
+ if (frameInfo.isKeyFrame) {
379
+ streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
380
+ }
381
+ }
382
+ sortAllFrames();
383
+ }
384
+
385
+ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
386
+ return containerMetadata_;
387
+ }
388
+
389
+ SeekMode SingleStreamDecoder::getSeekMode() const {
390
+ return seekMode_;
391
+ }
392
+
393
+ int SingleStreamDecoder::getActiveStreamIndex() const {
394
+ return activeStreamIndex_;
395
+ }
396
+
397
+ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
398
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
399
+ validateScannedAllStreams("getKeyFrameIndices");
400
+
401
+ const std::vector<FrameInfo>& keyFrames =
402
+ streamInfos_[activeStreamIndex_].keyFrames;
403
+ torch::Tensor keyFrameIndices =
404
+ torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
405
+ for (size_t i = 0; i < keyFrames.size(); ++i) {
406
+ keyFrameIndices[i] = keyFrames[i].frameIndex;
407
+ }
408
+
409
+ return keyFrameIndices;
410
+ }
411
+
412
+ // --------------------------------------------------------------------------
413
+ // ADDING STREAMS API
414
+ // --------------------------------------------------------------------------
415
+
416
+ void SingleStreamDecoder::addStream(
417
+ int streamIndex,
418
+ AVMediaType mediaType,
419
+ const torch::Device& device,
420
+ const std::string_view deviceVariant,
421
+ std::optional<int> ffmpegThreadCount) {
422
+ TORCH_CHECK(
423
+ activeStreamIndex_ == NO_ACTIVE_STREAM,
424
+ "Can only add one single stream.");
425
+ TORCH_CHECK(
426
+ mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
427
+ "Can only add video or audio streams.");
428
+ TORCH_CHECK(formatContext_.get() != nullptr);
429
+
430
+ AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
431
+
432
+ activeStreamIndex_ = av_find_best_stream(
433
+ formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0);
434
+
435
+ if (activeStreamIndex_ < 0) {
436
+ throw std::invalid_argument(
437
+ "No valid stream found in input file. Is " +
438
+ std::to_string(streamIndex) + " of the desired media type?");
439
+ }
440
+
441
+ TORCH_CHECK(avCodec != nullptr);
442
+
443
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
444
+ streamInfo.streamIndex = activeStreamIndex_;
445
+ streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
446
+ streamInfo.stream = formatContext_->streams[activeStreamIndex_];
447
+ streamInfo.avMediaType = mediaType;
448
+
449
+ // This should never happen, checking just to be safe.
450
+ TORCH_CHECK(
451
+ streamInfo.stream->codecpar->codec_type == mediaType,
452
+ "FFmpeg found stream with index ",
453
+ activeStreamIndex_,
454
+ " which is of the wrong media type.");
455
+
456
+ deviceInterface_ = createDeviceInterface(device, deviceVariant);
457
+ TORCH_CHECK(
458
+ deviceInterface_ != nullptr,
459
+ "Failed to create device interface. This should never happen, please report.");
460
+
461
+ // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
462
+ // addStream() which is supposed to be generic
463
+ if (mediaType == AVMEDIA_TYPE_VIDEO) {
464
+ avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
465
+ deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
466
+ .value_or(avCodec));
467
+ }
468
+
469
+ AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
470
+ TORCH_CHECK(codecContext != nullptr);
471
+ streamInfo.codecContext = makeSharedAVCodecContext(codecContext);
472
+
473
+ int retVal = avcodec_parameters_to_context(
474
+ streamInfo.codecContext.get(), streamInfo.stream->codecpar);
475
+ TORCH_CHECK_EQ(retVal, AVSUCCESS);
476
+
477
+ streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
478
+ streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
479
+
480
+ // Note that we must make sure to register the harware device context
481
+ // with the codec context before calling avcodec_open2(). Otherwise, decoding
482
+ // will happen on the CPU and not the hardware device.
483
+ deviceInterface_->registerHardwareDeviceWithCodec(
484
+ streamInfo.codecContext.get());
485
+ retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
486
+ TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
487
+
488
+ streamInfo.codecContext->time_base = streamInfo.stream->time_base;
489
+
490
+ // Initialize the device interface with the codec context
491
+ deviceInterface_->initialize(
492
+ streamInfo.stream, formatContext_, streamInfo.codecContext);
493
+
494
+ containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
495
+ std::string(avcodec_get_name(streamInfo.codecContext->codec_id));
496
+
497
+ // We will only need packets from the active stream, so we tell FFmpeg to
498
+ // discard packets from the other streams. Note that av_read_frame() may still
499
+ // return some of those un-desired packet under some conditions, so it's still
500
+ // important to discard/demux correctly in the inner decoding loop.
501
+ for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) {
502
+ if (i != static_cast<unsigned int>(activeStreamIndex_)) {
503
+ formatContext_->streams[i]->discard = AVDISCARD_ALL;
504
+ }
505
+ }
506
+ }
507
+
508
+ void SingleStreamDecoder::addVideoStream(
509
+ int streamIndex,
510
+ std::vector<Transform*>& transforms,
511
+ const VideoStreamOptions& videoStreamOptions,
512
+ std::optional<FrameMappings> customFrameMappings) {
513
+ TORCH_CHECK(
514
+ transforms.empty() || videoStreamOptions.device == torch::kCPU,
515
+ " Transforms are only supported for CPU devices.");
516
+
517
+ addStream(
518
+ streamIndex,
519
+ AVMEDIA_TYPE_VIDEO,
520
+ videoStreamOptions.device,
521
+ videoStreamOptions.deviceVariant,
522
+ videoStreamOptions.ffmpegThreadCount);
523
+
524
+ auto& streamMetadata =
525
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
526
+
527
+ if (seekMode_ == SeekMode::approximate) {
528
+ TORCH_CHECK(
529
+ streamMetadata.averageFpsFromHeader.has_value(),
530
+ "Seek mode is approximate, but stream ",
531
+ std::to_string(activeStreamIndex_),
532
+ " does not have an average fps in its metadata.");
533
+ }
534
+
535
+ auto& streamInfo = streamInfos_[activeStreamIndex_];
536
+ streamInfo.videoStreamOptions = videoStreamOptions;
537
+
538
+ if (seekMode_ == SeekMode::custom_frame_mappings) {
539
+ TORCH_CHECK(
540
+ customFrameMappings.has_value(),
541
+ "Missing frame mappings when custom_frame_mappings seek mode is set.");
542
+ readCustomFrameMappingsUpdateMetadataAndIndex(
543
+ activeStreamIndex_, customFrameMappings.value());
544
+ }
545
+
546
+ metadataDims_ =
547
+ FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
548
+ FrameDims currInputDims = metadataDims_;
549
+ for (auto& transform : transforms) {
550
+ TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
551
+ if (transform->getOutputFrameDims().has_value()) {
552
+ resizedOutputDims_ = transform->getOutputFrameDims().value();
553
+ }
554
+ transform->validate(currInputDims);
555
+ currInputDims = resizedOutputDims_.value_or(metadataDims_);
556
+
557
+ // Note that we are claiming ownership of the transform objects passed in to
558
+ // us.
559
+ transforms_.push_back(std::unique_ptr<Transform>(transform));
560
+ }
561
+
562
+ deviceInterface_->initializeVideo(
563
+ videoStreamOptions, transforms_, resizedOutputDims_);
564
+ }
565
+
566
+ void SingleStreamDecoder::addAudioStream(
567
+ int streamIndex,
568
+ const AudioStreamOptions& audioStreamOptions) {
569
+ TORCH_CHECK(
570
+ seekMode_ == SeekMode::approximate,
571
+ "seek_mode must be 'approximate' for audio streams.");
572
+ if (audioStreamOptions.numChannels.has_value()) {
573
+ TORCH_CHECK(
574
+ *audioStreamOptions.numChannels > 0 &&
575
+ *audioStreamOptions.numChannels <= AV_NUM_DATA_POINTERS,
576
+ "num_channels must be > 0 and <= AV_NUM_DATA_POINTERS (usually 8). Got: ",
577
+ *audioStreamOptions.numChannels);
578
+ }
579
+
580
+ addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
581
+
582
+ auto& streamInfo = streamInfos_[activeStreamIndex_];
583
+ streamInfo.audioStreamOptions = audioStreamOptions;
584
+
585
+ // FFmpeg docs say that the decoder will try to decode natively in this
586
+ // format, if it can. Docs don't say what the decoder does when it doesn't
587
+ // support that format, but it looks like it does nothing, so this probably
588
+ // doesn't hurt.
589
+ streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
590
+
591
+ // Initialize device interface for audio
592
+ deviceInterface_->initializeAudio(audioStreamOptions);
593
+ }
594
+
595
+ // --------------------------------------------------------------------------
596
+ // HIGH-LEVEL DECODING ENTRY-POINTS
597
+ // --------------------------------------------------------------------------
598
+
599
+ FrameOutput SingleStreamDecoder::getNextFrame() {
600
+ auto output = getNextFrameInternal();
601
+ if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
602
+ output.data = maybePermuteHWC2CHW(output.data);
603
+ }
604
+ return output;
605
+ }
606
+
607
+ FrameOutput SingleStreamDecoder::getNextFrameInternal(
608
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
609
+ validateActiveStream();
610
+ UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
611
+ return getPtsOrDts(avFrame) >= cursor_;
612
+ });
613
+ return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
614
+ }
615
+
616
+ FrameOutput SingleStreamDecoder::getFrameAtIndex(int64_t frameIndex) {
617
+ auto frameOutput = getFrameAtIndexInternal(frameIndex);
618
+ frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
619
+ return frameOutput;
620
+ }
621
+
622
+ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
623
+ int64_t frameIndex,
624
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
625
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
626
+
627
+ const auto& streamInfo = streamInfos_[activeStreamIndex_];
628
+ const auto& streamMetadata =
629
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
630
+
631
+ std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
632
+ if (numFrames.has_value()) {
633
+ // If the frameIndex is negative, we convert it to a positive index
634
+ frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
635
+ }
636
+ validateFrameIndex(streamMetadata, frameIndex);
637
+
638
+ // Only set cursor if we're not decoding sequentially: when decoding
639
+ // sequentially, we don't need to seek anywhere, so by *not* setting the
640
+ // cursor we allow canWeAvoidSeeking() to return true early.
641
+ if (frameIndex != lastDecodedFrameIndex_ + 1) {
642
+ int64_t pts = getPts(frameIndex);
643
+ setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
644
+ }
645
+
646
+ auto result = getNextFrameInternal(preAllocatedOutputTensor);
647
+ lastDecodedFrameIndex_ = frameIndex;
648
+ return result;
649
+ }
650
+
651
+ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
652
+ const torch::Tensor& frameIndices) {
653
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
654
+
655
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
656
+
657
+ bool indicesAreSorted = true;
658
+ for (int64_t i = 1; i < frameIndices.numel(); ++i) {
659
+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
660
+ indicesAreSorted = false;
661
+ break;
662
+ }
663
+ }
664
+
665
+ std::vector<size_t> argsort;
666
+ if (!indicesAreSorted) {
667
+ // if frameIndices is [13, 10, 12, 11]
668
+ // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
669
+ // to use to decode the frames
670
+ // and argsort is [ 1, 3, 2, 0]
671
+ argsort.resize(frameIndices.numel());
672
+ for (size_t i = 0; i < argsort.size(); ++i) {
673
+ argsort[i] = i;
674
+ }
675
+ std::sort(
676
+ argsort.begin(),
677
+ argsort.end(),
678
+ [&frameIndicesAccessor](size_t a, size_t b) {
679
+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
680
+ });
681
+ }
682
+
683
+ const auto& streamInfo = streamInfos_[activeStreamIndex_];
684
+ const auto& videoStreamOptions = streamInfo.videoStreamOptions;
685
+ FrameBatchOutput frameBatchOutput(
686
+ frameIndices.numel(),
687
+ resizedOutputDims_.value_or(metadataDims_),
688
+ videoStreamOptions.device);
689
+
690
+ auto previousIndexInVideo = -1;
691
+ for (int64_t f = 0; f < frameIndices.numel(); ++f) {
692
+ auto indexInOutput = indicesAreSorted ? f : argsort[f];
693
+ auto indexInVideo = frameIndicesAccessor[indexInOutput];
694
+
695
+ if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
696
+ // Avoid decoding the same frame twice
697
+ auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
698
+ frameBatchOutput.data[indexInOutput].copy_(
699
+ frameBatchOutput.data[previousIndexInOutput]);
700
+ frameBatchOutput.ptsSeconds[indexInOutput] =
701
+ frameBatchOutput.ptsSeconds[previousIndexInOutput];
702
+ frameBatchOutput.durationSeconds[indexInOutput] =
703
+ frameBatchOutput.durationSeconds[previousIndexInOutput];
704
+ } else {
705
+ FrameOutput frameOutput = getFrameAtIndexInternal(
706
+ indexInVideo, frameBatchOutput.data[indexInOutput]);
707
+ frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds;
708
+ frameBatchOutput.durationSeconds[indexInOutput] =
709
+ frameOutput.durationSeconds;
710
+ }
711
+ previousIndexInVideo = indexInVideo;
712
+ }
713
+ frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
714
+ return frameBatchOutput;
715
+ }
716
+
717
+ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
718
+ int64_t start,
719
+ int64_t stop,
720
+ int64_t step) {
721
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
722
+
723
+ const auto& streamMetadata =
724
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
725
+ const auto& streamInfo = streamInfos_[activeStreamIndex_];
726
+ TORCH_CHECK(
727
+ start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
728
+ TORCH_CHECK(
729
+ step > 0, "Step must be greater than 0; is " + std::to_string(step));
730
+
731
+ // Note that if we do not have the number of frames available in our
732
+ // metadata, then we assume that the upper part of the range is valid.
733
+ std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
734
+ if (numFrames.has_value()) {
735
+ TORCH_CHECK(
736
+ stop <= numFrames.value(),
737
+ "Range stop, " + std::to_string(stop) +
738
+ ", is more than the number of frames, " +
739
+ std::to_string(numFrames.value()));
740
+ }
741
+
742
+ int64_t numOutputFrames = std::ceil((stop - start) / double(step));
743
+ const auto& videoStreamOptions = streamInfo.videoStreamOptions;
744
+ FrameBatchOutput frameBatchOutput(
745
+ numOutputFrames,
746
+ resizedOutputDims_.value_or(metadataDims_),
747
+ videoStreamOptions.device);
748
+
749
+ for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
750
+ FrameOutput frameOutput =
751
+ getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
752
+ frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds;
753
+ frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds;
754
+ }
755
+ frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
756
+ return frameBatchOutput;
757
+ }
758
+
759
+ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
760
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
761
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
762
+ double lastDecodedStartTime =
763
+ ptsToSeconds(lastDecodedAvFramePts_, streamInfo.timeBase);
764
+ double lastDecodedEndTime = ptsToSeconds(
765
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_,
766
+ streamInfo.timeBase);
767
+ if (seconds >= lastDecodedStartTime && seconds < lastDecodedEndTime) {
768
+ // We are in the same frame as the one we just returned. However, since we
769
+ // don't cache it locally, we have to rewind back.
770
+ seconds = lastDecodedStartTime;
771
+ }
772
+
773
+ setCursorPtsInSeconds(seconds);
774
+ UniqueAVFrame avFrame =
775
+ decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
776
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
777
+ double frameStartTime =
778
+ ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase);
779
+ double frameEndTime = ptsToSeconds(
780
+ getPtsOrDts(avFrame) + getDuration(avFrame), streamInfo.timeBase);
781
+ if (frameStartTime > seconds) {
782
+ // FFMPEG seeked past the frame we are looking for even though we
783
+ // set max_ts to be our needed timestamp in avformat_seek_file()
784
+ // in maybeSeekToBeforeDesiredPts().
785
+ // This could be a bug in FFMPEG:
786
+ // https://trac.ffmpeg.org/ticket/11137 In this case we return the
787
+ // very next frame instead of throwing an exception.
788
+ // TODO: Maybe log to stderr for Debug builds?
789
+ return true;
790
+ }
791
+ return seconds >= frameStartTime && seconds < frameEndTime;
792
+ });
793
+
794
+ // Convert the frame to tensor.
795
+ FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame);
796
+ frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
797
+ return frameOutput;
798
+ }
799
+
800
+ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
801
+ const torch::Tensor& timestamps) {
802
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
803
+
804
+ const auto& streamMetadata =
805
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
806
+
807
+ double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
808
+ std::optional<double> maxSeconds =
809
+ streamMetadata.getEndStreamSeconds(seekMode_);
810
+
811
+ // The frame played at timestamp t and the one played at timestamp `t +
812
+ // eps` are probably the same frame, with the same index. The easiest way to
813
+ // avoid decoding that unique frame twice is to convert the input timestamps
814
+ // to indices, and leverage the de-duplication logic of getFramesAtIndices.
815
+
816
+ torch::Tensor frameIndices =
817
+ torch::empty({timestamps.numel()}, torch::kInt64);
818
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
819
+ auto timestampsAccessor = timestamps.accessor<double, 1>();
820
+
821
+ for (int64_t i = 0; i < timestamps.numel(); ++i) {
822
+ auto frameSeconds = timestampsAccessor[i];
823
+ TORCH_CHECK(
824
+ frameSeconds >= minSeconds,
825
+ "frame pts is " + std::to_string(frameSeconds) +
826
+ "; must be greater than or equal to " + std::to_string(minSeconds) +
827
+ ".");
828
+
829
+ // Note that if we can't determine the maximum number of seconds from the
830
+ // metadata, then we assume the frame's pts is valid.
831
+ if (maxSeconds.has_value()) {
832
+ TORCH_CHECK(
833
+ frameSeconds < maxSeconds.value(),
834
+ "frame pts is " + std::to_string(frameSeconds) +
835
+ "; must be less than " + std::to_string(maxSeconds.value()) +
836
+ ".");
837
+ }
838
+
839
+ frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
840
+ }
841
+
842
+ return getFramesAtIndices(frameIndices);
843
+ }
844
+
845
+ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
846
+ double startSeconds,
847
+ double stopSeconds) {
848
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
849
+ const auto& streamMetadata =
850
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
851
+ TORCH_CHECK(
852
+ startSeconds <= stopSeconds,
853
+ "Start seconds (" + std::to_string(startSeconds) +
854
+ ") must be less than or equal to stop seconds (" +
855
+ std::to_string(stopSeconds) + ".");
856
+
857
+ const auto& streamInfo = streamInfos_[activeStreamIndex_];
858
+ const auto& videoStreamOptions = streamInfo.videoStreamOptions;
859
+
860
+ // Special case needed to implement a half-open range. At first glance, this
861
+ // may seem unnecessary, as our search for stopFrame can return the end, and
862
+ // we don't include stopFramIndex in our output. However, consider the
863
+ // following scenario:
864
+ //
865
+ // frame=0, pts=0.0
866
+ // frame=1, pts=0.3
867
+ //
868
+ // interval A: [0.2, 0.2)
869
+ // interval B: [0.2, 0.15)
870
+ //
871
+ // Both intervals take place between the pts values for frame 0 and frame 1,
872
+ // which by our abstract player, means that both intervals map to frame 0.
873
+ // By the definition of a half open interval, interval A should return no
874
+ // frames. Interval B should return frame 0. However, for both A and B, the
875
+ // individual values of the intervals will map to the same frame indices
876
+ // below. Hence, we need this special case below.
877
+ if (startSeconds == stopSeconds) {
878
+ FrameBatchOutput frameBatchOutput(
879
+ 0,
880
+ resizedOutputDims_.value_or(metadataDims_),
881
+ videoStreamOptions.device);
882
+ frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
883
+ return frameBatchOutput;
884
+ }
885
+
886
+ double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
887
+ TORCH_CHECK(
888
+ startSeconds >= minSeconds,
889
+ "Start seconds is " + std::to_string(startSeconds) +
890
+ "; must be greater than or equal to " + std::to_string(minSeconds) +
891
+ ".");
892
+
893
+ // Note that if we can't determine the maximum seconds from the metadata,
894
+ // then we assume upper range is valid.
895
+ std::optional<double> maxSeconds =
896
+ streamMetadata.getEndStreamSeconds(seekMode_);
897
+ if (maxSeconds.has_value()) {
898
+ TORCH_CHECK(
899
+ startSeconds < maxSeconds.value(),
900
+ "Start seconds is " + std::to_string(startSeconds) +
901
+ "; must be less than " + std::to_string(maxSeconds.value()) + ".");
902
+ TORCH_CHECK(
903
+ stopSeconds <= maxSeconds.value(),
904
+ "Stop seconds (" + std::to_string(stopSeconds) +
905
+ "; must be less than or equal to " +
906
+ std::to_string(maxSeconds.value()) + ").");
907
+ }
908
+
909
+ // Note that we look at nextPts for a frame, and not its pts or duration.
910
+ // Our abstract player displays frames starting at the pts for that frame
911
+ // until the pts for the next frame. There are two consequences:
912
+ //
913
+ // 1. We ignore the duration for a frame. A frame is played until the
914
+ // next frame replaces it. This model is robust to durations being 0 or
915
+ // incorrect; our source of truth is the pts for frames. If duration is
916
+ // accurate, the nextPts for a frame would be equivalent to pts +
917
+ // duration.
918
+ // 2. In order to establish if the start of an interval maps to a
919
+ // particular frame, we need to figure out if it is ordered after the
920
+ // frame's pts, but before the next frames's pts.
921
+
922
+ int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds);
923
+ int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds);
924
+ int64_t numFrames = stopFrameIndex - startFrameIndex;
925
+
926
+ FrameBatchOutput frameBatchOutput(
927
+ numFrames,
928
+ resizedOutputDims_.value_or(metadataDims_),
929
+ videoStreamOptions.device);
930
+ for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
931
+ FrameOutput frameOutput =
932
+ getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
933
+ frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds;
934
+ frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds;
935
+ }
936
+ frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
937
+
938
+ return frameBatchOutput;
939
+ }
940
+
941
+ // Note [Audio Decoding Design]
942
+ // This note explains why audio decoding is implemented the way it is, and why
943
+ // it inherently differs from video decoding.
944
+ //
945
+ // Like for video, FFmpeg exposes the concept of a frame for audio streams. An
946
+ // audio frame is a contiguous sequence of samples, where a sample consists of
947
+ // `numChannels` values. An audio frame, or a sequence thereof, is always
948
+ // converted into a tensor of shape `(numChannels, numSamplesPerChannel)`.
949
+ //
950
+ // The notion of 'frame' in audio isn't what users want to interact with.
951
+ // Users want to interact with samples. The C++ and core APIs return frames,
952
+ // because we want those to be close to FFmpeg concepts, but the higher-level
953
+ // public APIs expose samples. As a result:
954
+ // - We don't expose index-based APIs for audio, because that would mean
955
+ // exposing the concept of audio frame. For now, we think exposing
956
+ // time-based APIs is more natural.
957
+ // - We never perform a scan for audio streams. We don't need to, since we
958
+ // won't
959
+ // be converting timestamps to indices. That's why we enforce the seek_mode
960
+ // to be "approximate" (which is slightly misleading, because technically
961
+ // the output samples will be at their exact positions. But this
962
+ // incongruence is only exposed at the C++/core private levels).
963
+ //
964
+ // Audio frames are of variable dimensions: in the same stream, a frame can
965
+ // contain 1024 samples and the next one may contain 512 [1]. This makes it
966
+ // impossible to stack audio frames in the same way we can stack video frames.
967
+ // This is one of the main reasons we cannot reuse the same pre-allocation
968
+ // logic we have for videos in getFramesPlayedInRange(): pre-allocating a
969
+ // batch requires constant (and known) frame dimensions. That's also why
970
+ // *concatenated* along the samples dimension, not stacked.
971
+ //
972
+ // [IMPORTANT!] There is one key invariant that we must respect when decoding
973
+ // audio frames:
974
+ //
975
+ // BEFORE DECODING FRAME i, WE MUST DECODE ALL FRAMES j < i.
976
+ //
977
+ // Always. Why? We don't know. What we know is that if we don't, we get
978
+ // clipped, incorrect audio as output [2]. All other (correct) libraries like
979
+ // TorchAudio or Decord do something similar, whether it was intended or not.
980
+ // This has a few implications:
981
+ // - The **only** place we're allowed to seek to in an audio stream is the
982
+ // stream's beginning. This ensures that if we need a frame, we'll have
983
+ // decoded all previous frames.
984
+ // - Because of that, we don't allow the public APIs to seek. Public APIs can
985
+ // call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually
986
+ // seek.
987
+ // - We try not to seek, when we can avoid it. Typically if the next frame we
988
+ // need is in the future, we don't seek back to the beginning, we just
989
+ // decode all the frames in-between.
990
+ //
991
+ // [2] If you're brave and curious, you can read the long "Seek offset for
992
+ // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which
993
+ // sums up past (and failed) attemps at working around this issue.
994
+ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
995
+ double startSeconds,
996
+ std::optional<double> stopSecondsOptional) {
997
+ validateActiveStream(AVMEDIA_TYPE_AUDIO);
998
+
999
+ if (stopSecondsOptional.has_value()) {
1000
+ TORCH_CHECK(
1001
+ startSeconds <= *stopSecondsOptional,
1002
+ "Start seconds (" + std::to_string(startSeconds) +
1003
+ ") must be less than or equal to stop seconds (" +
1004
+ std::to_string(*stopSecondsOptional) + ").");
1005
+ }
1006
+
1007
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
1008
+
1009
+ if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) {
1010
+ // For consistency with video
1011
+ int numChannels = getNumChannels(streamInfo.codecContext);
1012
+ return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0};
1013
+ }
1014
+
1015
+ auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
1016
+ if (startPts < lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_) {
1017
+ // If we need to seek backwards, then we have to seek back to the
1018
+ // beginning of the stream. See [Audio Decoding Design].
1019
+ setCursor(INT64_MIN);
1020
+ }
1021
+
1022
+ // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
1023
+ // cat(). This would save a copy. We know the duration of the output and the
1024
+ // sample rate, so in theory we know the number of output samples.
1025
+ std::vector<torch::Tensor> frames;
1026
+
1027
+ std::optional<double> firstFramePtsSeconds = std::nullopt;
1028
+ auto stopPts = stopSecondsOptional.has_value()
1029
+ ? secondsToClosestPts(*stopSecondsOptional, streamInfo.timeBase)
1030
+ : INT64_MAX;
1031
+ auto finished = false;
1032
+ while (!finished) {
1033
+ try {
1034
+ UniqueAVFrame avFrame =
1035
+ decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) {
1036
+ return startPts < getPtsOrDts(avFrame) + getDuration(avFrame) &&
1037
+ stopPts > getPtsOrDts(avFrame);
1038
+ });
1039
+ auto frameOutput = convertAVFrameToFrameOutput(avFrame);
1040
+ if (!firstFramePtsSeconds.has_value()) {
1041
+ firstFramePtsSeconds = frameOutput.ptsSeconds;
1042
+ }
1043
+ frames.push_back(frameOutput.data);
1044
+ } catch (const EndOfFileException&) {
1045
+ finished = true;
1046
+ }
1047
+
1048
+ // If stopSeconds is in [begin, end] of the last decoded frame, we should
1049
+ // stop decoding more frames. Note that if we were to use [begin, end),
1050
+ // which may seem more natural, then we would decode the frame starting at
1051
+ // stopSeconds, which isn't what we want!
1052
+ auto lastDecodedAvFrameEnd =
1053
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_;
1054
+ finished |= (lastDecodedAvFramePts_) <= stopPts &&
1055
+ (stopPts <= lastDecodedAvFrameEnd);
1056
+ }
1057
+
1058
+ auto lastSamples = deviceInterface_->maybeFlushAudioBuffers();
1059
+ if (lastSamples.has_value()) {
1060
+ frames.push_back(*lastSamples);
1061
+ }
1062
+
1063
+ TORCH_CHECK(
1064
+ frames.size() > 0 && firstFramePtsSeconds.has_value(),
1065
+ "No audio frames were decoded. ",
1066
+ "This is probably because start_seconds is too high(",
1067
+ startSeconds,
1068
+ "),",
1069
+ "or because stop_seconds(",
1070
+ stopSecondsOptional,
1071
+ ") is too low.");
1072
+
1073
+ return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
1074
+ }
1075
+
1076
+ // --------------------------------------------------------------------------
1077
+ // SEEKING APIs
1078
+ // --------------------------------------------------------------------------
1079
+
1080
+ void SingleStreamDecoder::setCursorPtsInSeconds(double seconds) {
1081
+ // We don't allow public audio decoding APIs to seek, see [Audio Decoding
1082
+ // Design]
1083
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
1084
+ setCursor(
1085
+ secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase));
1086
+ }
1087
+
1088
+ void SingleStreamDecoder::setCursor(int64_t pts) {
1089
+ cursorWasJustSet_ = true;
1090
+ cursor_ = pts;
1091
+ }
1092
+
1093
+ bool SingleStreamDecoder::canWeAvoidSeeking() const {
1094
+ // Returns true if we can avoid seeking in the AVFormatContext based on
1095
+ // heuristics that rely on the target cursor_ and the last decoded frame.
1096
+ // Seeking is expensive, so we try to avoid it when possible.
1097
+ const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
1098
+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1099
+ // For audio, we only need to seek if a backwards seek was requested
1100
+ // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was
1101
+ // called. For more context, see [Audio Decoding Design]
1102
+ return !cursorWasJustSet_;
1103
+ } else if (!cursorWasJustSet_) {
1104
+ // For videos, when decoding consecutive frames, we don't need to seek.
1105
+ return true;
1106
+ }
1107
+
1108
+ if (cursor_ < lastDecodedAvFramePts_) {
1109
+ // We can never skip a seek if we are seeking backwards.
1110
+ return false;
1111
+ }
1112
+ if (lastDecodedAvFramePts_ == cursor_) {
1113
+ // We are seeking to the exact same frame as we are currently at. Without
1114
+ // caching we have to rewind back and decode the frame again.
1115
+ // TODO: https://github.com/pytorch/torchcodec/issues/84 we could
1116
+ // implement caching.
1117
+ return false;
1118
+ }
1119
+ // We are seeking forwards. We can skip a seek if both the last decoded frame
1120
+ // and cursor_ share the same keyframe:
1121
+ // Videos have I frames and non-I frames (P and B frames). Non-I frames need
1122
+ // data from the previous I frame to be decoded.
1123
+ //
1124
+ // Imagine the cursor is at a random frame with PTS=lastDecodedAvFramePts (x
1125
+ // for brevity) and we wish to seek to a user-specified PTS=y.
1126
+ //
1127
+ // If y < x, we don't have a choice but to seek backwards to the highest I
1128
+ // frame before y.
1129
+ //
1130
+ // If y > x, we have two choices:
1131
+ //
1132
+ // 1. We could keep decoding forward until we hit y. Illustrated below:
1133
+ //
1134
+ // I P P P I P P P I P P I P
1135
+ // x y
1136
+ //
1137
+ // 2. We could try to jump to an I frame between x and y (indicated by j
1138
+ // below). And then start decoding until we encounter y. Illustrated below:
1139
+ //
1140
+ // I P P P I P P P I P P I P
1141
+ // x j y
1142
+ // (2) is only more efficient than (1) if there is an I frame between x and y.
1143
+ int lastKeyFrame = getKeyFrameIdentifier(lastDecodedAvFramePts_);
1144
+ int targetKeyFrame = getKeyFrameIdentifier(cursor_);
1145
+ return lastKeyFrame >= 0 && targetKeyFrame >= 0 &&
1146
+ lastKeyFrame == targetKeyFrame;
1147
+ }
1148
+
1149
+ // This method looks at currentPts and desiredPts and seeks in the
1150
+ // AVFormatContext if it is needed. We can skip seeking in certain cases. See
1151
+ // the comment of canWeAvoidSeeking() for details.
1152
+ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
1153
+ validateActiveStream();
1154
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
1155
+
1156
+ decodeStats_.numSeeksAttempted++;
1157
+ if (canWeAvoidSeeking()) {
1158
+ decodeStats_.numSeeksSkipped++;
1159
+ return;
1160
+ }
1161
+
1162
+ int64_t desiredPts = cursor_;
1163
+
1164
+ // For some encodings like H265, FFMPEG sometimes seeks past the point we
1165
+ // set as the max_ts. So we use our own index to give it the exact pts of
1166
+ // the key frame that we want to seek to.
1167
+ // See https://github.com/pytorch/torchcodec/issues/179 for more details.
1168
+ // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
1169
+ if (!streamInfo.keyFrames.empty()) {
1170
+ int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
1171
+ streamInfo.keyFrames, desiredPts);
1172
+ desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
1173
+ desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
1174
+ }
1175
+
1176
+ int status = avformat_seek_file(
1177
+ formatContext_.get(),
1178
+ streamInfo.streamIndex,
1179
+ INT64_MIN,
1180
+ desiredPts,
1181
+ desiredPts,
1182
+ 0);
1183
+ TORCH_CHECK(
1184
+ status >= 0,
1185
+ "Could not seek file to pts=",
1186
+ std::to_string(desiredPts),
1187
+ ": ",
1188
+ getFFMPEGErrorStringFromErrorCode(status));
1189
+
1190
+ decodeStats_.numFlushes++;
1191
+ deviceInterface_->flush();
1192
+ }
1193
+
1194
+ // --------------------------------------------------------------------------
1195
+ // LOW-LEVEL DECODING
1196
+ // --------------------------------------------------------------------------
1197
+
1198
+ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1199
+ std::function<bool(const UniqueAVFrame&)> filterFunction) {
1200
+ validateActiveStream();
1201
+
1202
+ resetDecodeStats();
1203
+
1204
+ maybeSeekToBeforeDesiredPts();
1205
+ cursorWasJustSet_ = false;
1206
+
1207
+ UniqueAVFrame avFrame(av_frame_alloc());
1208
+ AutoAVPacket autoAVPacket;
1209
+ int status = AVSUCCESS;
1210
+ bool reachedEOF = false;
1211
+
1212
+ // The default implementation uses avcodec_receive_frame and
1213
+ // avcodec_send_packet, while specialized interfaces can override for
1214
+ // hardware-specific optimizations.
1215
+ while (true) {
1216
+ status = deviceInterface_->receiveFrame(avFrame);
1217
+
1218
+ if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
1219
+ // Non-retriable error
1220
+ break;
1221
+ }
1222
+
1223
+ decodeStats_.numFramesReceivedByDecoder++;
1224
+ // Is this the kind of frame we're looking for?
1225
+ if (status == AVSUCCESS && filterFunction(avFrame)) {
1226
+ // Yes, this is the frame we'll return; break out of the decoding loop.
1227
+ break;
1228
+ } else if (status == AVSUCCESS) {
1229
+ // No, but we received a valid frame - just not the kind we're looking
1230
+ // for. The logic below will read packets and send them to the decoder.
1231
+ // But since we did just receive a frame, we should skip reading more
1232
+ // packets and sending them to the decoder and just try to receive more
1233
+ // frames from the decoder.
1234
+ continue;
1235
+ }
1236
+
1237
+ if (reachedEOF) {
1238
+ // We don't have any more packets to receive. So keep on pulling frames
1239
+ // from decoder's internal buffers.
1240
+ continue;
1241
+ }
1242
+
1243
+ // We still haven't found the frame we're looking for. So let's read more
1244
+ // packets and send them to the decoder.
1245
+ ReferenceAVPacket packet(autoAVPacket);
1246
+ do {
1247
+ status = av_read_frame(formatContext_.get(), packet.get());
1248
+ decodeStats_.numPacketsRead++;
1249
+
1250
+ if (status == AVERROR_EOF) {
1251
+ // End of file reached. We must drain the decoder
1252
+ status = deviceInterface_->sendEOFPacket();
1253
+ TORCH_CHECK(
1254
+ status >= AVSUCCESS,
1255
+ "Could not flush decoder: ",
1256
+ getFFMPEGErrorStringFromErrorCode(status));
1257
+
1258
+ reachedEOF = true;
1259
+ break;
1260
+ }
1261
+
1262
+ TORCH_CHECK(
1263
+ status >= AVSUCCESS,
1264
+ "Could not read frame from input file: ",
1265
+ getFFMPEGErrorStringFromErrorCode(status));
1266
+
1267
+ } while (packet->stream_index != activeStreamIndex_);
1268
+
1269
+ if (reachedEOF) {
1270
+ // We don't have any more packets to send to the decoder. So keep on
1271
+ // pulling frames from its internal buffers.
1272
+ continue;
1273
+ }
1274
+
1275
+ // We got a valid packet. Send it to the decoder, and we'll receive it in
1276
+ // the next iteration.
1277
+ status = deviceInterface_->sendPacket(packet);
1278
+ TORCH_CHECK(
1279
+ status >= AVSUCCESS,
1280
+ "Could not push packet to decoder: ",
1281
+ getFFMPEGErrorStringFromErrorCode(status));
1282
+
1283
+ decodeStats_.numPacketsSentToDecoder++;
1284
+ }
1285
+
1286
+ if (status < AVSUCCESS) {
1287
+ if (reachedEOF || status == AVERROR_EOF) {
1288
+ throw SingleStreamDecoder::EndOfFileException(
1289
+ "Requested next frame while there are no more frames left to "
1290
+ "decode.");
1291
+ }
1292
+ TORCH_CHECK(
1293
+ false,
1294
+ "Could not receive frame from decoder: ",
1295
+ getFFMPEGErrorStringFromErrorCode(status));
1296
+ }
1297
+
1298
+ // Note that we don't flush the decoder when we reach EOF (even though
1299
+ // that's mentioned in
1300
+ // https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). This is
1301
+ // because we may have packets internally in the decoder that we haven't
1302
+ // received as frames. Eventually we will either hit AVERROR_EOF from
1303
+ // av_receive_frame() or the user will have seeked to a different location
1304
+ // in the file and that will flush the decoder.
1305
+ lastDecodedAvFramePts_ = getPtsOrDts(avFrame);
1306
+ lastDecodedAvFrameDuration_ = getDuration(avFrame);
1307
+
1308
+ return avFrame;
1309
+ }
1310
+
1311
+ // --------------------------------------------------------------------------
1312
+ // AVFRAME <-> FRAME OUTPUT CONVERSION
1313
+ // --------------------------------------------------------------------------
1314
+
1315
+ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
1316
+ UniqueAVFrame& avFrame,
1317
+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1318
+ // Convert the frame to tensor.
1319
+ FrameOutput frameOutput;
1320
+ frameOutput.ptsSeconds = ptsToSeconds(
1321
+ getPtsOrDts(avFrame),
1322
+ formatContext_->streams[activeStreamIndex_]->time_base);
1323
+ frameOutput.durationSeconds = ptsToSeconds(
1324
+ getDuration(avFrame),
1325
+ formatContext_->streams[activeStreamIndex_]->time_base);
1326
+ deviceInterface_->convertAVFrameToFrameOutput(
1327
+ avFrame, frameOutput, std::move(preAllocatedOutputTensor));
1328
+ return frameOutput;
1329
+ }
1330
+
1331
+ // --------------------------------------------------------------------------
1332
+ // OUTPUT ALLOCATION AND SHAPE CONVERSION
1333
+ // --------------------------------------------------------------------------
1334
+
1335
+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require
1336
+ // so. The [N] leading batch-dimension is optional i.e. the input tensor can
1337
+ // be 3D or 4D. Calling permute() is guaranteed to return a view as per the
1338
+ // docs: https://pytorch.org/docs/stable/generated/torch.permute.html
1339
+ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW(
1340
+ torch::Tensor& hwcTensor) {
1341
+ if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
1342
+ "NHWC") {
1343
+ return hwcTensor;
1344
+ }
1345
+ auto numDimensions = hwcTensor.dim();
1346
+ auto shape = hwcTensor.sizes();
1347
+ if (numDimensions == 3) {
1348
+ TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
1349
+ return hwcTensor.permute({2, 0, 1});
1350
+ } else if (numDimensions == 4) {
1351
+ TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
1352
+ return hwcTensor.permute({0, 3, 1, 2});
1353
+ } else {
1354
+ TORCH_CHECK(
1355
+ false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
1356
+ }
1357
+ }
1358
+
1359
+ // --------------------------------------------------------------------------
1360
+ // PTS <-> INDEX CONVERSIONS
1361
+ // --------------------------------------------------------------------------
1362
+
1363
+ int SingleStreamDecoder::getKeyFrameIdentifier(int64_t pts) const {
1364
+ // This function "identifies" a key frame for a given pts value.
1365
+ // We use the term "identifier" rather than "index" because the nature of the
1366
+ // index that is returned depends on various factors:
1367
+ // - If seek_mode is exact, we return the index of the key frame in the
1368
+ // scanned key-frame vector (streamInfo.keyFrames). So the returned value is
1369
+ // in [0, num_key_frames).
1370
+ // - If seek_mode is approximate, we use av_index_search_timestamp() which
1371
+ // may return a value in [0, num_key_frames) like for mkv, but also a value
1372
+ // in [0, num_frames) like for mp4. It really depends on the container.
1373
+ //
1374
+ // The range of the "identifier" doesn't matter that much, for now we only
1375
+ // use it to uniquely identify a key frame in canWeAvoidSeeking().
1376
+ const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
1377
+ if (streamInfo.keyFrames.empty()) {
1378
+ return av_index_search_timestamp(
1379
+ streamInfo.stream, pts, AVSEEK_FLAG_BACKWARD);
1380
+ } else {
1381
+ return getKeyFrameIndexForPtsUsingScannedIndex(streamInfo.keyFrames, pts);
1382
+ }
1383
+ }
1384
+
1385
+ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
1386
+ const std::vector<SingleStreamDecoder::FrameInfo>& keyFrames,
1387
+ int64_t pts) const {
1388
+ auto upperBound = std::upper_bound(
1389
+ keyFrames.begin(),
1390
+ keyFrames.end(),
1391
+ pts,
1392
+ [](int64_t pts, const SingleStreamDecoder::FrameInfo& frameInfo) {
1393
+ return pts < frameInfo.pts;
1394
+ });
1395
+ if (upperBound == keyFrames.begin()) {
1396
+ return -1;
1397
+ }
1398
+ return upperBound - 1 - keyFrames.begin();
1399
+ }
1400
+
1401
+ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
1402
+ auto& streamInfo = streamInfos_[activeStreamIndex_];
1403
+ switch (seekMode_) {
1404
+ case SeekMode::custom_frame_mappings:
1405
+ case SeekMode::exact: {
1406
+ auto frame = std::lower_bound(
1407
+ streamInfo.allFrames.begin(),
1408
+ streamInfo.allFrames.end(),
1409
+ seconds,
1410
+ [&streamInfo](const FrameInfo& info, double start) {
1411
+ return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start;
1412
+ });
1413
+
1414
+ return frame - streamInfo.allFrames.begin();
1415
+ }
1416
+ case SeekMode::approximate: {
1417
+ auto& streamMetadata =
1418
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
1419
+ TORCH_CHECK(
1420
+ streamMetadata.averageFpsFromHeader.has_value(),
1421
+ "Cannot use approximate mode since we couldn't find the average fps from the metadata.");
1422
+ return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
1423
+ }
1424
+ default:
1425
+ TORCH_CHECK(false, "Unknown SeekMode");
1426
+ }
1427
+ }
1428
+
1429
+ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
1430
+ auto& streamInfo = streamInfos_[activeStreamIndex_];
1431
+ switch (seekMode_) {
1432
+ case SeekMode::custom_frame_mappings:
1433
+ case SeekMode::exact: {
1434
+ auto frame = std::upper_bound(
1435
+ streamInfo.allFrames.begin(),
1436
+ streamInfo.allFrames.end(),
1437
+ seconds,
1438
+ [&streamInfo](double stop, const FrameInfo& info) {
1439
+ return stop <= ptsToSeconds(info.pts, streamInfo.timeBase);
1440
+ });
1441
+
1442
+ return frame - streamInfo.allFrames.begin();
1443
+ }
1444
+ case SeekMode::approximate: {
1445
+ auto& streamMetadata =
1446
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
1447
+ TORCH_CHECK(
1448
+ streamMetadata.averageFpsFromHeader.has_value(),
1449
+ "Cannot use approximate mode since we couldn't find the average fps from the metadata.");
1450
+ return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
1451
+ }
1452
+ default:
1453
+ TORCH_CHECK(false, "Unknown SeekMode");
1454
+ }
1455
+ }
1456
+
1457
+ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
1458
+ auto& streamInfo = streamInfos_[activeStreamIndex_];
1459
+ switch (seekMode_) {
1460
+ case SeekMode::custom_frame_mappings:
1461
+ case SeekMode::exact:
1462
+ return streamInfo.allFrames[frameIndex].pts;
1463
+ case SeekMode::approximate: {
1464
+ auto& streamMetadata =
1465
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
1466
+ TORCH_CHECK(
1467
+ streamMetadata.averageFpsFromHeader.has_value(),
1468
+ "Cannot use approximate mode since we couldn't find the average fps from the metadata.");
1469
+ return secondsToClosestPts(
1470
+ frameIndex / streamMetadata.averageFpsFromHeader.value(),
1471
+ streamInfo.timeBase);
1472
+ }
1473
+ default:
1474
+ TORCH_CHECK(false, "Unknown SeekMode");
1475
+ }
1476
+ }
1477
+
1478
+ // --------------------------------------------------------------------------
1479
+ // STREAM AND METADATA APIS
1480
+ // --------------------------------------------------------------------------
1481
+
1482
+ // --------------------------------------------------------------------------
1483
+ // VALIDATION UTILS
1484
+ // --------------------------------------------------------------------------
1485
+
1486
+ void SingleStreamDecoder::validateActiveStream(
1487
+ std::optional<AVMediaType> avMediaType) {
1488
+ auto errorMsg =
1489
+ "Provided stream index=" + std::to_string(activeStreamIndex_) +
1490
+ " was not previously added.";
1491
+ TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, errorMsg);
1492
+ TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, errorMsg);
1493
+
1494
+ int allStreamMetadataSize =
1495
+ static_cast<int>(containerMetadata_.allStreamMetadata.size());
1496
+ TORCH_CHECK(
1497
+ activeStreamIndex_ >= 0 && activeStreamIndex_ < allStreamMetadataSize,
1498
+ "Invalid stream index=" + std::to_string(activeStreamIndex_) +
1499
+ "; valid indices are in the range [0, " +
1500
+ std::to_string(allStreamMetadataSize) + ").");
1501
+
1502
+ if (avMediaType.has_value()) {
1503
+ TORCH_CHECK(
1504
+ streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(),
1505
+ "The method you called isn't supported. ",
1506
+ "If you're seeing this error, you are probably trying to call an ",
1507
+ "unsupported method on an audio stream.");
1508
+ }
1509
+ }
1510
+
1511
+ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
1512
+ TORCH_CHECK(
1513
+ scannedAllStreams_,
1514
+ "Must scan all streams to update metadata before calling ",
1515
+ msg);
1516
+ }
1517
+
1518
+ void SingleStreamDecoder::validateFrameIndex(
1519
+ const StreamMetadata& streamMetadata,
1520
+ int64_t frameIndex) {
1521
+ if (frameIndex < 0) {
1522
+ throw std::out_of_range(
1523
+ "Invalid frame index=" + std::to_string(frameIndex) +
1524
+ " for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1525
+ "; negative indices must have an absolute value less than the number of frames, "
1526
+ "and the number of frames must be known.");
1527
+ }
1528
+
1529
+ // Note that if we do not have the number of frames available in our
1530
+ // metadata, then we assume that the frameIndex is valid.
1531
+ std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
1532
+ if (numFrames.has_value()) {
1533
+ if (frameIndex >= numFrames.value()) {
1534
+ throw std::out_of_range(
1535
+ "Invalid frame index=" + std::to_string(frameIndex) +
1536
+ " for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1537
+ "; must be less than " + std::to_string(numFrames.value()));
1538
+ }
1539
+ }
1540
+ }
1541
+
1542
+ // --------------------------------------------------------------------------
1543
+ // MORALLY PRIVATE UTILS
1544
+ // --------------------------------------------------------------------------
1545
+
1546
+ SingleStreamDecoder::DecodeStats SingleStreamDecoder::getDecodeStats() const {
1547
+ return decodeStats_;
1548
+ }
1549
+
1550
+ std::ostream& operator<<(
1551
+ std::ostream& os,
1552
+ const SingleStreamDecoder::DecodeStats& stats) {
1553
+ os << "DecodeStats{"
1554
+ << "numFramesReceivedByDecoder=" << stats.numFramesReceivedByDecoder
1555
+ << ", numPacketsRead=" << stats.numPacketsRead
1556
+ << ", numPacketsSentToDecoder=" << stats.numPacketsSentToDecoder
1557
+ << ", numSeeksAttempted=" << stats.numSeeksAttempted
1558
+ << ", numSeeksSkipped=" << stats.numSeeksSkipped
1559
+ << ", numFlushes=" << stats.numFlushes << "}";
1560
+
1561
+ return os;
1562
+ }
1563
+
1564
+ void SingleStreamDecoder::resetDecodeStats() {
1565
+ decodeStats_ = DecodeStats{};
1566
+ }
1567
+
1568
+ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
1569
+ validateActiveStream(AVMEDIA_TYPE_VIDEO);
1570
+ validateScannedAllStreams("getPtsSecondsForFrame");
1571
+
1572
+ const auto& streamInfo = streamInfos_[activeStreamIndex_];
1573
+ const auto& streamMetadata =
1574
+ containerMetadata_.allStreamMetadata[activeStreamIndex_];
1575
+ validateFrameIndex(streamMetadata, frameIndex);
1576
+
1577
+ return ptsToSeconds(
1578
+ streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
1579
+ }
1580
+
1581
+ std::string SingleStreamDecoder::getDeviceInterfaceDetails() const {
1582
+ TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist.");
1583
+ return deviceInterface_->getDetails();
1584
+ }
1585
+
1586
+ } // namespace facebook::torchcodec