torchcodec 0.7.0__cp313-cp313-win_amd64.whl → 0.8.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (66) hide show
  1. torchcodec/_core/AVIOTensorContext.cpp +23 -16
  2. torchcodec/_core/AVIOTensorContext.h +2 -1
  3. torchcodec/_core/BetaCudaDeviceInterface.cpp +718 -0
  4. torchcodec/_core/BetaCudaDeviceInterface.h +193 -0
  5. torchcodec/_core/CMakeLists.txt +18 -3
  6. torchcodec/_core/CUDACommon.cpp +330 -0
  7. torchcodec/_core/CUDACommon.h +51 -0
  8. torchcodec/_core/Cache.h +6 -20
  9. torchcodec/_core/CpuDeviceInterface.cpp +195 -108
  10. torchcodec/_core/CpuDeviceInterface.h +84 -19
  11. torchcodec/_core/CudaDeviceInterface.cpp +227 -376
  12. torchcodec/_core/CudaDeviceInterface.h +38 -6
  13. torchcodec/_core/DeviceInterface.cpp +57 -19
  14. torchcodec/_core/DeviceInterface.h +97 -16
  15. torchcodec/_core/Encoder.cpp +346 -9
  16. torchcodec/_core/Encoder.h +62 -1
  17. torchcodec/_core/FFMPEGCommon.cpp +190 -3
  18. torchcodec/_core/FFMPEGCommon.h +27 -1
  19. torchcodec/_core/FilterGraph.cpp +30 -22
  20. torchcodec/_core/FilterGraph.h +15 -1
  21. torchcodec/_core/Frame.cpp +22 -7
  22. torchcodec/_core/Frame.h +15 -61
  23. torchcodec/_core/Metadata.h +2 -2
  24. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  25. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  26. torchcodec/_core/NVDECCache.cpp +60 -0
  27. torchcodec/_core/NVDECCache.h +102 -0
  28. torchcodec/_core/SingleStreamDecoder.cpp +196 -201
  29. torchcodec/_core/SingleStreamDecoder.h +42 -15
  30. torchcodec/_core/StreamOptions.h +16 -6
  31. torchcodec/_core/Transform.cpp +87 -0
  32. torchcodec/_core/Transform.h +84 -0
  33. torchcodec/_core/__init__.py +4 -0
  34. torchcodec/_core/custom_ops.cpp +257 -32
  35. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
  36. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  37. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  38. torchcodec/_core/ops.py +147 -44
  39. torchcodec/_core/pybind_ops.cpp +22 -59
  40. torchcodec/_samplers/video_clip_sampler.py +7 -19
  41. torchcodec/decoders/__init__.py +1 -0
  42. torchcodec/decoders/_decoder_utils.py +61 -1
  43. torchcodec/decoders/_video_decoder.py +46 -20
  44. torchcodec/libtorchcodec_core4.dll +0 -0
  45. torchcodec/libtorchcodec_core5.dll +0 -0
  46. torchcodec/libtorchcodec_core6.dll +0 -0
  47. torchcodec/libtorchcodec_core7.dll +0 -0
  48. torchcodec/libtorchcodec_core8.dll +0 -0
  49. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  50. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  51. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  52. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  53. torchcodec/libtorchcodec_custom_ops8.dll +0 -0
  54. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  55. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  56. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  57. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  58. torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
  59. torchcodec/samplers/_time_based.py +8 -0
  60. torchcodec/version.py +1 -1
  61. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/METADATA +29 -16
  62. torchcodec-0.8.1.dist-info/RECORD +82 -0
  63. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/WHEEL +1 -1
  64. torchcodec-0.7.0.dist-info/RECORD +0 -67
  65. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/licenses/LICENSE +0 -0
  66. {torchcodec-0.7.0.dist-info → torchcodec-0.8.1.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,12 @@
12
12
  #include <sstream>
13
13
  #include <stdexcept>
14
14
  #include <string_view>
15
+ #include "Metadata.h"
15
16
  #include "torch/types.h"
16
17
 
17
18
  namespace facebook::torchcodec {
18
19
  namespace {
19
20
 
20
- double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
21
- // To perform the multiplication before the division, av_q2d is not used
22
- return static_cast<double>(pts) * timeBase.num / timeBase.den;
23
- }
24
-
25
- int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
26
- return static_cast<int64_t>(
27
- std::round(seconds * timeBase.den / timeBase.num));
28
- }
29
-
30
21
  // Some videos aren't properly encoded and do not specify pts values for
31
22
  // packets, and thus for frames. Unset values correspond to INT64_MIN. When that
32
23
  // happens, we fallback to the dts value which hopefully exists and is correct.
@@ -322,19 +313,35 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
322
313
  void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
323
314
  int streamIndex,
324
315
  FrameMappings customFrameMappings) {
325
- auto& all_frames = customFrameMappings.all_frames;
326
- auto& is_key_frame = customFrameMappings.is_key_frame;
327
- auto& duration = customFrameMappings.duration;
316
+ TORCH_CHECK(
317
+ customFrameMappings.all_frames.dtype() == torch::kLong &&
318
+ customFrameMappings.is_key_frame.dtype() == torch::kBool &&
319
+ customFrameMappings.duration.dtype() == torch::kLong,
320
+ "all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype.");
321
+ const torch::Tensor& all_frames =
322
+ customFrameMappings.all_frames.to(torch::kLong);
323
+ const torch::Tensor& is_key_frame =
324
+ customFrameMappings.is_key_frame.to(torch::kBool);
325
+ const torch::Tensor& duration = customFrameMappings.duration.to(torch::kLong);
328
326
  TORCH_CHECK(
329
327
  all_frames.size(0) == is_key_frame.size(0) &&
330
328
  is_key_frame.size(0) == duration.size(0),
331
329
  "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.");
332
330
 
331
+ // Allocate vectors using num frames to reduce reallocations
332
+ int64_t numFrames = all_frames.size(0);
333
+ streamInfos_[streamIndex].allFrames.reserve(numFrames);
334
+ streamInfos_[streamIndex].keyFrames.reserve(numFrames);
335
+ // Use accessor to efficiently access tensor elements
336
+ auto pts_data = all_frames.accessor<int64_t, 1>();
337
+ auto is_key_frame_data = is_key_frame.accessor<bool, 1>();
338
+ auto duration_data = duration.accessor<int64_t, 1>();
339
+
333
340
  auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
334
341
 
335
- streamMetadata.beginStreamPtsFromContent = all_frames[0].item<int64_t>();
342
+ streamMetadata.beginStreamPtsFromContent = pts_data[0];
336
343
  streamMetadata.endStreamPtsFromContent =
337
- all_frames[-1].item<int64_t>() + duration[-1].item<int64_t>();
344
+ pts_data[numFrames - 1] + duration_data[numFrames - 1];
338
345
 
339
346
  auto avStream = formatContext_->streams[streamIndex];
340
347
  streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds(
@@ -343,17 +350,16 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
343
350
  streamMetadata.endStreamPtsSecondsFromContent = ptsToSeconds(
344
351
  *streamMetadata.endStreamPtsFromContent, avStream->time_base);
345
352
 
346
- streamMetadata.numFramesFromContent = all_frames.size(0);
347
- for (int64_t i = 0; i < all_frames.size(0); ++i) {
353
+ streamMetadata.numFramesFromContent = numFrames;
354
+ for (int64_t i = 0; i < numFrames; ++i) {
348
355
  FrameInfo frameInfo;
349
- frameInfo.pts = all_frames[i].item<int64_t>();
350
- frameInfo.isKeyFrame = is_key_frame[i].item<bool>();
356
+ frameInfo.pts = pts_data[i];
357
+ frameInfo.isKeyFrame = is_key_frame_data[i];
351
358
  streamInfos_[streamIndex].allFrames.push_back(frameInfo);
352
359
  if (frameInfo.isKeyFrame) {
353
360
  streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
354
361
  }
355
362
  }
356
- // Sort all frames by their pts
357
363
  sortAllFrames();
358
364
  }
359
365
 
@@ -384,6 +390,7 @@ void SingleStreamDecoder::addStream(
384
390
  int streamIndex,
385
391
  AVMediaType mediaType,
386
392
  const torch::Device& device,
393
+ const std::string_view deviceVariant,
387
394
  std::optional<int> ffmpegThreadCount) {
388
395
  TORCH_CHECK(
389
396
  activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -412,8 +419,6 @@ void SingleStreamDecoder::addStream(
412
419
  streamInfo.stream = formatContext_->streams[activeStreamIndex_];
413
420
  streamInfo.avMediaType = mediaType;
414
421
 
415
- deviceInterface_ = createDeviceInterface(device);
416
-
417
422
  // This should never happen, checking just to be safe.
418
423
  TORCH_CHECK(
419
424
  streamInfo.stream->codecpar->codec_type == mediaType,
@@ -421,19 +426,22 @@ void SingleStreamDecoder::addStream(
421
426
  activeStreamIndex_,
422
427
  " which is of the wrong media type.");
423
428
 
429
+ deviceInterface_ = createDeviceInterface(device, deviceVariant);
430
+ TORCH_CHECK(
431
+ deviceInterface_ != nullptr,
432
+ "Failed to create device interface. This should never happen, please report.");
433
+
424
434
  // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
425
435
  // addStream() which is supposed to be generic
426
436
  if (mediaType == AVMEDIA_TYPE_VIDEO) {
427
- if (deviceInterface_) {
428
- avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
429
- deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
430
- .value_or(avCodec));
431
- }
437
+ avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
438
+ deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
439
+ .value_or(avCodec));
432
440
  }
433
441
 
434
442
  AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
435
443
  TORCH_CHECK(codecContext != nullptr);
436
- streamInfo.codecContext.reset(codecContext);
444
+ streamInfo.codecContext = makeSharedAVCodecContext(codecContext);
437
445
 
438
446
  int retVal = avcodec_parameters_to_context(
439
447
  streamInfo.codecContext.get(), streamInfo.stream->codecpar);
@@ -442,19 +450,22 @@ void SingleStreamDecoder::addStream(
442
450
  streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
443
451
  streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
444
452
 
445
- // TODO_CODE_QUALITY same as above.
446
- if (mediaType == AVMEDIA_TYPE_VIDEO) {
447
- if (deviceInterface_) {
448
- deviceInterface_->initializeContext(codecContext);
449
- }
450
- }
451
-
453
+ // Note that we must make sure to register the harware device context
454
+ // with the codec context before calling avcodec_open2(). Otherwise, decoding
455
+ // will happen on the CPU and not the hardware device.
456
+ deviceInterface_->registerHardwareDeviceWithCodec(
457
+ streamInfo.codecContext.get());
452
458
  retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
453
459
  TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
454
460
 
455
- codecContext->time_base = streamInfo.stream->time_base;
461
+ streamInfo.codecContext->time_base = streamInfo.stream->time_base;
462
+
463
+ // Initialize the device interface with the codec context
464
+ deviceInterface_->initialize(
465
+ streamInfo.stream, formatContext_, streamInfo.codecContext);
466
+
456
467
  containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
457
- std::string(avcodec_get_name(codecContext->codec_id));
468
+ std::string(avcodec_get_name(streamInfo.codecContext->codec_id));
458
469
 
459
470
  // We will only need packets from the active stream, so we tell FFmpeg to
460
471
  // discard packets from the other streams. Note that av_read_frame() may still
@@ -469,12 +480,18 @@ void SingleStreamDecoder::addStream(
469
480
 
470
481
  void SingleStreamDecoder::addVideoStream(
471
482
  int streamIndex,
483
+ std::vector<Transform*>& transforms,
472
484
  const VideoStreamOptions& videoStreamOptions,
473
485
  std::optional<FrameMappings> customFrameMappings) {
486
+ TORCH_CHECK(
487
+ transforms.empty() || videoStreamOptions.device == torch::kCPU,
488
+ " Transforms are only supported for CPU devices.");
489
+
474
490
  addStream(
475
491
  streamIndex,
476
492
  AVMEDIA_TYPE_VIDEO,
477
493
  videoStreamOptions.device,
494
+ videoStreamOptions.deviceVariant,
478
495
  videoStreamOptions.ffmpegThreadCount);
479
496
 
480
497
  auto& streamMetadata =
@@ -501,8 +518,25 @@ void SingleStreamDecoder::addVideoStream(
501
518
  customFrameMappings.has_value(),
502
519
  "Missing frame mappings when custom_frame_mappings seek mode is set.");
503
520
  readCustomFrameMappingsUpdateMetadataAndIndex(
504
- streamIndex, customFrameMappings.value());
521
+ activeStreamIndex_, customFrameMappings.value());
505
522
  }
523
+
524
+ metadataDims_ =
525
+ FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
526
+ for (auto& transform : transforms) {
527
+ TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
528
+ if (transform->getOutputFrameDims().has_value()) {
529
+ resizedOutputDims_ = transform->getOutputFrameDims().value();
530
+ }
531
+ transform->validate(streamMetadata);
532
+
533
+ // Note that we are claiming ownership of the transform objects passed in to
534
+ // us.
535
+ transforms_.push_back(std::unique_ptr<Transform>(transform));
536
+ }
537
+
538
+ deviceInterface_->initializeVideo(
539
+ videoStreamOptions, transforms_, resizedOutputDims_);
506
540
  }
507
541
 
508
542
  void SingleStreamDecoder::addAudioStream(
@@ -587,11 +621,18 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
587
621
  }
588
622
 
589
623
  FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
590
- const std::vector<int64_t>& frameIndices) {
624
+ const torch::Tensor& frameIndices) {
591
625
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
592
626
 
593
- auto indicesAreSorted =
594
- std::is_sorted(frameIndices.begin(), frameIndices.end());
627
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
628
+
629
+ bool indicesAreSorted = true;
630
+ for (int64_t i = 1; i < frameIndices.numel(); ++i) {
631
+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
632
+ indicesAreSorted = false;
633
+ break;
634
+ }
635
+ }
595
636
 
596
637
  std::vector<size_t> argsort;
597
638
  if (!indicesAreSorted) {
@@ -599,27 +640,29 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
599
640
  // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
600
641
  // to use to decode the frames
601
642
  // and argsort is [ 1, 3, 2, 0]
602
- argsort.resize(frameIndices.size());
643
+ argsort.resize(frameIndices.numel());
603
644
  for (size_t i = 0; i < argsort.size(); ++i) {
604
645
  argsort[i] = i;
605
646
  }
606
647
  std::sort(
607
- argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
608
- return frameIndices[a] < frameIndices[b];
648
+ argsort.begin(),
649
+ argsort.end(),
650
+ [&frameIndicesAccessor](size_t a, size_t b) {
651
+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
609
652
  });
610
653
  }
611
654
 
612
- const auto& streamMetadata =
613
- containerMetadata_.allStreamMetadata[activeStreamIndex_];
614
655
  const auto& streamInfo = streamInfos_[activeStreamIndex_];
615
656
  const auto& videoStreamOptions = streamInfo.videoStreamOptions;
616
657
  FrameBatchOutput frameBatchOutput(
617
- frameIndices.size(), videoStreamOptions, streamMetadata);
658
+ frameIndices.numel(),
659
+ resizedOutputDims_.value_or(metadataDims_),
660
+ videoStreamOptions.device);
618
661
 
619
662
  auto previousIndexInVideo = -1;
620
- for (size_t f = 0; f < frameIndices.size(); ++f) {
663
+ for (int64_t f = 0; f < frameIndices.numel(); ++f) {
621
664
  auto indexInOutput = indicesAreSorted ? f : argsort[f];
622
- auto indexInVideo = frameIndices[indexInOutput];
665
+ auto indexInVideo = frameIndicesAccessor[indexInOutput];
623
666
 
624
667
  if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
625
668
  // Avoid decoding the same frame twice
@@ -657,8 +700,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
657
700
  TORCH_CHECK(
658
701
  step > 0, "Step must be greater than 0; is " + std::to_string(step));
659
702
 
660
- // Note that if we do not have the number of frames available in our metadata,
661
- // then we assume that the upper part of the range is valid.
703
+ // Note that if we do not have the number of frames available in our
704
+ // metadata, then we assume that the upper part of the range is valid.
662
705
  std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
663
706
  if (numFrames.has_value()) {
664
707
  TORCH_CHECK(
@@ -671,7 +714,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
671
714
  int64_t numOutputFrames = std::ceil((stop - start) / double(step));
672
715
  const auto& videoStreamOptions = streamInfo.videoStreamOptions;
673
716
  FrameBatchOutput frameBatchOutput(
674
- numOutputFrames, videoStreamOptions, streamMetadata);
717
+ numOutputFrames,
718
+ resizedOutputDims_.value_or(metadataDims_),
719
+ videoStreamOptions.device);
675
720
 
676
721
  for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
677
722
  FrameOutput frameOutput =
@@ -687,9 +732,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
687
732
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
688
733
  StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689
734
  double lastDecodedStartTime =
690
- ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
735
+ ptsToSeconds(lastDecodedAvFramePts_, streamInfo.timeBase);
691
736
  double lastDecodedEndTime = ptsToSeconds(
692
- streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration,
737
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_,
693
738
  streamInfo.timeBase);
694
739
  if (seconds >= lastDecodedStartTime && seconds < lastDecodedEndTime) {
695
740
  // We are in the same frame as the one we just returned. However, since we
@@ -709,9 +754,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
709
754
  // FFMPEG seeked past the frame we are looking for even though we
710
755
  // set max_ts to be our needed timestamp in avformat_seek_file()
711
756
  // in maybeSeekToBeforeDesiredPts().
712
- // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137
713
- // In this case we return the very next frame instead of throwing an
714
- // exception.
757
+ // This could be a bug in FFMPEG:
758
+ // https://trac.ffmpeg.org/ticket/11137 In this case we return the
759
+ // very next frame instead of throwing an exception.
715
760
  // TODO: Maybe log to stderr for Debug builds?
716
761
  return true;
717
762
  }
@@ -725,7 +770,7 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
725
770
  }
726
771
 
727
772
  FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
728
- const std::vector<double>& timestamps) {
773
+ const torch::Tensor& timestamps) {
729
774
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
730
775
 
731
776
  const auto& streamMetadata =
@@ -739,9 +784,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
739
784
  // avoid decoding that unique frame twice is to convert the input timestamps
740
785
  // to indices, and leverage the de-duplication logic of getFramesAtIndices.
741
786
 
742
- std::vector<int64_t> frameIndices(timestamps.size());
743
- for (size_t i = 0; i < timestamps.size(); ++i) {
744
- auto frameSeconds = timestamps[i];
787
+ torch::Tensor frameIndices =
788
+ torch::empty({timestamps.numel()}, torch::kInt64);
789
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
790
+ auto timestampsAccessor = timestamps.accessor<double, 1>();
791
+
792
+ for (int64_t i = 0; i < timestamps.numel(); ++i) {
793
+ auto frameSeconds = timestampsAccessor[i];
745
794
  TORCH_CHECK(
746
795
  frameSeconds >= minSeconds,
747
796
  "frame pts is " + std::to_string(frameSeconds) +
@@ -758,7 +807,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
758
807
  ".");
759
808
  }
760
809
 
761
- frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
810
+ frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
762
811
  }
763
812
 
764
813
  return getFramesAtIndices(frameIndices);
@@ -791,13 +840,16 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
791
840
  // interval B: [0.2, 0.15)
792
841
  //
793
842
  // Both intervals take place between the pts values for frame 0 and frame 1,
794
- // which by our abstract player, means that both intervals map to frame 0. By
795
- // the definition of a half open interval, interval A should return no frames.
796
- // Interval B should return frame 0. However, for both A and B, the individual
797
- // values of the intervals will map to the same frame indices below. Hence, we
798
- // need this special case below.
843
+ // which by our abstract player, means that both intervals map to frame 0.
844
+ // By the definition of a half open interval, interval A should return no
845
+ // frames. Interval B should return frame 0. However, for both A and B, the
846
+ // individual values of the intervals will map to the same frame indices
847
+ // below. Hence, we need this special case below.
799
848
  if (startSeconds == stopSeconds) {
800
- FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata);
849
+ FrameBatchOutput frameBatchOutput(
850
+ 0,
851
+ resizedOutputDims_.value_or(metadataDims_),
852
+ videoStreamOptions.device);
801
853
  frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
802
854
  return frameBatchOutput;
803
855
  }
@@ -809,8 +861,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
809
861
  "; must be greater than or equal to " + std::to_string(minSeconds) +
810
862
  ".");
811
863
 
812
- // Note that if we can't determine the maximum seconds from the metadata, then
813
- // we assume upper range is valid.
864
+ // Note that if we can't determine the maximum seconds from the metadata,
865
+ // then we assume upper range is valid.
814
866
  std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
815
867
  if (maxSeconds.has_value()) {
816
868
  TORCH_CHECK(
@@ -842,7 +894,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
842
894
  int64_t numFrames = stopFrameIndex - startFrameIndex;
843
895
 
844
896
  FrameBatchOutput frameBatchOutput(
845
- numFrames, videoStreamOptions, streamMetadata);
897
+ numFrames,
898
+ resizedOutputDims_.value_or(metadataDims_),
899
+ videoStreamOptions.device);
846
900
  for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
847
901
  FrameOutput frameOutput =
848
902
  getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
@@ -863,25 +917,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
863
917
  // `numChannels` values. An audio frame, or a sequence thereof, is always
864
918
  // converted into a tensor of shape `(numChannels, numSamplesPerChannel)`.
865
919
  //
866
- // The notion of 'frame' in audio isn't what users want to interact with. Users
867
- // want to interact with samples. The C++ and core APIs return frames, because
868
- // we want those to be close to FFmpeg concepts, but the higher-level public
869
- // APIs expose samples. As a result:
920
+ // The notion of 'frame' in audio isn't what users want to interact with.
921
+ // Users want to interact with samples. The C++ and core APIs return frames,
922
+ // because we want those to be close to FFmpeg concepts, but the higher-level
923
+ // public APIs expose samples. As a result:
870
924
  // - We don't expose index-based APIs for audio, because that would mean
871
- // exposing the concept of audio frame. For now, we think exposing time-based
872
- // APIs is more natural.
873
- // - We never perform a scan for audio streams. We don't need to, since we won't
925
+ // exposing the concept of audio frame. For now, we think exposing
926
+ // time-based APIs is more natural.
927
+ // - We never perform a scan for audio streams. We don't need to, since we
928
+ // won't
874
929
  // be converting timestamps to indices. That's why we enforce the seek_mode
875
- // to be "approximate" (which is slightly misleading, because technically the
876
- // output samples will be at their exact positions. But this incongruence is
877
- // only exposed at the C++/core private levels).
930
+ // to be "approximate" (which is slightly misleading, because technically
931
+ // the output samples will be at their exact positions. But this
932
+ // incongruence is only exposed at the C++/core private levels).
878
933
  //
879
934
  // Audio frames are of variable dimensions: in the same stream, a frame can
880
935
  // contain 1024 samples and the next one may contain 512 [1]. This makes it
881
936
  // impossible to stack audio frames in the same way we can stack video frames.
882
- // This is one of the main reasons we cannot reuse the same pre-allocation logic
883
- // we have for videos in getFramesPlayedInRange(): pre-allocating a batch
884
- // requires constant (and known) frame dimensions. That's also why
937
+ // This is one of the main reasons we cannot reuse the same pre-allocation
938
+ // logic we have for videos in getFramesPlayedInRange(): pre-allocating a
939
+ // batch requires constant (and known) frame dimensions. That's also why
885
940
  // *concatenated* along the samples dimension, not stacked.
886
941
  //
887
942
  // [IMPORTANT!] There is one key invariant that we must respect when decoding
@@ -889,10 +944,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
889
944
  //
890
945
  // BEFORE DECODING FRAME i, WE MUST DECODE ALL FRAMES j < i.
891
946
  //
892
- // Always. Why? We don't know. What we know is that if we don't, we get clipped,
893
- // incorrect audio as output [2]. All other (correct) libraries like TorchAudio
894
- // or Decord do something similar, whether it was intended or not. This has a
895
- // few implications:
947
+ // Always. Why? We don't know. What we know is that if we don't, we get
948
+ // clipped, incorrect audio as output [2]. All other (correct) libraries like
949
+ // TorchAudio or Decord do something similar, whether it was intended or not.
950
+ // This has a few implications:
896
951
  // - The **only** place we're allowed to seek to in an audio stream is the
897
952
  // stream's beginning. This ensures that if we need a frame, we'll have
898
953
  // decoded all previous frames.
@@ -900,8 +955,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
900
955
  // call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually
901
956
  // seek.
902
957
  // - We try not to seek, when we can avoid it. Typically if the next frame we
903
- // need is in the future, we don't seek back to the beginning, we just decode
904
- // all the frames in-between.
958
+ // need is in the future, we don't seek back to the beginning, we just
959
+ // decode all the frames in-between.
905
960
  //
906
961
  // [2] If you're brave and curious, you can read the long "Seek offset for
907
962
  // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which
@@ -928,11 +983,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
928
983
  }
929
984
 
930
985
  auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
931
- if (startPts < streamInfo.lastDecodedAvFramePts +
932
- streamInfo.lastDecodedAvFrameDuration) {
933
- // If we need to seek backwards, then we have to seek back to the beginning
934
- // of the stream.
935
- // See [Audio Decoding Design].
986
+ if (startPts < lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_) {
987
+ // If we need to seek backwards, then we have to seek back to the
988
+ // beginning of the stream. See [Audio Decoding Design].
936
989
  setCursor(INT64_MIN);
937
990
  }
938
991
 
@@ -966,9 +1019,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
966
1019
  // stop decoding more frames. Note that if we were to use [begin, end),
967
1020
  // which may seem more natural, then we would decode the frame starting at
968
1021
  // stopSeconds, which isn't what we want!
969
- auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
970
- streamInfo.lastDecodedAvFrameDuration;
971
- finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
1022
+ auto lastDecodedAvFrameEnd =
1023
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_;
1024
+ finished |= (lastDecodedAvFramePts_) <= stopPts &&
972
1025
  (stopPts <= lastDecodedAvFrameEnd);
973
1026
  }
974
1027
 
@@ -1035,18 +1088,16 @@ I P P P I P P P I P P I P P I P
1035
1088
  bool SingleStreamDecoder::canWeAvoidSeeking() const {
1036
1089
  const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
1037
1090
  if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1038
- // For audio, we only need to seek if a backwards seek was requested within
1039
- // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
1040
- // For more context, see [Audio Decoding Design]
1091
+ // For audio, we only need to seek if a backwards seek was requested
1092
+ // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was
1093
+ // called. For more context, see [Audio Decoding Design]
1041
1094
  return !cursorWasJustSet_;
1042
1095
  }
1043
- int64_t lastDecodedAvFramePts =
1044
- streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
1045
- if (cursor_ < lastDecodedAvFramePts) {
1096
+ if (cursor_ < lastDecodedAvFramePts_) {
1046
1097
  // We can never skip a seek if we are seeking backwards.
1047
1098
  return false;
1048
1099
  }
1049
- if (lastDecodedAvFramePts == cursor_) {
1100
+ if (lastDecodedAvFramePts_ == cursor_) {
1050
1101
  // We are seeking to the exact same frame as we are currently at. Without
1051
1102
  // caching we have to rewind back and decode the frame again.
1052
1103
  // TODO: https://github.com/pytorch/torchcodec/issues/84 we could
@@ -1056,7 +1107,7 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const {
1056
1107
  // We are seeking forwards.
1057
1108
  // We can only skip a seek if both lastDecodedAvFramePts and
1058
1109
  // cursor_ share the same keyframe.
1059
- int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
1110
+ int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_);
1060
1111
  int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_);
1061
1112
  return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
1062
1113
  lastDecodedAvFrameIndex == targetKeyFrameIndex;
@@ -1104,7 +1155,7 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
1104
1155
  getFFMPEGErrorStringFromErrorCode(status));
1105
1156
 
1106
1157
  decodeStats_.numFlushes++;
1107
- avcodec_flush_buffers(streamInfo.codecContext.get());
1158
+ deviceInterface_->flush();
1108
1159
  }
1109
1160
 
1110
1161
  // --------------------------------------------------------------------------
@@ -1122,16 +1173,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1122
1173
  cursorWasJustSet_ = false;
1123
1174
  }
1124
1175
 
1125
- StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
1126
-
1127
- // Need to get the next frame or error from PopFrame.
1128
1176
  UniqueAVFrame avFrame(av_frame_alloc());
1129
1177
  AutoAVPacket autoAVPacket;
1130
1178
  int status = AVSUCCESS;
1131
1179
  bool reachedEOF = false;
1180
+
1181
+ // The default implementation uses avcodec_receive_frame and
1182
+ // avcodec_send_packet, while specialized interfaces can override for
1183
+ // hardware-specific optimizations.
1132
1184
  while (true) {
1133
- status =
1134
- avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
1185
+ status = deviceInterface_->receiveFrame(avFrame);
1135
1186
 
1136
1187
  if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
1137
1188
  // Non-retriable error
@@ -1154,7 +1205,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1154
1205
 
1155
1206
  if (reachedEOF) {
1156
1207
  // We don't have any more packets to receive. So keep on pulling frames
1157
- // from its internal buffers.
1208
+ // from decoder's internal buffers.
1158
1209
  continue;
1159
1210
  }
1160
1211
 
@@ -1166,11 +1217,8 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1166
1217
  decodeStats_.numPacketsRead++;
1167
1218
 
1168
1219
  if (status == AVERROR_EOF) {
1169
- // End of file reached. We must drain the codec by sending a nullptr
1170
- // packet.
1171
- status = avcodec_send_packet(
1172
- streamInfo.codecContext.get(),
1173
- /*avpkt=*/nullptr);
1220
+ // End of file reached. We must drain the decoder
1221
+ status = deviceInterface_->sendEOFPacket();
1174
1222
  TORCH_CHECK(
1175
1223
  status >= AVSUCCESS,
1176
1224
  "Could not flush decoder: ",
@@ -1195,7 +1243,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1195
1243
 
1196
1244
  // We got a valid packet. Send it to the decoder, and we'll receive it in
1197
1245
  // the next iteration.
1198
- status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1246
+ status = deviceInterface_->sendPacket(packet);
1199
1247
  TORCH_CHECK(
1200
1248
  status >= AVSUCCESS,
1201
1249
  "Could not push packet to decoder: ",
@@ -1216,14 +1264,15 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1216
1264
  getFFMPEGErrorStringFromErrorCode(status));
1217
1265
  }
1218
1266
 
1219
- // Note that we don't flush the decoder when we reach EOF (even though that's
1220
- // mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html).
1221
- // This is because we may have packets internally in the decoder that we
1222
- // haven't received as frames. Eventually we will either hit AVERROR_EOF from
1223
- // av_receive_frame() or the user will have seeked to a different location in
1224
- // the file and that will flush the decoder.
1225
- streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
1226
- streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
1267
+ // Note that we don't flush the decoder when we reach EOF (even though
1268
+ // that's mentioned in
1269
+ // https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). This is
1270
+ // because we may have packets internally in the decoder that we haven't
1271
+ // received as frames. Eventually we will either hit AVERROR_EOF from
1272
+ // av_receive_frame() or the user will have seeked to a different location
1273
+ // in the file and that will flush the decoder.
1274
+ lastDecodedAvFramePts_ = getPtsOrDts(avFrame);
1275
+ lastDecodedAvFrameDuration_ = getDuration(avFrame);
1227
1276
 
1228
1277
  return avFrame;
1229
1278
  }
@@ -1246,13 +1295,9 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
1246
1295
  formatContext_->streams[activeStreamIndex_]->time_base);
1247
1296
  if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1248
1297
  convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1249
- } else if (deviceInterface_) {
1298
+ } else {
1250
1299
  deviceInterface_->convertAVFrameToFrameOutput(
1251
- streamInfo.videoStreamOptions,
1252
- streamInfo.timeBase,
1253
- avFrame,
1254
- frameOutput,
1255
- preAllocatedOutputTensor);
1300
+ avFrame, frameOutput, preAllocatedOutputTensor);
1256
1301
  }
