torchcodec 0.7.0__cp310-cp310-win_amd64.whl → 0.8.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (61) hide show
  1. torchcodec/_core/BetaCudaDeviceInterface.cpp +636 -0
  2. torchcodec/_core/BetaCudaDeviceInterface.h +191 -0
  3. torchcodec/_core/CMakeLists.txt +36 -3
  4. torchcodec/_core/CUDACommon.cpp +315 -0
  5. torchcodec/_core/CUDACommon.h +46 -0
  6. torchcodec/_core/CpuDeviceInterface.cpp +189 -108
  7. torchcodec/_core/CpuDeviceInterface.h +81 -19
  8. torchcodec/_core/CudaDeviceInterface.cpp +211 -368
  9. torchcodec/_core/CudaDeviceInterface.h +33 -6
  10. torchcodec/_core/DeviceInterface.cpp +57 -19
  11. torchcodec/_core/DeviceInterface.h +97 -16
  12. torchcodec/_core/Encoder.cpp +302 -9
  13. torchcodec/_core/Encoder.h +51 -1
  14. torchcodec/_core/FFMPEGCommon.cpp +189 -2
  15. torchcodec/_core/FFMPEGCommon.h +18 -0
  16. torchcodec/_core/FilterGraph.cpp +28 -21
  17. torchcodec/_core/FilterGraph.h +15 -1
  18. torchcodec/_core/Frame.cpp +17 -7
  19. torchcodec/_core/Frame.h +15 -61
  20. torchcodec/_core/Metadata.h +2 -2
  21. torchcodec/_core/NVDECCache.cpp +70 -0
  22. torchcodec/_core/NVDECCache.h +104 -0
  23. torchcodec/_core/SingleStreamDecoder.cpp +202 -198
  24. torchcodec/_core/SingleStreamDecoder.h +39 -14
  25. torchcodec/_core/StreamOptions.h +16 -6
  26. torchcodec/_core/Transform.cpp +60 -0
  27. torchcodec/_core/Transform.h +59 -0
  28. torchcodec/_core/__init__.py +1 -0
  29. torchcodec/_core/custom_ops.cpp +180 -32
  30. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +61 -1
  31. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  32. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  33. torchcodec/_core/ops.py +86 -43
  34. torchcodec/_core/pybind_ops.cpp +22 -59
  35. torchcodec/_samplers/video_clip_sampler.py +7 -19
  36. torchcodec/decoders/__init__.py +1 -0
  37. torchcodec/decoders/_decoder_utils.py +61 -1
  38. torchcodec/decoders/_video_decoder.py +56 -20
  39. torchcodec/libtorchcodec_core4.dll +0 -0
  40. torchcodec/libtorchcodec_core5.dll +0 -0
  41. torchcodec/libtorchcodec_core6.dll +0 -0
  42. torchcodec/libtorchcodec_core7.dll +0 -0
  43. torchcodec/libtorchcodec_core8.dll +0 -0
  44. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  45. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  46. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  47. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  48. torchcodec/libtorchcodec_custom_ops8.dll +0 -0
  49. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  50. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  51. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  52. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  53. torchcodec/libtorchcodec_pybind_ops8.pyd +0 -0
  54. torchcodec/samplers/_time_based.py +8 -0
  55. torchcodec/version.py +1 -1
  56. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/METADATA +24 -13
  57. torchcodec-0.8.0.dist-info/RECORD +80 -0
  58. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/WHEEL +1 -1
  59. torchcodec-0.7.0.dist-info/RECORD +0 -67
  60. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/licenses/LICENSE +0 -0
  61. {torchcodec-0.7.0.dist-info → torchcodec-0.8.0.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@
17
17
  #include "src/torchcodec/_core/FFMPEGCommon.h"
18
18
  #include "src/torchcodec/_core/Frame.h"
19
19
  #include "src/torchcodec/_core/StreamOptions.h"
20
+ #include "src/torchcodec/_core/Transform.h"
20
21
 
21
22
  namespace facebook::torchcodec {
22
23
 
@@ -83,6 +84,7 @@ class SingleStreamDecoder {
83
84
 
84
85
  void addVideoStream(
85
86
  int streamIndex,
87
+ std::vector<Transform*>& transforms,
86
88
  const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(),
87
89
  std::optional<FrameMappings> customFrameMappings = std::nullopt);
88
90
  void addAudioStream(
@@ -106,7 +108,7 @@ class SingleStreamDecoder {
106
108
 
107
109
  // Returns frames at the given indices for a given stream as a single stacked
108
110
  // Tensor.
109
- FrameBatchOutput getFramesAtIndices(const std::vector<int64_t>& frameIndices);
111
+ FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices);
110
112
 
111
113
  // Returns frames within a given range. The range is defined by [start, stop).
112
114
  // The values retrieved from the range are: [start, start+step,
@@ -121,7 +123,7 @@ class SingleStreamDecoder {
121
123
  // seconds=5.999, etc.
122
124
  FrameOutput getFramePlayedAt(double seconds);
123
125
 
124
- FrameBatchOutput getFramesPlayedAt(const std::vector<double>& timestamps);
126
+ FrameBatchOutput getFramesPlayedAt(const torch::Tensor& timestamps);
125
127
 
126
128
  // Returns frames within a given pts range. The range is defined by
127
129
  // [startSeconds, stopSeconds) with respect to the pts values for frames. The
@@ -226,17 +228,8 @@ class SingleStreamDecoder {
226
228
  std::vector<FrameInfo> keyFrames;
227
229
  std::vector<FrameInfo> allFrames;
228
230
 
229
- // TODO since the decoder is single-stream, these should be decoder fields,
230
- // not streamInfo fields. And they should be defined right next to
231
- // `cursor_`, with joint documentation.
232
- int64_t lastDecodedAvFramePts = 0;
233
- int64_t lastDecodedAvFrameDuration = 0;
234
231
  VideoStreamOptions videoStreamOptions;
235
232
  AudioStreamOptions audioStreamOptions;
236
-
237
- // color-conversion fields. Only one of FilterGraphContext and
238
- // UniqueSwsContext should be non-null.
239
- UniqueSwrContext swrContext;
240
233
  };
241
234
 
242
235
  // --------------------------------------------------------------------------
@@ -318,6 +311,7 @@ class SingleStreamDecoder {
318
311
  int streamIndex,
319
312
  AVMediaType mediaType,
320
313
  const torch::Device& device = torch::kCPU,
314
+ const std::string_view deviceVariant = "default",
321
315
  std::optional<int> ffmpegThreadCount = std::nullopt);
322
316
 
323
317
  // Returns the "best" stream index for a given media type. The "best" is
@@ -356,16 +350,49 @@ class SingleStreamDecoder {
356
350
  const int NO_ACTIVE_STREAM = -2;
357
351
  int activeStreamIndex_ = NO_ACTIVE_STREAM;
358
352
 
359
- bool cursorWasJustSet_ = false;
360
353
  // The desired position of the cursor in the stream. We send frames >= this
361
354
  // pts to the user when they request a frame.
362
355
  int64_t cursor_ = INT64_MIN;
356
+ bool cursorWasJustSet_ = false;
357
+ int64_t lastDecodedAvFramePts_ = 0;
358
+ int64_t lastDecodedAvFrameDuration_ = 0;
359
+
360
+ // Audio only. We cache it for performance. The video equivalents live in
361
+ // deviceInterface_. We store swrContext_ here because we only handle audio
362
+ // on the CPU.
363
+ UniqueSwrContext swrContext_;
364
+
363
365
  // Stores various internal decoding stats.
364
366
  DecodeStats decodeStats_;
367
+
365
368
  // Stores the AVIOContext for the input buffer.
366
369
  std::unique_ptr<AVIOContextHolder> avioContextHolder_;
370
+
371
+ // We will receive a vector of transforms upon adding a stream and store it
372
+ // here. However, we need to know if any of those operations change the
373
+ // dimensions of the output frame. If they do, we need to figure out what are
374
+ // the final dimensions of the output frame after ALL transformations. We
375
+ // figure this out as soon as we receive the transforms. If any of the
376
+ // transforms change the final output frame dimensions, we store that in
377
+ // resizedOutputDims_. If resizedOutputDims_ has no value, that means there
378
+ // are no transforms that change the output frame dimensions.
379
+ //
380
+ // The priority order for output frame dimension is:
381
+ //
382
+ // 1. resizedOutputDims_; the resize requested by the user always takes
383
+ // priority.
384
+ // 2. The dimemnsions of the actual decoded AVFrame. This can change
385
+ // per-decoded frame, and is unknown in SingleStreamDecoder. Only the
386
+ // DeviceInterface learns it immediately after decoding a raw frame but
387
+ // before the color transformation.
388
+ // 3. metdataDims_; the dimensions we learned from the metadata.
389
+ std::vector<std::unique_ptr<Transform>> transforms_;
390
+ std::optional<FrameDims> resizedOutputDims_;
391
+ FrameDims metadataDims_;
392
+
367
393
  // Whether or not we have already scanned all streams to update the metadata.
368
394
  bool scannedAllStreams_ = false;
395
+
369
396
  // Tracks that we've already been initialized.
370
397
  bool initialized_ = false;
371
398
  };
@@ -375,6 +402,4 @@ std::ostream& operator<<(
375
402
  std::ostream& os,
376
403
  const SingleStreamDecoder::DecodeStats& stats);
377
404
 
378
- SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode);
379
-
380
405
  } // namespace facebook::torchcodec
@@ -9,11 +9,11 @@
9
9
  #include <torch/types.h>
10
10
  #include <optional>
11
11
  #include <string>
12
+ #include <string_view>
12
13
 
13
14
  namespace facebook::torchcodec {
14
15
 
15
16
  enum ColorConversionLibrary {
16
- // TODO: Add an AUTO option later.
17
17
  // Use the libavfilter library for color conversion.
18
18
  FILTERGRAPH,
19
19
  // Use the libswscale library for color conversion.
@@ -28,16 +28,26 @@ struct VideoStreamOptions {
28
28
  // utilize all cores. If not set, it will be the default FFMPEG behavior for
29
29
  // the given codec.
30
30
  std::optional<int> ffmpegThreadCount;
31
+
31
32
  // Currently the dimension order can be either NHWC or NCHW.
32
33
  // H=height, W=width, C=channel.
33
34
  std::string dimensionOrder = "NCHW";
34
- // The output height and width of the frame. If not specified, the output
35
- // is the same as the original video.
36
- std::optional<int> width;
37
- std::optional<int> height;
38
- std::optional<ColorConversionLibrary> colorConversionLibrary;
35
+
36
+ // By default we have to use filtergraph, as it is more general. We can only
37
+ // use swscale when we have met strict requirements. See
38
+ // CpuDeviceInterface::initialze() for the logic.
39
+ ColorConversionLibrary colorConversionLibrary =
40
+ ColorConversionLibrary::FILTERGRAPH;
41
+
39
42
  // By default we use CPU for decoding for both C++ and python users.
40
43
  torch::Device device = torch::kCPU;
44
+ // Device variant (e.g., "default", "beta", etc.)
45
+ std::string_view deviceVariant = "default";
46
+
47
+ // Encoding options
48
+ // TODO-VideoEncoder: Consider adding other optional fields here
49
+ // (bit rate, gop size, max b frames, preset)
50
+ std::optional<int> crf;
41
51
  };
42
52
 
43
53
  struct AudioStreamOptions {
@@ -0,0 +1,60 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include "src/torchcodec/_core/Transform.h"
8
+ #include <torch/types.h>
9
+ #include "src/torchcodec/_core/FFMPEGCommon.h"
10
+
11
+ namespace facebook::torchcodec {
12
+
13
+ namespace {
14
+
15
+ std::string toFilterGraphInterpolation(
16
+ ResizeTransform::InterpolationMode mode) {
17
+ switch (mode) {
18
+ case ResizeTransform::InterpolationMode::BILINEAR:
19
+ return "bilinear";
20
+ default:
21
+ TORCH_CHECK(
22
+ false,
23
+ "Unknown interpolation mode: " +
24
+ std::to_string(static_cast<int>(mode)));
25
+ }
26
+ }
27
+
28
+ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
29
+ switch (mode) {
30
+ case ResizeTransform::InterpolationMode::BILINEAR:
31
+ return SWS_BILINEAR;
32
+ default:
33
+ TORCH_CHECK(
34
+ false,
35
+ "Unknown interpolation mode: " +
36
+ std::to_string(static_cast<int>(mode)));
37
+ }
38
+ }
39
+
40
+ } // namespace
41
+
42
+ std::string ResizeTransform::getFilterGraphCpu() const {
43
+ return "scale=" + std::to_string(outputDims_.width) + ":" +
44
+ std::to_string(outputDims_.height) +
45
+ ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
46
+ }
47
+
48
+ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
49
+ return outputDims_;
50
+ }
51
+
52
+ bool ResizeTransform::isResize() const {
53
+ return true;
54
+ }
55
+
56
+ int ResizeTransform::getSwsFlags() const {
57
+ return toSwsInterpolation(interpolationMode_);
58
+ }
59
+
60
+ } // namespace facebook::torchcodec
@@ -0,0 +1,59 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <optional>
10
+ #include <string>
11
+ #include "src/torchcodec/_core/Frame.h"
12
+
13
+ namespace facebook::torchcodec {
14
+
15
+ class Transform {
16
+ public:
17
+ virtual std::string getFilterGraphCpu() const = 0;
18
+ virtual ~Transform() = default;
19
+
20
+ // If the transformation does not change the output frame dimensions, then
21
+ // there is no need to override this member function. The default
22
+ // implementation returns an empty optional, indicating that the output frame
23
+ // has the same dimensions as the input frame.
24
+ //
25
+ // If the transformation does change the output frame dimensions, then it
26
+ // must override this member function and return the output frame dimensions.
27
+ virtual std::optional<FrameDims> getOutputFrameDims() const {
28
+ return std::nullopt;
29
+ }
30
+
31
+ // The ResizeTransform is special, because it is the only transform that
32
+ // swscale can handle.
33
+ virtual bool isResize() const {
34
+ return false;
35
+ }
36
+ };
37
+
38
+ class ResizeTransform : public Transform {
39
+ public:
40
+ enum class InterpolationMode { BILINEAR };
41
+
42
+ ResizeTransform(const FrameDims& dims)
43
+ : outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {}
44
+
45
+ ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode)
46
+ : outputDims_(dims), interpolationMode_(interpolationMode) {}
47
+
48
+ std::string getFilterGraphCpu() const override;
49
+ std::optional<FrameDims> getOutputFrameDims() const override;
50
+ bool isResize() const override;
51
+
52
+ int getSwsFlags() const;
53
+
54
+ private:
55
+ FrameDims outputDims_;
56
+ InterpolationMode interpolationMode_;
57
+ };
58
+
59
+ } // namespace facebook::torchcodec
@@ -25,6 +25,7 @@ from .ops import (
25
25
  encode_audio_to_file,
26
26
  encode_audio_to_file_like,
27
27
  encode_audio_to_tensor,
28
+ encode_video_to_file,
28
29
  get_ffmpeg_library_versions,
29
30
  get_frame_at_index,
30
31
  get_frame_at_pts,
@@ -10,6 +10,7 @@
10
10
  #include <string>
11
11
  #include "c10/core/SymIntArrayRef.h"
12
12
  #include "c10/util/Exception.h"
13
+ #include "src/torchcodec/_core/AVIOFileLikeContext.h"
13
14
  #include "src/torchcodec/_core/AVIOTensorContext.h"
14
15
  #include "src/torchcodec/_core/Encoder.h"
15
16
  #include "src/torchcodec/_core/SingleStreamDecoder.h"
@@ -31,15 +32,20 @@ TORCH_LIBRARY(torchcodec_ns, m) {
31
32
  m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
32
33
  m.def(
33
34
  "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
35
+ m.def(
36
+ "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
34
37
  m.def(
35
38
  "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
39
+ m.def(
40
+ "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
36
41
  m.def(
37
42
  "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
38
- m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
39
43
  m.def(
40
- "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
44
+ "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
41
45
  m.def(
42
- "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
46
+ "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
47
+ m.def(
48
+ "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
43
49
  m.def(
44
50
  "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
45
51
  m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
@@ -49,7 +55,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
49
55
  m.def(
50
56
  "get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
51
57
  m.def(
52
- "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
58
+ "get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)");
53
59
  m.def(
54
60
  "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
55
61
  m.def(
@@ -57,7 +63,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
57
63
  m.def(
58
64
  "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
59
65
  m.def(
60
- "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
66
+ "get_frames_by_pts(Tensor(a!) decoder, *, Tensor timestamps) -> (Tensor, Tensor, Tensor)");
61
67
  m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
62
68
  m.def("get_json_metadata(Tensor(a!) decoder) -> str");
63
69
  m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
@@ -165,6 +171,81 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
165
171
  return ss.str();
166
172
  }
167
173
 
174
+ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
175
+ if (seekMode == "exact") {
176
+ return SingleStreamDecoder::SeekMode::exact;
177
+ } else if (seekMode == "approximate") {
178
+ return SingleStreamDecoder::SeekMode::approximate;
179
+ } else if (seekMode == "custom_frame_mappings") {
180
+ return SingleStreamDecoder::SeekMode::custom_frame_mappings;
181
+ } else {
182
+ TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
183
+ }
184
+ }
185
+
186
+ int checkedToPositiveInt(const std::string& str) {
187
+ int ret = 0;
188
+ try {
189
+ ret = std::stoi(str);
190
+ } catch (const std::invalid_argument&) {
191
+ TORCH_CHECK(false, "String cannot be converted to an int:" + str);
192
+ } catch (const std::out_of_range&) {
193
+ TORCH_CHECK(false, "String would become integer out of range:" + str);
194
+ }
195
+ TORCH_CHECK(ret > 0, "String must be a positive integer:" + str);
196
+ return ret;
197
+ }
198
+
199
+ // Resize transform specs take the form:
200
+ //
201
+ // "resize, <height>, <width>"
202
+ //
203
+ // Where "resize" is the string literal and <height> and <width> are positive
204
+ // integers.
205
+ Transform* makeResizeTransform(
206
+ const std::vector<std::string>& resizeTransformSpec) {
207
+ TORCH_CHECK(
208
+ resizeTransformSpec.size() == 3,
209
+ "resizeTransformSpec must have 3 elements including its name");
210
+ int height = checkedToPositiveInt(resizeTransformSpec[1]);
211
+ int width = checkedToPositiveInt(resizeTransformSpec[2]);
212
+ return new ResizeTransform(FrameDims(height, width));
213
+ }
214
+
215
+ std::vector<std::string> split(const std::string& str, char delimiter) {
216
+ std::vector<std::string> tokens;
217
+ std::string token;
218
+ std::istringstream tokenStream(str);
219
+ while (std::getline(tokenStream, token, delimiter)) {
220
+ tokens.push_back(token);
221
+ }
222
+ return tokens;
223
+ }
224
+
225
+ // The transformSpecsRaw string is always in the format:
226
+ //
227
+ // "name1, param1, param2, ...; name2, param1, param2, ...; ..."
228
+ //
229
+ // Where "nameX" is the name of the transform, and "paramX" are the parameters.
230
+ std::vector<Transform*> makeTransforms(const std::string& transformSpecsRaw) {
231
+ std::vector<Transform*> transforms;
232
+ std::vector<std::string> transformSpecs = split(transformSpecsRaw, ';');
233
+ for (const std::string& transformSpecRaw : transformSpecs) {
234
+ std::vector<std::string> transformSpec = split(transformSpecRaw, ',');
235
+ TORCH_CHECK(
236
+ transformSpec.size() >= 1,
237
+ "Invalid transform spec: " + transformSpecRaw);
238
+
239
+ auto name = transformSpec[0];
240
+ if (name == "resize") {
241
+ transforms.push_back(makeResizeTransform(transformSpec));
242
+ } else {
243
+ TORCH_CHECK(false, "Invalid transform name: " + name);
244
+ }
245
+ }
246
+ return transforms;
247
+ }
248
+
168
249
  } // namespace
169
250
 
170
251
  // ==============================
@@ -203,33 +284,47 @@ at::Tensor create_from_tensor(
203
284
  realSeek = seekModeFromString(seek_mode.value());
204
285
  }
205
286
 
206
- auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
287
+ auto avioContextHolder =
288
+ std::make_unique<AVIOFromTensorContext>(video_tensor);
207
289
 
208
290
  std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
209
- std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);
291
+ std::make_unique<SingleStreamDecoder>(
292
+ std::move(avioContextHolder), realSeek);
210
293
  return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
211
294
  }
212
295
 
213
- at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
214
- auto decoder = reinterpret_cast<SingleStreamDecoder*>(decoder_ptr);
215
- std::unique_ptr<SingleStreamDecoder> uniqueDecoder(decoder);
296
+ at::Tensor _create_from_file_like(
297
+ int64_t file_like_context,
298
+ std::optional<std::string_view> seek_mode) {
299
+ auto fileLikeContext =
300
+ reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
301
+ TORCH_CHECK(
302
+ fileLikeContext != nullptr, "file_like_context must be a valid pointer");
303
+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
304
+
305
+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
306
+ if (seek_mode.has_value()) {
307
+ realSeek = seekModeFromString(seek_mode.value());
308
+ }
309
+
310
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
311
+ std::make_unique<SingleStreamDecoder>(
312
+ std::move(avioContextHolder), realSeek);
216
313
  return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
217
314
  }
218
315
 
219
316
  void _add_video_stream(
220
317
  at::Tensor& decoder,
221
- std::optional<int64_t> width = std::nullopt,
222
- std::optional<int64_t> height = std::nullopt,
223
318
  std::optional<int64_t> num_threads = std::nullopt,
224
319
  std::optional<std::string_view> dimension_order = std::nullopt,
225
320
  std::optional<int64_t> stream_index = std::nullopt,
226
- std::optional<std::string_view> device = std::nullopt,
321
+ std::string_view device = "cpu",
322
+ std::string_view device_variant = "default",
323
+ std::string_view transform_specs = "",
227
324
  std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
228
325
  custom_frame_mappings = std::nullopt,
229
326
  std::optional<std::string_view> color_conversion_library = std::nullopt) {
230
327
  VideoStreamOptions videoStreamOptions;
231
- videoStreamOptions.width = width;
232
- videoStreamOptions.height = height;
233
328
  videoStreamOptions.ffmpegThreadCount = num_threads;
234
329
 
235
330
  if (dimension_order.has_value()) {
@@ -253,37 +348,46 @@ void _add_video_stream(
253
348
  ". color_conversion_library must be either filtergraph or swscale.");
254
349
  }
255
350
  }
256
- if (device.has_value()) {
257
- videoStreamOptions.device = createTorchDevice(std::string(device.value()));
258
- }
351
+
352
+ validateDeviceInterface(std::string(device), std::string(device_variant));
353
+
354
+ videoStreamOptions.device = torch::Device(std::string(device));
355
+ videoStreamOptions.deviceVariant = device_variant;
356
+
357
+ std::vector<Transform*> transforms =
358
+ makeTransforms(std::string(transform_specs));
359
+
259
360
  std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
260
361
  custom_frame_mappings.has_value()
261
362
  ? std::make_optional(makeFrameMappings(custom_frame_mappings.value()))
262
363
  : std::nullopt;
263
364
  auto videoDecoder = unwrapTensorToGetDecoder(decoder);
264
365
  videoDecoder->addVideoStream(
265
- stream_index.value_or(-1), videoStreamOptions, converted_mappings);
366
+ stream_index.value_or(-1),
367
+ transforms,
368
+ videoStreamOptions,
369
+ converted_mappings);
266
370
  }
267
371
 
268
372
  // Add a new video stream at `stream_index` using the provided options.
269
373
  void add_video_stream(
270
374
  at::Tensor& decoder,
271
- std::optional<int64_t> width = std::nullopt,
272
- std::optional<int64_t> height = std::nullopt,
273
375
  std::optional<int64_t> num_threads = std::nullopt,
274
376
  std::optional<std::string_view> dimension_order = std::nullopt,
275
377
  std::optional<int64_t> stream_index = std::nullopt,
276
- std::optional<std::string_view> device = std::nullopt,
378
+ std::string_view device = "cpu",
379
+ std::string_view device_variant = "default",
380
+ std::string_view transform_specs = "",
277
381
  const std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>&
278
382
  custom_frame_mappings = std::nullopt) {
279
383
  _add_video_stream(
280
384
  decoder,
281
- width,
282
- height,
283
385
  num_threads,
284
386
  dimension_order,
285
387
  stream_index,
286
388
  device,
389
+ device_variant,
390
+ transform_specs,
287
391
  custom_frame_mappings);
288
392
  }
289
393
 
@@ -344,11 +448,9 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
344
448
  // Return the frames at given indices for a given stream
345
449
  OpsFrameBatchOutput get_frames_at_indices(
346
450
  at::Tensor& decoder,
347
- at::IntArrayRef frame_indices) {
451
+ const at::Tensor& frame_indices) {
348
452
  auto videoDecoder = unwrapTensorToGetDecoder(decoder);
349
- std::vector<int64_t> frameIndicesVec(
350
- frame_indices.begin(), frame_indices.end());
351
- auto result = videoDecoder->getFramesAtIndices(frameIndicesVec);
453
+ auto result = videoDecoder->getFramesAtIndices(frame_indices);
352
454
  return makeOpsFrameBatchOutput(result);
353
455
  }
354
456
 
@@ -367,10 +469,9 @@ OpsFrameBatchOutput get_frames_in_range(
367
469
  // Return the frames at given ptss for a given stream
368
470
  OpsFrameBatchOutput get_frames_by_pts(
369
471
  at::Tensor& decoder,
370
- at::ArrayRef<double> timestamps) {
472
+ const at::Tensor& timestamps) {
371
473
  auto videoDecoder = unwrapTensorToGetDecoder(decoder);
372
- std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
373
- auto result = videoDecoder->getFramesPlayedAt(timestampsVec);
474
+ auto result = videoDecoder->getFramesPlayedAt(timestamps);
374
475
  return makeOpsFrameBatchOutput(result);
375
476
  }
376
477
 
@@ -397,6 +498,21 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
397
498
  return makeOpsAudioFramesOutput(result);
398
499
  }
399
500
 
501
+ void encode_video_to_file(
502
+ const at::Tensor& frames,
503
+ int64_t frame_rate,
504
+ std::string_view file_name,
505
+ std::optional<int64_t> crf = std::nullopt) {
506
+ VideoStreamOptions videoStreamOptions;
507
+ videoStreamOptions.crf = crf;
508
+ VideoEncoder(
509
+ frames,
510
+ validateInt64ToInt(frame_rate, "frame_rate"),
511
+ file_name,
512
+ videoStreamOptions)
513
+ .encode();
514
+ }
515
+
400
516
  void encode_audio_to_file(
401
517
  const at::Tensor& samples,
402
518
  int64_t sample_rate,
@@ -441,6 +557,36 @@ at::Tensor encode_audio_to_tensor(
441
557
  .encodeToTensor();
442
558
  }
443
559
 
560
+ void _encode_audio_to_file_like(
561
+ const at::Tensor& samples,
562
+ int64_t sample_rate,
563
+ std::string_view format,
564
+ int64_t file_like_context,
565
+ std::optional<int64_t> bit_rate = std::nullopt,
566
+ std::optional<int64_t> num_channels = std::nullopt,
567
+ std::optional<int64_t> desired_sample_rate = std::nullopt) {
568
+ auto fileLikeContext =
569
+ reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
570
+ TORCH_CHECK(
571
+ fileLikeContext != nullptr, "file_like_context must be a valid pointer");
572
+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
573
+
574
+ AudioStreamOptions audioStreamOptions;
575
+ audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
576
+ audioStreamOptions.numChannels =
577
+ validateOptionalInt64ToInt(num_channels, "num_channels");
578
+ audioStreamOptions.sampleRate =
579
+ validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
580
+
581
+ AudioEncoder encoder(
582
+ samples,
583
+ validateInt64ToInt(sample_rate, "sample_rate"),
584
+ format,
585
+ std::move(avioContextHolder),
586
+ audioStreamOptions);
587
+ encoder.encode();
588
+ }
589
+
444
590
  // For testing only. We need to implement this operation as a core library
445
591
  // function because what we're testing is round-tripping pts values as
446
592
  // double-precision floating point numbers from C++ to Python and back to C++.
@@ -694,14 +840,16 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
694
840
  TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
695
841
  m.impl("create_from_file", &create_from_file);
696
842
  m.impl("create_from_tensor", &create_from_tensor);
697
- m.impl("_convert_to_tensor", &_convert_to_tensor);
843
+ m.impl("_create_from_file_like", &_create_from_file_like);
698
844
  m.impl(
699
845
  "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
700
846
  }
701
847
 
702
848
  TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
703
849
  m.impl("encode_audio_to_file", &encode_audio_to_file);
850
+ m.impl("encode_video_to_file", &encode_video_to_file);
704
851
  m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
852
+ m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
705
853
  m.impl("seek_to_pts", &seek_to_pts);
706
854
  m.impl("add_video_stream", &add_video_stream);
707
855
  m.impl("_add_video_stream", &_add_video_stream);