torchcodec 0.7.0__cp310-cp310-win_amd64.whl → 0.8.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (61) hide show
  1. torchcodec/_core/BetaCudaDeviceInterface.cpp +636 -0
  2. torchcodec/_core/BetaCudaDeviceInterface.h +191 -0
  3. torchcodec/_core/CMakeLists.txt +36 -3
  4. torchcodec/_core/CUDACommon.cpp +315 -0
  5. torchcodec/_core/CUDACommon.h +46 -0
  6. torchcodec/_core/CpuDeviceInterface.cpp +189 -108
  7. torchcodec/_core/CpuDeviceInterface.h +81 -19
  8. torchcodec/_core/CudaDeviceInterface.cpp +211 -368
  9. torchcodec/_core/CudaDeviceInterface.h +33 -6
  10. torchcodec/_core/DeviceInterface.cpp +57 -19
  11. torchcodec/_core/DeviceInterface.h +97 -16
  12. torchcodec/_core/Encoder.cpp +302 -9
  13. torchcodec/_core/Encoder.h +51 -1
  14. torchcodec/_core/FFMPEGCommon.cpp +189 -2
  15. torchcodec/_core/FFMPEGCommon.h +18 -0
  16. torchcodec/_core/FilterGraph.cpp +28 -21
  17. torchcodec/_core/FilterGraph.h +15 -1
  18. torchcodec/_core/Frame.cpp +17 -7
  19. torchcodec/_core/Frame.h +15 -61
  20. torchcodec/_core/Metadata.h +2 -2
  21. torchcodec/_core/NVDECCache.cpp +70 -0
  22. torchcodec/_core/NVDECCache.h +104 -0
  23. torchcodec/_core/SingleStreamDecoder.cpp +202 -198
  24. torchcodec/_core/SingleStreamDecoder.h +39 -14
  25. torchcodec/_core/StreamOptions.h +16 -6
  26. torchcodec/_core/Transform.cpp +60 -0
  27. torchcodec/_core/Transform.h +59 -0
  28. torchcodec/_core/__init__.py +1 -0
  29. torchcodec/_core/custom_ops.cpp +180 -32
  30. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
  31. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  32. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  33. torchcodec/_core/ops.py +86 -43
  34. torchcodec/_core/pybind_ops.cpp +22 -59
  35. torchcodec/_samplers/video_clip_sampler.py +7 -19
  36. torchcodec/decoders/__init__.py +1 -0
  37. torchcodec/decoders/_decoder_utils.py +61 -1
  38. torchcodec/decoders/_video_decoder.py +56 -20
  39. torchcodec/libtorchcodec_core4.dll +0 -0
  40. torchcodec/libtorchcodec_core5.dll +0 -0
  41. torchcodec/libtorchcodec_core6.dll +0 -0
  42. torchcodec/libtorchcodec_core7.dll +0 -0
  43. torchcodec/libtorchcodec_core8.dll +0 -0
  44. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  45. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  46. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  47. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  48. torchcodec/libtorchcodec_custom_ops8.dll +0 -0
  49. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  50. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  51. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  52. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  53. torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
  54. torchcodec/samplers/_time_based.py +8 -0
  55. torchcodec/version.py +1 -1
  56. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/METADATA +24 -13
  57. torchcodec-0.8.0.dist-info/RECORD +80 -0
  58. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/WHEEL +1 -1
  59. torchcodec-0.7.0.dist-info/RECORD +0 -67
  60. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  61. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/top_level.txt +0 -0
@@ -17,16 +17,6 @@
17
17
  namespace facebook::torchcodec {
18
18
  namespace {
19
19
 
20
- double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
21
- // To perform the multiplication before the division, av_q2d is not used
22
- return static_cast<double>(pts) * timeBase.num / timeBase.den;
23
- }
24
-
25
- int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) {
26
- return static_cast<int64_t>(
27
- std::round(seconds * timeBase.den / timeBase.num));
28
- }
29
-
30
20
  // Some videos aren't properly encoded and do not specify pts values for
31
21
  // packets, and thus for frames. Unset values correspond to INT64_MIN. When that
32
22
  // happens, we fallback to the dts value which hopefully exists and is correct.
@@ -322,19 +312,35 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
322
312
  void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
323
313
  int streamIndex,