1257
1302
  return frameOutput;
1258
1303
  }
@@ -1288,8 +1333,8 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1288
1333
 
1289
1334
  UniqueAVFrame convertedAVFrame;
1290
1335
  if (mustConvert) {
1291
- if (!streamInfo.swrContext) {
1292
- streamInfo.swrContext.reset(createSwrContext(
1336
+ if (!swrContext_) {
1337
+ swrContext_.reset(createSwrContext(
1293
1338
  srcSampleFormat,
1294
1339
  outSampleFormat,
1295
1340
  srcSampleRate,
@@ -1299,7 +1344,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1299
1344
  }
1300
1345
 
1301
1346
  convertedAVFrame = convertAudioAVFrameSamples(
1302
- streamInfo.swrContext,
1347
+ swrContext_,
1303
1348
  srcAVFrame,
1304
1349
  outSampleFormat,
1305
1350
  outSampleRate,
@@ -1347,15 +1392,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1347
1392
  std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1348
1393
  // When sample rate conversion is involved, swresample buffers some of the
1349
1394
  // samples in-between calls to swr_convert (see the libswresample docs).
1350
- // That's because the last few samples in a given frame require future samples
1351
- // from the next frame to be properly converted. This function flushes out the
1352
- // samples that are stored in swresample's buffers.
1395
+ // That's because the last few samples in a given frame require future
1396
+ // samples from the next frame to be properly converted. This function
1397
+ // flushes out the samples that are stored in swresample's buffers.
1353
1398
  auto& streamInfo = streamInfos_[activeStreamIndex_];
1354
- if (!streamInfo.swrContext) {
1399
+ if (!swrContext_) {
1355
1400
  return std::nullopt;
1356
1401
  }
1357
1402
  auto numRemainingSamples = // this is an upper bound
1358
- swr_get_out_samples(streamInfo.swrContext.get(), 0);
1403
+ swr_get_out_samples(swrContext_.get(), 0);
1359
1404
 
1360
1405
  if (numRemainingSamples == 0) {
1361
1406
  return std::nullopt;
@@ -1372,11 +1417,7 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1372
1417
  }
1373
1418
 
1374
1419
  auto actualNumRemainingSamples = swr_convert(
1375
- streamInfo.swrContext.get(),
1376
- outputBuffers.data(),
1377
- numRemainingSamples,
1378
- nullptr,
1379
- 0);
1420
+ swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);
1380
1421
 
1381
1422
  return lastSamples.narrow(
1382
1423
  /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
@@ -1386,25 +1427,10 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1386
1427
  // OUTPUT ALLOCATION AND SHAPE CONVERSION
1387
1428
  // --------------------------------------------------------------------------
1388
1429
 
1389
- FrameBatchOutput::FrameBatchOutput(
1390
- int64_t numFrames,
1391
- const VideoStreamOptions& videoStreamOptions,
1392
- const StreamMetadata& streamMetadata)
1393
- : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
1394
- durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
1395
- auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
1396
- videoStreamOptions, streamMetadata);
1397
- int height = frameDims.height;
1398
- int width = frameDims.width;
1399
- data = allocateEmptyHWCTensor(
1400
- height, width, videoStreamOptions.device, numFrames);
1401
- }
1402
-
1403
- // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
1404
- // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
1405
- // or 4D.
1406
- // Calling permute() is guaranteed to return a view as per the docs:
1407
- // https://pytorch.org/docs/stable/generated/torch.permute.html
1430
+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require
1431
+ // so. The [N] leading batch-dimension is optional i.e. the input tensor can
1432
+ // be 3D or 4D. Calling permute() is guaranteed to return a view as per the
1433
+ // docs: https://pytorch.org/docs/stable/generated/torch.permute.html
1408
1434
  torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW(
1409
1435
  torch::Tensor& hwcTensor) {
1410
1436
  if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
@@ -1624,8 +1650,8 @@ void SingleStreamDecoder::validateFrameIndex(
1624
1650
  "and the number of frames must be known.");
1625
1651
  }
1626
1652
 
1627
- // Note that if we do not have the number of frames available in our metadata,
1628
- // then we assume that the frameIndex is valid.
1653
+ // Note that if we do not have the number of frames available in our
1654
+ // metadata, then we assume that the frameIndex is valid.
1629
1655
  std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1630
1656
  if (numFrames.has_value()) {
1631
1657
  if (frameIndex >= numFrames.value()) {
@@ -1676,40 +1702,9 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
1676
1702
  streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
1677
1703
  }
1678
1704
 
1679
- // --------------------------------------------------------------------------
1680
- // FrameDims APIs
1681
- // --------------------------------------------------------------------------
1682
-
1683
- FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
1684
- return FrameDims(resizedAVFrame.height, resizedAVFrame.width);
1685
- }
1686
-
1687
- FrameDims getHeightAndWidthFromOptionsOrMetadata(
1688
- const VideoStreamOptions& videoStreamOptions,
1689
- const StreamMetadata& streamMetadata) {
1690
- return FrameDims(
1691
- videoStreamOptions.height.value_or(*streamMetadata.height),
1692
- videoStreamOptions.width.value_or(*streamMetadata.width));
1693
- }
1694
-
1695
- FrameDims getHeightAndWidthFromOptionsOrAVFrame(
1696
- const VideoStreamOptions& videoStreamOptions,
1697
- const UniqueAVFrame& avFrame) {
1698
- return FrameDims(
1699
- videoStreamOptions.height.value_or(avFrame->height),
1700
- videoStreamOptions.width.value_or(avFrame->width));
1701
- }
1702
-
1703
- SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
1704
- if (seekMode == "exact") {
1705
- return SingleStreamDecoder::SeekMode::exact;
1706
- } else if (seekMode == "approximate") {
1707
- return SingleStreamDecoder::SeekMode::approximate;
1708
- } else if (seekMode == "custom_frame_mappings") {
1709
- return SingleStreamDecoder::SeekMode::custom_frame_mappings;
1710
- } else {
1711
- TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
1712
- }
1705
+ std::string SingleStreamDecoder::getDeviceInterfaceDetails() const {
1706
+ TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist.");
1707
+ return deviceInterface_->getDetails();
1713
1708
  }
1714
1709
 
1715
1710
  } // namespace facebook::torchcodec