324
314
  FrameMappings customFrameMappings) {
325
- auto& all_frames = customFrameMappings.all_frames;
326
- auto& is_key_frame = customFrameMappings.is_key_frame;
327
- auto& duration = customFrameMappings.duration;
315
+ TORCH_CHECK(
316
+ customFrameMappings.all_frames.dtype() == torch::kLong &&
317
+ customFrameMappings.is_key_frame.dtype() == torch::kBool &&
318
+ customFrameMappings.duration.dtype() == torch::kLong,
319
+ "all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype.");
320
+ const torch::Tensor& all_frames =
321
+ customFrameMappings.all_frames.to(torch::kLong);
322
+ const torch::Tensor& is_key_frame =
323
+ customFrameMappings.is_key_frame.to(torch::kBool);
324
+ const torch::Tensor& duration = customFrameMappings.duration.to(torch::kLong);
328
325
  TORCH_CHECK(
329
326
  all_frames.size(0) == is_key_frame.size(0) &&
330
327
  is_key_frame.size(0) == duration.size(0),
331
328
  "all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.");
332
329
 
330
+ // Allocate vectors using num frames to reduce reallocations
331
+ int64_t numFrames = all_frames.size(0);
332
+ streamInfos_[streamIndex].allFrames.reserve(numFrames);
333
+ streamInfos_[streamIndex].keyFrames.reserve(numFrames);
334
+ // Use accessor to efficiently access tensor elements
335
+ auto pts_data = all_frames.accessor<int64_t, 1>();
336
+ auto is_key_frame_data = is_key_frame.accessor<bool, 1>();
337
+ auto duration_data = duration.accessor<int64_t, 1>();
338
+
333
339
  auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
334
340
 
335
- streamMetadata.beginStreamPtsFromContent = all_frames[0].item<int64_t>();
341
+ streamMetadata.beginStreamPtsFromContent = pts_data[0];
336
342
  streamMetadata.endStreamPtsFromContent =
337
- all_frames[-1].item<int64_t>() + duration[-1].item<int64_t>();
343
+ pts_data[numFrames - 1] + duration_data[numFrames - 1];
338
344
 
339
345
  auto avStream = formatContext_->streams[streamIndex];
340
346
  streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds(
@@ -343,17 +349,16 @@ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
343
349
  streamMetadata.endStreamPtsSecondsFromContent = ptsToSeconds(
344
350
  *streamMetadata.endStreamPtsFromContent, avStream->time_base);
345
351
 
346
- streamMetadata.numFramesFromContent = all_frames.size(0);
347
- for (int64_t i = 0; i < all_frames.size(0); ++i) {
352
+ streamMetadata.numFramesFromContent = numFrames;
353
+ for (int64_t i = 0; i < numFrames; ++i) {
348
354
  FrameInfo frameInfo;
349
- frameInfo.pts = all_frames[i].item<int64_t>();
350
- frameInfo.isKeyFrame = is_key_frame[i].item<bool>();
355
+ frameInfo.pts = pts_data[i];
356
+ frameInfo.isKeyFrame = is_key_frame_data[i];
351
357
  streamInfos_[streamIndex].allFrames.push_back(frameInfo);
352
358
  if (frameInfo.isKeyFrame) {
353
359
  streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
354
360
  }
355
361
  }
356
- // Sort all frames by their pts
357
362
  sortAllFrames();
358
363
  }
359
364
 
@@ -384,6 +389,7 @@ void SingleStreamDecoder::addStream(
384
389
  int streamIndex,
385
390
  AVMediaType mediaType,
386
391
  const torch::Device& device,
392
+ const std::string_view deviceVariant,
387
393
  std::optional<int> ffmpegThreadCount) {
388
394
  TORCH_CHECK(
389
395
  activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -412,8 +418,6 @@ void SingleStreamDecoder::addStream(
412
418
  streamInfo.stream = formatContext_->streams[activeStreamIndex_];
413
419
  streamInfo.avMediaType = mediaType;
414
420
 
415
- deviceInterface_ = createDeviceInterface(device);
416
-
417
421
  // This should never happen, checking just to be safe.
418
422
  TORCH_CHECK(
419
423
  streamInfo.stream->codecpar->codec_type == mediaType,
@@ -421,14 +425,18 @@ void SingleStreamDecoder::addStream(
421
425
  activeStreamIndex_,
422
426
  " which is of the wrong media type.");
423
427
 
428
+ deviceInterface_ = createDeviceInterface(device, deviceVariant);
429
+ TORCH_CHECK(
430
+ deviceInterface_ != nullptr,
431
+ "Failed to create device interface. This should never happen, please report.");
432
+ deviceInterface_->initialize(streamInfo.stream, formatContext_);
433
+
424
434
  // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
425
435
  // addStream() which is supposed to be generic
426
436
  if (mediaType == AVMEDIA_TYPE_VIDEO) {
427
- if (deviceInterface_) {
428
- avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
429
- deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
430
- .value_or(avCodec));
431
- }
437
+ avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
438
+ deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
439
+ .value_or(avCodec));
432
440
  }
433
441
 
434
442
  AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
@@ -442,17 +450,15 @@ void SingleStreamDecoder::addStream(
442
450
  streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0);
443
451
  streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base;
444
452
 
445
- // TODO_CODE_QUALITY same as above.
446
- if (mediaType == AVMEDIA_TYPE_VIDEO) {
447
- if (deviceInterface_) {
448
- deviceInterface_->initializeContext(codecContext);
449
- }
450
- }
451
-
453
+ // Note that we must make sure to register the harware device context
454
+ // with the codec context before calling avcodec_open2(). Otherwise, decoding
455
+ // will happen on the CPU and not the hardware device.
456
+ deviceInterface_->registerHardwareDeviceWithCodec(codecContext);
452
457
  retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
453
458
  TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
454
459
 
455
460
  codecContext->time_base = streamInfo.stream->time_base;
461
+
456
462
  containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
457
463
  std::string(avcodec_get_name(codecContext->codec_id));
458
464
 
@@ -469,12 +475,18 @@ void SingleStreamDecoder::addStream(
469
475
 
470
476
  void SingleStreamDecoder::addVideoStream(
471
477
  int streamIndex,
478
+ std::vector<Transform*>& transforms,
472
479
  const VideoStreamOptions& videoStreamOptions,
473
480
  std::optional<FrameMappings> customFrameMappings) {
481
+ TORCH_CHECK(
482
+ transforms.empty() || videoStreamOptions.device == torch::kCPU,
483
+ " Transforms are only supported for CPU devices.");
484
+
474
485
  addStream(
475
486
  streamIndex,
476
487
  AVMEDIA_TYPE_VIDEO,
477
488
  videoStreamOptions.device,
489
+ videoStreamOptions.deviceVariant,
478
490
  videoStreamOptions.ffmpegThreadCount);
479
491
 
480
492
  auto& streamMetadata =
@@ -501,8 +513,24 @@ void SingleStreamDecoder::addVideoStream(
501
513
  customFrameMappings.has_value(),
502
514
  "Missing frame mappings when custom_frame_mappings seek mode is set.");
503
515
  readCustomFrameMappingsUpdateMetadataAndIndex(
504
- streamIndex, customFrameMappings.value());
516
+ activeStreamIndex_, customFrameMappings.value());
505
517
  }
518
+
519
+ metadataDims_ =
520
+ FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
521
+ for (auto& transform : transforms) {
522
+ TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
523
+ if (transform->getOutputFrameDims().has_value()) {
524
+ resizedOutputDims_ = transform->getOutputFrameDims().value();
525
+ }
526
+
527
+ // Note that we are claiming ownership of the transform objects passed in to
528
+ // us.
529
+ transforms_.push_back(std::unique_ptr<Transform>(transform));
530
+ }
531
+
532
+ deviceInterface_->initializeVideo(
533
+ videoStreamOptions, transforms_, resizedOutputDims_);
506
534
  }
507
535
 
508
536
  void SingleStreamDecoder::addAudioStream(
@@ -587,11 +615,18 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
587
615
  }
588
616
 
589
617
  FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
590
- const std::vector<int64_t>& frameIndices) {
618
+ const torch::Tensor& frameIndices) {
591
619
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
592
620
 
593
- auto indicesAreSorted =
594
- std::is_sorted(frameIndices.begin(), frameIndices.end());
621
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
622
+
623
+ bool indicesAreSorted = true;
624
+ for (int64_t i = 1; i < frameIndices.numel(); ++i) {
625
+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
626
+ indicesAreSorted = false;
627
+ break;
628
+ }
629
+ }
595
630
 
596
631
  std::vector<size_t> argsort;
597
632
  if (!indicesAreSorted) {
@@ -599,27 +634,29 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
599
634
  // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
600
635
  // to use to decode the frames
601
636
  // and argsort is [ 1, 3, 2, 0]
602
- argsort.resize(frameIndices.size());
637
+ argsort.resize(frameIndices.numel());
603
638
  for (size_t i = 0; i < argsort.size(); ++i) {
604
639
  argsort[i] = i;
605
640
  }
606
641
  std::sort(
607
- argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
608
- return frameIndices[a] < frameIndices[b];
642
+ argsort.begin(),
643
+ argsort.end(),
644
+ [&frameIndicesAccessor](size_t a, size_t b) {
645
+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
609
646
  });
610
647
  }
611
648
 
612
- const auto& streamMetadata =
613
- containerMetadata_.allStreamMetadata[activeStreamIndex_];
614
649
  const auto& streamInfo = streamInfos_[activeStreamIndex_];
615
650
  const auto& videoStreamOptions = streamInfo.videoStreamOptions;
616
651
  FrameBatchOutput frameBatchOutput(
617
- frameIndices.size(), videoStreamOptions, streamMetadata);
652
+ frameIndices.numel(),
653
+ resizedOutputDims_.value_or(metadataDims_),
654
+ videoStreamOptions.device);
618
655
 
619
656
  auto previousIndexInVideo = -1;
620
- for (size_t f = 0; f < frameIndices.size(); ++f) {
657
+ for (int64_t f = 0; f < frameIndices.numel(); ++f) {
621
658
  auto indexInOutput = indicesAreSorted ? f : argsort[f];
622
- auto indexInVideo = frameIndices[indexInOutput];
659
+ auto indexInVideo = frameIndicesAccessor[indexInOutput];
623
660
 
624
661
  if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
625
662
  // Avoid decoding the same frame twice
@@ -657,8 +694,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
657
694
  TORCH_CHECK(
658
695
  step > 0, "Step must be greater than 0; is " + std::to_string(step));
659
696
 
660
- // Note that if we do not have the number of frames available in our metadata,
661
- // then we assume that the upper part of the range is valid.
697
+ // Note that if we do not have the number of frames available in our
698
+ // metadata, then we assume that the upper part of the range is valid.
662
699
  std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
663
700
  if (numFrames.has_value()) {
664
701
  TORCH_CHECK(
@@ -671,7 +708,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
671
708
  int64_t numOutputFrames = std::ceil((stop - start) / double(step));
672
709
  const auto& videoStreamOptions = streamInfo.videoStreamOptions;
673
710
  FrameBatchOutput frameBatchOutput(
674
- numOutputFrames, videoStreamOptions, streamMetadata);
711
+ numOutputFrames,
712
+ resizedOutputDims_.value_or(metadataDims_),
713
+ videoStreamOptions.device);
675
714
 
676
715
  for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
677
716
  FrameOutput frameOutput =
@@ -687,9 +726,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
687
726
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
688
727
  StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689
728
  double lastDecodedStartTime =
690
- ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
729
+ ptsToSeconds(lastDecodedAvFramePts_, streamInfo.timeBase);
691
730
  double lastDecodedEndTime = ptsToSeconds(
692
- streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration,
731
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_,
693
732
  streamInfo.timeBase);
694
733
  if (seconds >= lastDecodedStartTime && seconds < lastDecodedEndTime) {
695
734
  // We are in the same frame as the one we just returned. However, since we
@@ -709,9 +748,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
709
748
  // FFMPEG seeked past the frame we are looking for even though we
710
749
  // set max_ts to be our needed timestamp in avformat_seek_file()
711
750
  // in maybeSeekToBeforeDesiredPts().
712
- // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137
713
- // In this case we return the very next frame instead of throwing an
714
- // exception.
751
+ // This could be a bug in FFMPEG:
752
+ // https://trac.ffmpeg.org/ticket/11137 In this case we return the
753
+ // very next frame instead of throwing an exception.
715
754
  // TODO: Maybe log to stderr for Debug builds?
716
755
  return true;
717
756
  }
@@ -725,7 +764,7 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
725
764
  }
726
765
 
727
766
  FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
728
- const std::vector<double>& timestamps) {
767
+ const torch::Tensor& timestamps) {
729
768
  validateActiveStream(AVMEDIA_TYPE_VIDEO);
730
769
 
731
770
  const auto& streamMetadata =
@@ -739,9 +778,13 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
739
778
  // avoid decoding that unique frame twice is to convert the input timestamps
740
779
  // to indices, and leverage the de-duplication logic of getFramesAtIndices.
741
780
 
742
- std::vector<int64_t> frameIndices(timestamps.size());
743
- for (size_t i = 0; i < timestamps.size(); ++i) {
744
- auto frameSeconds = timestamps[i];
781
+ torch::Tensor frameIndices =
782
+ torch::empty({timestamps.numel()}, torch::kInt64);
783
+ auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
784
+ auto timestampsAccessor = timestamps.accessor<double, 1>();
785
+
786
+ for (int64_t i = 0; i < timestamps.numel(); ++i) {
787
+ auto frameSeconds = timestampsAccessor[i];
745
788
  TORCH_CHECK(
746
789
  frameSeconds >= minSeconds,
747
790
  "frame pts is " + std::to_string(frameSeconds) +
@@ -758,7 +801,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
758
801
  ".");
759
802
  }
760
803
 
761
- frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
804
+ frameIndicesAccessor[i] = secondsToIndexLowerBound(frameSeconds);
762
805
  }
763
806
 
764
807
  return getFramesAtIndices(frameIndices);
@@ -791,13 +834,16 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
791
834
  // interval B: [0.2, 0.15)
792
835
  //
793
836
  // Both intervals take place between the pts values for frame 0 and frame 1,
794
- // which by our abstract player, means that both intervals map to frame 0. By
795
- // the definition of a half open interval, interval A should return no frames.
796
- // Interval B should return frame 0. However, for both A and B, the individual
797
- // values of the intervals will map to the same frame indices below. Hence, we
798
- // need this special case below.
837
+ // which by our abstract player, means that both intervals map to frame 0.
838
+ // By the definition of a half open interval, interval A should return no
839
+ // frames. Interval B should return frame 0. However, for both A and B, the
840
+ // individual values of the intervals will map to the same frame indices
841
+ // below. Hence, we need this special case below.
799
842
  if (startSeconds == stopSeconds) {
800
- FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata);
843
+ FrameBatchOutput frameBatchOutput(
844
+ 0,
845
+ resizedOutputDims_.value_or(metadataDims_),
846
+ videoStreamOptions.device);
801
847
  frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
802
848
  return frameBatchOutput;
803
849
  }
@@ -809,8 +855,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
809
855
  "; must be greater than or equal to " + std::to_string(minSeconds) +
810
856
  ".");
811
857
 
812
- // Note that if we can't determine the maximum seconds from the metadata, then
813
- // we assume upper range is valid.
858
+ // Note that if we can't determine the maximum seconds from the metadata,
859
+ // then we assume upper range is valid.
814
860
  std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
815
861
  if (maxSeconds.has_value()) {
816
862
  TORCH_CHECK(
@@ -842,7 +888,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
842
888
  int64_t numFrames = stopFrameIndex - startFrameIndex;
843
889
 
844
890
  FrameBatchOutput frameBatchOutput(
845
- numFrames, videoStreamOptions, streamMetadata);
891
+ numFrames,
892
+ resizedOutputDims_.value_or(metadataDims_),
893
+ videoStreamOptions.device);
846
894
  for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
847
895
  FrameOutput frameOutput =
848
896
  getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
@@ -863,25 +911,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
863
911
  // `numChannels` values. An audio frame, or a sequence thereof, is always
864
912
  // converted into a tensor of shape `(numChannels, numSamplesPerChannel)`.
865
913
  //
866
- // The notion of 'frame' in audio isn't what users want to interact with. Users
867
- // want to interact with samples. The C++ and core APIs return frames, because
868
- // we want those to be close to FFmpeg concepts, but the higher-level public
869
- // APIs expose samples. As a result:
914
+ // The notion of 'frame' in audio isn't what users want to interact with.
915
+ // Users want to interact with samples. The C++ and core APIs return frames,
916
+ // because we want those to be close to FFmpeg concepts, but the higher-level
917
+ // public APIs expose samples. As a result:
870
918
  // - We don't expose index-based APIs for audio, because that would mean
871
- // exposing the concept of audio frame. For now, we think exposing time-based
872
- // APIs is more natural.
873
- // - We never perform a scan for audio streams. We don't need to, since we won't
919
+ // exposing the concept of audio frame. For now, we think exposing
920
+ // time-based APIs is more natural.
921
+ // - We never perform a scan for audio streams. We don't need to, since we
922
+ // won't
874
923
  // be converting timestamps to indices. That's why we enforce the seek_mode
875
- // to be "approximate" (which is slightly misleading, because technically the
876
- // output samples will be at their exact positions. But this incongruence is
877
- // only exposed at the C++/core private levels).
924
+ // to be "approximate" (which is slightly misleading, because technically
925
+ // the output samples will be at their exact positions. But this
926
+ // incongruence is only exposed at the C++/core private levels).
878
927
  //
879
928
  // Audio frames are of variable dimensions: in the same stream, a frame can
880
929
  // contain 1024 samples and the next one may contain 512 [1]. This makes it
881
930
  // impossible to stack audio frames in the same way we can stack video frames.
882
- // This is one of the main reasons we cannot reuse the same pre-allocation logic
883
- // we have for videos in getFramesPlayedInRange(): pre-allocating a batch
884
- // requires constant (and known) frame dimensions. That's also why
931
+ // This is one of the main reasons we cannot reuse the same pre-allocation
932
+ // logic we have for videos in getFramesPlayedInRange(): pre-allocating a
933
+ // batch requires constant (and known) frame dimensions. That's also why
885
934
  // *concatenated* along the samples dimension, not stacked.
886
935
  //
887
936
  // [IMPORTANT!] There is one key invariant that we must respect when decoding
@@ -889,10 +938,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
889
938
  //
890
939
  // BEFORE DECODING FRAME i, WE MUST DECODE ALL FRAMES j < i.
891
940
  //
892
- // Always. Why? We don't know. What we know is that if we don't, we get clipped,
893
- // incorrect audio as output [2]. All other (correct) libraries like TorchAudio
894
- // or Decord do something similar, whether it was intended or not. This has a
895
- // few implications:
941
+ // Always. Why? We don't know. What we know is that if we don't, we get
942
+ // clipped, incorrect audio as output [2]. All other (correct) libraries like
943
+ // TorchAudio or Decord do something similar, whether it was intended or not.
944
+ // This has a few implications:
896
945
  // - The **only** place we're allowed to seek to in an audio stream is the
897
946
  // stream's beginning. This ensures that if we need a frame, we'll have
898
947
  // decoded all previous frames.
@@ -900,8 +949,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
900
949
  // call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually
901
950
  // seek.
902
951
  // - We try not to seek, when we can avoid it. Typically if the next frame we
903
- // need is in the future, we don't seek back to the beginning, we just decode
904
- // all the frames in-between.
952
+ // need is in the future, we don't seek back to the beginning, we just
953
+ // decode all the frames in-between.
905
954
  //
906
955
  // [2] If you're brave and curious, you can read the long "Seek offset for
907
956
  // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which
@@ -928,11 +977,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
928
977
  }
929
978
 
930
979
  auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
931
- if (startPts < streamInfo.lastDecodedAvFramePts +
932
- streamInfo.lastDecodedAvFrameDuration) {
933
- // If we need to seek backwards, then we have to seek back to the beginning
934
- // of the stream.
935
- // See [Audio Decoding Design].
980
+ if (startPts < lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_) {
981
+ // If we need to seek backwards, then we have to seek back to the
982
+ // beginning of the stream. See [Audio Decoding Design].
936
983
  setCursor(INT64_MIN);
937
984
  }
938
985
 
@@ -966,9 +1013,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
966
1013
  // stop decoding more frames. Note that if we were to use [begin, end),
967
1014
  // which may seem more natural, then we would decode the frame starting at
968
1015
  // stopSeconds, which isn't what we want!
969
- auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
970
- streamInfo.lastDecodedAvFrameDuration;
971
- finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
1016
+ auto lastDecodedAvFrameEnd =
1017
+ lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_;
1018
+ finished |= (lastDecodedAvFramePts_) <= stopPts &&
972
1019
  (stopPts <= lastDecodedAvFrameEnd);
973
1020
  }
974
1021
 
@@ -1035,18 +1082,16 @@ I P P P I P P P I P P I P P I P
1035
1082
  bool SingleStreamDecoder::canWeAvoidSeeking() const {
1036
1083
  const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
1037
1084
  if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1038
- // For audio, we only need to seek if a backwards seek was requested within
1039
- // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
1040
- // For more context, see [Audio Decoding Design]
1085
+ // For audio, we only need to seek if a backwards seek was requested
1086
+ // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was
1087
+ // called. For more context, see [Audio Decoding Design]
1041
1088
  return !cursorWasJustSet_;
1042
1089
  }
1043
- int64_t lastDecodedAvFramePts =
1044
- streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
1045
- if (cursor_ < lastDecodedAvFramePts) {
1090
+ if (cursor_ < lastDecodedAvFramePts_) {
1046
1091
  // We can never skip a seek if we are seeking backwards.
1047
1092
  return false;
1048
1093
  }
1049
- if (lastDecodedAvFramePts == cursor_) {
1094
+ if (lastDecodedAvFramePts_ == cursor_) {
1050
1095
  // We are seeking to the exact same frame as we are currently at. Without
1051
1096
  // caching we have to rewind back and decode the frame again.
1052
1097
  // TODO: https://github.com/pytorch/torchcodec/issues/84 we could
@@ -1056,7 +1101,7 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const {
1056
1101
  // We are seeking forwards.
1057
1102
  // We can only skip a seek if both lastDecodedAvFramePts and
1058
1103
  // cursor_ share the same keyframe.
1059
- int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
1104
+ int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_);
1060
1105
  int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_);
1061
1106
  return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
1062
1107
  lastDecodedAvFrameIndex == targetKeyFrameIndex;
@@ -1105,6 +1150,8 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
1105
1150
 
1106
1151
  decodeStats_.numFlushes++;
1107
1152
  avcodec_flush_buffers(streamInfo.codecContext.get());
1153
+
1154
+ deviceInterface_->flush();
1108
1155
  }
1109
1156
 
1110
1157
  // --------------------------------------------------------------------------
@@ -1123,15 +1170,23 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1123
1170
  }
1124
1171
 
1125
1172
  StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
1126
-
1127
- // Need to get the next frame or error from PopFrame.
1128
1173
  UniqueAVFrame avFrame(av_frame_alloc());
1129
1174
  AutoAVPacket autoAVPacket;
1130
1175
  int status = AVSUCCESS;
1131
1176
  bool reachedEOF = false;
1177
+
1178
+ // TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
1179
+ // if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
1180
+ // dispatching to the interface. The default implementation of the interface's
1181
+ // receiveFrame and sendPacket could just be calling avcodec_receive_frame and
1182
+ // avcodec_send_packet. This would make the decoding loop even more generic.
1132
1183
  while (true) {
1133
- status =
1134
- avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
1184
+ if (deviceInterface_->canDecodePacketDirectly()) {
1185
+ status = deviceInterface_->receiveFrame(avFrame);
1186
+ } else {
1187
+ status =
1188
+ avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
1189
+ }
1135
1190
 
1136
1191
  if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
1137
1192
  // Non-retriable error
@@ -1154,7 +1209,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1154
1209
 
1155
1210
  if (reachedEOF) {
1156
1211
  // We don't have any more packets to receive. So keep on pulling frames
1157
- // from its internal buffers.
1212
+ // from decoder's internal buffers.
1158
1213
  continue;
1159
1214
  }
1160
1215
 
@@ -1166,11 +1221,14 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1166
1221
  decodeStats_.numPacketsRead++;
1167
1222
 
1168
1223
  if (status == AVERROR_EOF) {
1169
- // End of file reached. We must drain the codec by sending a nullptr
1170
- // packet.
1171
- status = avcodec_send_packet(
1172
- streamInfo.codecContext.get(),
1173
- /*avpkt=*/nullptr);
1224
+ // End of file reached. We must drain the decoder
1225
+ if (deviceInterface_->canDecodePacketDirectly()) {
1226
+ status = deviceInterface_->sendEOFPacket();
1227
+ } else {
1228
+ status = avcodec_send_packet(
1229
+ streamInfo.codecContext.get(),
1230
+ /*avpkt=*/nullptr);
1231
+ }
1174
1232
  TORCH_CHECK(
1175
1233
  status >= AVSUCCESS,
1176
1234
  "Could not flush decoder: ",
@@ -1195,7 +1253,11 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1195
1253
 
1196
1254
  // We got a valid packet. Send it to the decoder, and we'll receive it in
1197
1255
  // the next iteration.
1198
- status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1256
+ if (deviceInterface_->canDecodePacketDirectly()) {
1257
+ status = deviceInterface_->sendPacket(packet);
1258
+ } else {
1259
+ status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1260
+ }
1199
1261
  TORCH_CHECK(
1200
1262
  status >= AVSUCCESS,
1201
1263
  "Could not push packet to decoder: ",
@@ -1216,14 +1278,15 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1216
1278
  getFFMPEGErrorStringFromErrorCode(status));
1217
1279
  }
1218
1280
 
1219
- // Note that we don't flush the decoder when we reach EOF (even though that's
1220
- // mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html).
1221
- // This is because we may have packets internally in the decoder that we
1222
- // haven't received as frames. Eventually we will either hit AVERROR_EOF from
1223
- // av_receive_frame() or the user will have seeked to a different location in
1224
- // the file and that will flush the decoder.
1225
- streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame);
1226
- streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
1281
+ // Note that we don't flush the decoder when we reach EOF (even though
1282
+ // that's mentioned in
1283
+ // https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). This is
1284
+ // because we may have packets internally in the decoder that we haven't
1285
+ // received as frames. Eventually we will either hit AVERROR_EOF from
1286
+ // av_receive_frame() or the user will have seeked to a different location
1287
+ // in the file and that will flush the decoder.
1288
+ lastDecodedAvFramePts_ = getPtsOrDts(avFrame);
1289
+ lastDecodedAvFrameDuration_ = getDuration(avFrame);
1227
1290
 
1228
1291
  return avFrame;
1229
1292
  }
@@ -1246,13 +1309,9 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
1246
1309
  formatContext_->streams[activeStreamIndex_]->time_base);
1247
1310
  if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1248
1311
  convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1249
- } else if (deviceInterface_) {
1312
+ } else {
1250
1313
  deviceInterface_->convertAVFrameToFrameOutput(
1251
- streamInfo.videoStreamOptions,
1252
- streamInfo.timeBase,
1253
- avFrame,
1254
- frameOutput,
1255
- preAllocatedOutputTensor);
1314
+ avFrame, frameOutput, preAllocatedOutputTensor);
1256
1315
  }
1257
1316
  return frameOutput;
1258
1317
  }
@@ -1288,8 +1347,8 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1288
1347
 
1289
1348
  UniqueAVFrame convertedAVFrame;
1290
1349
  if (mustConvert) {
1291
- if (!streamInfo.swrContext) {
1292
- streamInfo.swrContext.reset(createSwrContext(
1350
+ if (!swrContext_) {
1351
+ swrContext_.reset(createSwrContext(
1293
1352
  srcSampleFormat,
1294
1353
  outSampleFormat,
1295
1354
  srcSampleRate,
@@ -1299,7 +1358,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1299
1358
  }
1300
1359
 
1301
1360
  convertedAVFrame = convertAudioAVFrameSamples(
1302
- streamInfo.swrContext,
1361
+ swrContext_,
1303
1362
  srcAVFrame,
1304
1363
  outSampleFormat,
1305
1364
  outSampleRate,
@@ -1347,15 +1406,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1347
1406
  std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1348
1407
  // When sample rate conversion is involved, swresample buffers some of the
1349
1408
  // samples in-between calls to swr_convert (see the libswresample docs).
1350
- // That's because the last few samples in a given frame require future samples
1351
- // from the next frame to be properly converted. This function flushes out the
1352
- // samples that are stored in swresample's buffers.
1409
+ // That's because the last few samples in a given frame require future
1410
+ // samples from the next frame to be properly converted. This function
1411
+ // flushes out the samples that are stored in swresample's buffers.
1353
1412
  auto& streamInfo = streamInfos_[activeStreamIndex_];
1354
- if (!streamInfo.swrContext) {
1413
+ if (!swrContext_) {
1355
1414
  return std::nullopt;
1356
1415
  }
1357
1416
  auto numRemainingSamples = // this is an upper bound
1358
- swr_get_out_samples(streamInfo.swrContext.get(), 0);
1417
+ swr_get_out_samples(swrContext_.get(), 0);
1359
1418
 
1360
1419
  if (numRemainingSamples == 0) {
1361
1420
  return std::nullopt;
@@ -1372,11 +1431,7 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1372
1431
  }
1373
1432
 
1374
1433
  auto actualNumRemainingSamples = swr_convert(
1375
- streamInfo.swrContext.get(),
1376
- outputBuffers.data(),
1377
- numRemainingSamples,
1378
- nullptr,
1379
- 0);
1434
+ swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);
1380
1435
 
1381
1436
  return lastSamples.narrow(
1382
1437
  /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
@@ -1386,25 +1441,10 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
1386
1441
  // OUTPUT ALLOCATION AND SHAPE CONVERSION
1387
1442
  // --------------------------------------------------------------------------
1388
1443
 
1389
- FrameBatchOutput::FrameBatchOutput(
1390
- int64_t numFrames,
1391
- const VideoStreamOptions& videoStreamOptions,
1392
- const StreamMetadata& streamMetadata)
1393
- : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
1394
- durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
1395
- auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
1396
- videoStreamOptions, streamMetadata);
1397
- int height = frameDims.height;
1398
- int width = frameDims.width;
1399
- data = allocateEmptyHWCTensor(
1400
- height, width, videoStreamOptions.device, numFrames);
1401
- }
1402
-
1403
- // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
1404
- // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
1405
- // or 4D.
1406
- // Calling permute() is guaranteed to return a view as per the docs:
1407
- // https://pytorch.org/docs/stable/generated/torch.permute.html
1444
+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require
1445
+ // so. The [N] leading batch-dimension is optional i.e. the input tensor can
1446
+ // be 3D or 4D. Calling permute() is guaranteed to return a view as per the
1447
+ // docs: https://pytorch.org/docs/stable/generated/torch.permute.html
1408
1448
  torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW(
1409
1449
  torch::Tensor& hwcTensor) {
1410
1450
  if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
@@ -1624,8 +1664,8 @@ void SingleStreamDecoder::validateFrameIndex(
1624
1664
  "and the number of frames must be known.");
1625
1665
  }
1626
1666
 
1627
- // Note that if we do not have the number of frames available in our metadata,
1628
- // then we assume that the frameIndex is valid.
1667
+ // Note that if we do not have the number of frames available in our
1668
+ // metadata, then we assume that the frameIndex is valid.
1629
1669
  std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1630
1670
  if (numFrames.has_value()) {
1631
1671
  if (frameIndex >= numFrames.value()) {
@@ -1676,40 +1716,4 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
1676
1716
  streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
1677
1717
  }
1678
1718
 
1679
- // --------------------------------------------------------------------------
1680
- // FrameDims APIs
1681
- // --------------------------------------------------------------------------
1682
-
1683
- FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
1684
- return FrameDims(resizedAVFrame.height, resizedAVFrame.width);
1685
- }
1686
-
1687
- FrameDims getHeightAndWidthFromOptionsOrMetadata(
1688
- const VideoStreamOptions& videoStreamOptions,
1689
- const StreamMetadata& streamMetadata) {
1690
- return FrameDims(
1691
- videoStreamOptions.height.value_or(*streamMetadata.height),
1692
- videoStreamOptions.width.value_or(*streamMetadata.width));
1693
- }
1694
-
1695
- FrameDims getHeightAndWidthFromOptionsOrAVFrame(
1696
- const VideoStreamOptions& videoStreamOptions,
1697
- const UniqueAVFrame& avFrame) {
1698
- return FrameDims(
1699
- videoStreamOptions.height.value_or(avFrame->height),
1700
- videoStreamOptions.width.value_or(avFrame->width));
1701
- }
1702
-
1703
- SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
1704
- if (seekMode == "exact") {
1705
- return SingleStreamDecoder::SeekMode::exact;
1706
- } else if (seekMode == "approximate") {
1707
- return SingleStreamDecoder::SeekMode::approximate;
1708
- } else if (seekMode == "custom_frame_mappings") {
1709
- return SingleStreamDecoder::SeekMode::custom_frame_mappings;
1710
- } else {
1711
- TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
1712
- }
1713
- }
1714
-
1715
1719
  } // namespace facebook::torchcodec