torchcodec 0.7.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (67) hide show
  1. torchcodec/__init__.py +16 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +123 -0
  7. torchcodec/_core/AVIOTensorContext.h +43 -0
  8. torchcodec/_core/CMakeLists.txt +292 -0
  9. torchcodec/_core/Cache.h +138 -0
  10. torchcodec/_core/CpuDeviceInterface.cpp +266 -0
  11. torchcodec/_core/CpuDeviceInterface.h +70 -0
  12. torchcodec/_core/CudaDeviceInterface.cpp +514 -0
  13. torchcodec/_core/CudaDeviceInterface.h +37 -0
  14. torchcodec/_core/DeviceInterface.cpp +79 -0
  15. torchcodec/_core/DeviceInterface.h +67 -0
  16. torchcodec/_core/Encoder.cpp +514 -0
  17. torchcodec/_core/Encoder.h +123 -0
  18. torchcodec/_core/FFMPEGCommon.cpp +421 -0
  19. torchcodec/_core/FFMPEGCommon.h +227 -0
  20. torchcodec/_core/FilterGraph.cpp +142 -0
  21. torchcodec/_core/FilterGraph.h +45 -0
  22. torchcodec/_core/Frame.cpp +32 -0
  23. torchcodec/_core/Frame.h +118 -0
  24. torchcodec/_core/Metadata.h +72 -0
  25. torchcodec/_core/SingleStreamDecoder.cpp +1715 -0
  26. torchcodec/_core/SingleStreamDecoder.h +380 -0
  27. torchcodec/_core/StreamOptions.h +53 -0
  28. torchcodec/_core/ValidationUtils.cpp +35 -0
  29. torchcodec/_core/ValidationUtils.h +21 -0
  30. torchcodec/_core/__init__.py +40 -0
  31. torchcodec/_core/_metadata.py +317 -0
  32. torchcodec/_core/custom_ops.cpp +727 -0
  33. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +300 -0
  34. torchcodec/_core/ops.py +455 -0
  35. torchcodec/_core/pybind_ops.cpp +87 -0
  36. torchcodec/_frame.py +145 -0
  37. torchcodec/_internally_replaced_utils.py +67 -0
  38. torchcodec/_samplers/__init__.py +7 -0
  39. torchcodec/_samplers/video_clip_sampler.py +430 -0
  40. torchcodec/decoders/__init__.py +11 -0
  41. torchcodec/decoders/_audio_decoder.py +177 -0
  42. torchcodec/decoders/_decoder_utils.py +52 -0
  43. torchcodec/decoders/_video_decoder.py +464 -0
  44. torchcodec/encoders/__init__.py +1 -0
  45. torchcodec/encoders/_audio_encoder.py +150 -0
  46. torchcodec/libtorchcodec_core4.dll +0 -0
  47. torchcodec/libtorchcodec_core5.dll +0 -0
  48. torchcodec/libtorchcodec_core6.dll +0 -0
  49. torchcodec/libtorchcodec_core7.dll +0 -0
  50. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  51. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  52. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  53. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  54. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  55. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  56. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  57. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  58. torchcodec/samplers/__init__.py +2 -0
  59. torchcodec/samplers/_common.py +84 -0
  60. torchcodec/samplers/_index_based.py +287 -0
  61. torchcodec/samplers/_time_based.py +350 -0
  62. torchcodec/version.py +2 -0
  63. torchcodec-0.7.0.dist-info/METADATA +242 -0
  64. torchcodec-0.7.0.dist-info/RECORD +67 -0
  65. torchcodec-0.7.0.dist-info/WHEEL +5 -0
  66. torchcodec-0.7.0.dist-info/licenses/LICENSE +28 -0
  67. torchcodec-0.7.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,727 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include <pybind11/pybind11.h>
8
+ #include <cstdint>
9
+ #include <sstream>
10
+ #include <string>
11
+ #include "c10/core/SymIntArrayRef.h"
12
+ #include "c10/util/Exception.h"
13
+ #include "src/torchcodec/_core/AVIOTensorContext.h"
14
+ #include "src/torchcodec/_core/Encoder.h"
15
+ #include "src/torchcodec/_core/SingleStreamDecoder.h"
16
+ #include "src/torchcodec/_core/ValidationUtils.h"
17
+
18
+ namespace facebook::torchcodec {
19
+
20
+ // ==============================
21
+ // Define the operators
22
+ // ==============================
23
+ // All instances of accepting the decoder as a tensor must be annotated with
24
+ // `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being
25
+ // mutated in place. We need it to make sure that torch.compile does not reorder
26
+ // calls to these functions. For more detail, see:
27
+ // https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme
28
+ TORCH_LIBRARY(torchcodec_ns, m) {
29
+ m.impl_abstract_pystub(
30
+ "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
31
+ m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
32
+ m.def(
33
+ "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
34
+ m.def(
35
+ "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
36
+ m.def(
37
+ "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
38
+ m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
39
+ m.def(
40
+ "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
41
+ m.def(
42
+ "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
43
+ m.def(
44
+ "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
45
+ m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
46
+ m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
47
+ m.def(
48
+ "get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
49
+ m.def(
50
+ "get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
51
+ m.def(
52
+ "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
53
+ m.def(
54
+ "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
55
+ m.def(
56
+ "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
57
+ m.def(
58
+ "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
59
+ m.def(
60
+ "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
61
+ m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
62
+ m.def("get_json_metadata(Tensor(a!) decoder) -> str");
63
+ m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
64
+ m.def(
65
+ "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
66
+ m.def("_get_json_ffmpeg_library_versions() -> str");
67
+ m.def(
68
+ "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
69
+ m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
70
+ }
71
+
72
+ namespace {
73
+
74
+ at::Tensor wrapDecoderPointerToTensor(
75
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder) {
76
+ SingleStreamDecoder* decoder = uniqueDecoder.release();
77
+
78
+ auto deleter = [decoder](void*) { delete decoder; };
79
+ at::Tensor tensor = at::from_blob(
80
+ decoder, {sizeof(SingleStreamDecoder*)}, deleter, {at::kLong});
81
+ auto videoDecoder =
82
+ static_cast<SingleStreamDecoder*>(tensor.mutable_data_ptr());
83
+ TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << videoDecoder;
84
+ return tensor;
85
+ }
86
+
87
+ SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
88
+ TORCH_CHECK(
89
+ tensor.is_contiguous(),
90
+ "fake decoder tensor must be contiguous! This is an internal error, please report on the torchcodec issue tracker.");
91
+ void* buffer = tensor.mutable_data_ptr();
92
+ SingleStreamDecoder* decoder = static_cast<SingleStreamDecoder*>(buffer);
93
+ return decoder;
94
+ }
95
+
96
+ // The elements of this tuple are all tensors that represent a single frame:
97
+ // 1. The frame data, which is a multidimensional tensor.
98
+ // 2. A single float value for the pts in seconds.
99
+ // 3. A single float value for the duration in seconds.
100
+ // The reason we use Tensors for the second and third values is so we can run
101
+ // under torch.compile().
102
+ using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
103
+
104
+ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) {
105
+ return std::make_tuple(
106
+ frame.data,
107
+ torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)),
108
+ torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64)));
109
+ }
110
+
111
+ SingleStreamDecoder::FrameMappings makeFrameMappings(
112
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> custom_frame_mappings) {
113
+ return SingleStreamDecoder::FrameMappings{
114
+ std::move(std::get<0>(custom_frame_mappings)),
115
+ std::move(std::get<1>(custom_frame_mappings)),
116
+ std::move(std::get<2>(custom_frame_mappings))};
117
+ }
118
+
119
+ // All elements of this tuple are tensors of the same leading dimension. The
120
+ // tuple represents the frames for N total frames, where N is the dimension of
121
+ // each stacked tensor. The elments are:
122
+ // 1. Stacked tensor of data for all N frames. Each frame is also a
123
+ // multidimensional tensor.
124
+ // 2. Tensor of N pts values in seconds, where each pts is a single
125
+ // float.
126
+ // 3. Tensor of N durationis in seconds, where each duration is a
127
+ // single float.
128
+ using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
129
+
130
+ OpsFrameBatchOutput makeOpsFrameBatchOutput(FrameBatchOutput& batch) {
131
+ return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds);
132
+ }
133
+
134
+ // The elements of this tuple are all tensors that represent the concatenation
135
+ // of multiple audio frames:
136
+ // 1. The frames data (concatenated)
137
+ // 2. A single float value for the pts of the first frame, in seconds.
138
+ using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;
139
+
140
+ OpsAudioFramesOutput makeOpsAudioFramesOutput(AudioFramesOutput& audioFrames) {
141
+ return std::make_tuple(
142
+ audioFrames.data,
143
+ torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64)));
144
+ }
145
+
146
+ std::string quoteValue(const std::string& value) {
147
+ return "\"" + value + "\"";
148
+ }
149
+
150
+ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
151
+ std::stringstream ss;
152
+ ss << "{\n";
153
+ auto it = metadataMap.begin();
154
+ while (it != metadataMap.end()) {
155
+ ss << "\"" << it->first << "\": " << it->second;
156
+ ++it;
157
+ if (it != metadataMap.end()) {
158
+ ss << ",\n";
159
+ } else {
160
+ ss << "\n";
161
+ }
162
+ }
163
+ ss << "}";
164
+
165
+ return ss.str();
166
+ }
167
+
168
+ } // namespace
169
+
170
+ // ==============================
171
+ // Implementations for the operators
172
+ // ==============================
173
+
174
+ // Create a SingleStreamDecoder from file and wrap the pointer in a tensor.
175
+ at::Tensor create_from_file(
176
+ std::string_view filename,
177
+ std::optional<std::string_view> seek_mode = std::nullopt) {
178
+ std::string filenameStr(filename);
179
+
180
+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
181
+ if (seek_mode.has_value()) {
182
+ realSeek = seekModeFromString(seek_mode.value());
183
+ }
184
+
185
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
186
+ std::make_unique<SingleStreamDecoder>(filenameStr, realSeek);
187
+
188
+ return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
189
+ }
190
+
191
+ // Create a SingleStreamDecoder from the actual bytes of a video and wrap the
192
+ // pointer in a tensor. The SingleStreamDecoder will decode the provided bytes.
193
+ at::Tensor create_from_tensor(
194
+ at::Tensor video_tensor,
195
+ std::optional<std::string_view> seek_mode = std::nullopt) {
196
+ TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous");
197
+ TORCH_CHECK(
198
+ video_tensor.scalar_type() == torch::kUInt8,
199
+ "video_tensor must be kUInt8");
200
+
201
+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
202
+ if (seek_mode.has_value()) {
203
+ realSeek = seekModeFromString(seek_mode.value());
204
+ }
205
+
206
+ auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
207
+
208
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
209
+ std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);
210
+ return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
211
+ }
212
+
213
+ at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
214
+ auto decoder = reinterpret_cast<SingleStreamDecoder*>(decoder_ptr);
215
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder(decoder);
216
+ return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
217
+ }
218
+
219
+ void _add_video_stream(
220
+ at::Tensor& decoder,
221
+ std::optional<int64_t> width = std::nullopt,
222
+ std::optional<int64_t> height = std::nullopt,
223
+ std::optional<int64_t> num_threads = std::nullopt,
224
+ std::optional<std::string_view> dimension_order = std::nullopt,
225
+ std::optional<int64_t> stream_index = std::nullopt,
226
+ std::optional<std::string_view> device = std::nullopt,
227
+ std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
228
+ custom_frame_mappings = std::nullopt,
229
+ std::optional<std::string_view> color_conversion_library = std::nullopt) {
230
+ VideoStreamOptions videoStreamOptions;
231
+ videoStreamOptions.width = width;
232
+ videoStreamOptions.height = height;
233
+ videoStreamOptions.ffmpegThreadCount = num_threads;
234
+
235
+ if (dimension_order.has_value()) {
236
+ std::string stdDimensionOrder{dimension_order.value()};
237
+ TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW");
238
+ videoStreamOptions.dimensionOrder = stdDimensionOrder;
239
+ }
240
+ if (color_conversion_library.has_value()) {
241
+ std::string stdColorConversionLibrary{color_conversion_library.value()};
242
+ if (stdColorConversionLibrary == "filtergraph") {
243
+ videoStreamOptions.colorConversionLibrary =
244
+ ColorConversionLibrary::FILTERGRAPH;
245
+ } else if (stdColorConversionLibrary == "swscale") {
246
+ videoStreamOptions.colorConversionLibrary =
247
+ ColorConversionLibrary::SWSCALE;
248
+ } else {
249
+ TORCH_CHECK(
250
+ false,
251
+ "Invalid color_conversion_library=",
252
+ stdColorConversionLibrary,
253
+ ". color_conversion_library must be either filtergraph or swscale.");
254
+ }
255
+ }
256
+ if (device.has_value()) {
257
+ videoStreamOptions.device = createTorchDevice(std::string(device.value()));
258
+ }
259
+ std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
260
+ custom_frame_mappings.has_value()
261
+ ? std::make_optional(makeFrameMappings(custom_frame_mappings.value()))
262
+ : std::nullopt;
263
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
264
+ videoDecoder->addVideoStream(
265
+ stream_index.value_or(-1), videoStreamOptions, converted_mappings);
266
+ }
267
+
268
+ // Add a new video stream at `stream_index` using the provided options.
269
+ void add_video_stream(
270
+ at::Tensor& decoder,
271
+ std::optional<int64_t> width = std::nullopt,
272
+ std::optional<int64_t> height = std::nullopt,
273
+ std::optional<int64_t> num_threads = std::nullopt,
274
+ std::optional<std::string_view> dimension_order = std::nullopt,
275
+ std::optional<int64_t> stream_index = std::nullopt,
276
+ std::optional<std::string_view> device = std::nullopt,
277
+ const std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>&
278
+ custom_frame_mappings = std::nullopt) {
279
+ _add_video_stream(
280
+ decoder,
281
+ width,
282
+ height,
283
+ num_threads,
284
+ dimension_order,
285
+ stream_index,
286
+ device,
287
+ custom_frame_mappings);
288
+ }
289
+
290
+ void add_audio_stream(
291
+ at::Tensor& decoder,
292
+ std::optional<int64_t> stream_index = std::nullopt,
293
+ std::optional<int64_t> sample_rate = std::nullopt,
294
+ std::optional<int64_t> num_channels = std::nullopt) {
295
+ AudioStreamOptions audioStreamOptions;
296
+ audioStreamOptions.sampleRate = sample_rate;
297
+ audioStreamOptions.numChannels = num_channels;
298
+
299
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
300
+ videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);
301
+ }
302
+
303
+ // Seek to a particular presentation timestamp in the video in seconds.
304
+ void seek_to_pts(at::Tensor& decoder, double seconds) {
305
+ auto videoDecoder =
306
+ static_cast<SingleStreamDecoder*>(decoder.mutable_data_ptr());
307
+ videoDecoder->setCursorPtsInSeconds(seconds);
308
+ }
309
+
310
+ // Get the next frame from the video as a tuple that has the frame data, pts and
311
+ // duration as tensors.
312
+ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
313
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
314
+ FrameOutput result;
315
+ try {
316
+ result = videoDecoder->getNextFrame();
317
+ } catch (const SingleStreamDecoder::EndOfFileException& e) {
318
+ C10_THROW_ERROR(IndexError, e.what());
319
+ }
320
+ return makeOpsFrameOutput(result);
321
+ }
322
+
323
+ // Return the frame that is visible at a given timestamp in seconds. Each frame
324
+ // in FFMPEG has a presentation timestamp and a duration. The frame visible at a
325
+ // given timestamp T has T >= PTS and T < PTS + Duration.
326
+ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
327
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
328
+ FrameOutput result;
329
+ try {
330
+ result = videoDecoder->getFramePlayedAt(seconds);
331
+ } catch (const SingleStreamDecoder::EndOfFileException& e) {
332
+ C10_THROW_ERROR(IndexError, e.what());
333
+ }
334
+ return makeOpsFrameOutput(result);
335
+ }
336
+
337
+ // Return the frame that is visible at a given index in the video.
338
+ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
339
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
340
+ auto result = videoDecoder->getFrameAtIndex(frame_index);
341
+ return makeOpsFrameOutput(result);
342
+ }
343
+
344
+ // Return the frames at given indices for a given stream
345
+ OpsFrameBatchOutput get_frames_at_indices(
346
+ at::Tensor& decoder,
347
+ at::IntArrayRef frame_indices) {
348
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
349
+ std::vector<int64_t> frameIndicesVec(
350
+ frame_indices.begin(), frame_indices.end());
351
+ auto result = videoDecoder->getFramesAtIndices(frameIndicesVec);
352
+ return makeOpsFrameBatchOutput(result);
353
+ }
354
+
355
+ // Return the frames inside a range as a single stacked Tensor. The range is
356
+ // defined as [start, stop).
357
+ OpsFrameBatchOutput get_frames_in_range(
358
+ at::Tensor& decoder,
359
+ int64_t start,
360
+ int64_t stop,
361
+ std::optional<int64_t> step = std::nullopt) {
362
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
363
+ auto result = videoDecoder->getFramesInRange(start, stop, step.value_or(1));
364
+ return makeOpsFrameBatchOutput(result);
365
+ }
366
+
367
+ // Return the frames at given ptss for a given stream
368
+ OpsFrameBatchOutput get_frames_by_pts(
369
+ at::Tensor& decoder,
370
+ at::ArrayRef<double> timestamps) {
371
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
372
+ std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
373
+ auto result = videoDecoder->getFramesPlayedAt(timestampsVec);
374
+ return makeOpsFrameBatchOutput(result);
375
+ }
376
+
377
+ // Return the frames inside the range as a single stacked Tensor. The range is
378
+ // defined as [start_seconds, stop_seconds). The frames are stacked in pts
379
+ // order.
380
+ OpsFrameBatchOutput get_frames_by_pts_in_range(
381
+ at::Tensor& decoder,
382
+ double start_seconds,
383
+ double stop_seconds) {
384
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
385
+ auto result =
386
+ videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds);
387
+ return makeOpsFrameBatchOutput(result);
388
+ }
389
+
390
+ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
391
+ at::Tensor& decoder,
392
+ double start_seconds,
393
+ std::optional<double> stop_seconds = std::nullopt) {
394
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
395
+ auto result =
396
+ videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
397
+ return makeOpsAudioFramesOutput(result);
398
+ }
399
+
400
+ void encode_audio_to_file(
401
+ const at::Tensor& samples,
402
+ int64_t sample_rate,
403
+ std::string_view file_name,
404
+ std::optional<int64_t> bit_rate = std::nullopt,
405
+ std::optional<int64_t> num_channels = std::nullopt,
406
+ std::optional<int64_t> desired_sample_rate = std::nullopt) {
407
+ AudioStreamOptions audioStreamOptions;
408
+ audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
409
+ audioStreamOptions.numChannels =
410
+ validateOptionalInt64ToInt(num_channels, "num_channels");
411
+ audioStreamOptions.sampleRate =
412
+ validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
413
+ AudioEncoder(
414
+ samples,
415
+ validateInt64ToInt(sample_rate, "sample_rate"),
416
+ file_name,
417
+ audioStreamOptions)
418
+ .encode();
419
+ }
420
+
421
+ at::Tensor encode_audio_to_tensor(
422
+ const at::Tensor& samples,
423
+ int64_t sample_rate,
424
+ std::string_view format,
425
+ std::optional<int64_t> bit_rate = std::nullopt,
426
+ std::optional<int64_t> num_channels = std::nullopt,
427
+ std::optional<int64_t> desired_sample_rate = std::nullopt) {
428
+ auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
429
+ AudioStreamOptions audioStreamOptions;
430
+ audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
431
+ audioStreamOptions.numChannels =
432
+ validateOptionalInt64ToInt(num_channels, "num_channels");
433
+ audioStreamOptions.sampleRate =
434
+ validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
435
+ return AudioEncoder(
436
+ samples,
437
+ validateInt64ToInt(sample_rate, "sample_rate"),
438
+ format,
439
+ std::move(avioContextHolder),
440
+ audioStreamOptions)
441
+ .encodeToTensor();
442
+ }
443
+
444
+ // For testing only. We need to implement this operation as a core library
445
+ // function because what we're testing is round-tripping pts values as
446
+ // double-precision floating point numbers from C++ to Python and back to C++.
447
+ // We want to make sure that the value is preserved exactly, bit-for-bit, during
448
+ // this process.
449
+ //
450
+ // Returns true if for the given decoder, the pts
451
+ // value when converted to seconds as a double is exactly pts_seconds_to_test.
452
+ // Returns false otherwise.
453
+ bool _test_frame_pts_equality(
454
+ at::Tensor& decoder,
455
+ int64_t frame_index,
456
+ double pts_seconds_to_test) {
457
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
458
+ return pts_seconds_to_test ==
459
+ videoDecoder->getPtsSecondsForFrame(frame_index);
460
+ }
461
+
462
+ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
463
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
464
+ return videoDecoder->getKeyFrameIndices();
465
+ }
466
+
467
+ // Get the metadata from the video as a string.
468
+ std::string get_json_metadata(at::Tensor& decoder) {
469
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
470
+
471
+ ContainerMetadata videoMetadata = videoDecoder->getContainerMetadata();
472
+ auto maybeBestVideoStreamIndex = videoMetadata.bestVideoStreamIndex;
473
+
474
+ std::map<std::string, std::string> metadataMap;
475
+ // serialize the metadata into a string std::stringstream ss;
476
+ double durationSecondsFromHeader = 0;
477
+ if (maybeBestVideoStreamIndex.has_value() &&
478
+ videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]
479
+ .durationSecondsFromHeader.has_value()) {
480
+ durationSecondsFromHeader =
481
+ videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]
482
+ .durationSecondsFromHeader.value_or(0);
483
+ } else {
484
+ // Fallback to container-level duration if stream duration is not found.
485
+ durationSecondsFromHeader =
486
+ videoMetadata.durationSecondsFromHeader.value_or(0);
487
+ }
488
+ metadataMap["durationSecondsFromHeader"] =
489
+ std::to_string(durationSecondsFromHeader);
490
+
491
+ if (videoMetadata.bitRate.has_value()) {
492
+ metadataMap["bitRate"] = std::to_string(videoMetadata.bitRate.value());
493
+ }
494
+
495
+ if (maybeBestVideoStreamIndex.has_value()) {
496
+ auto streamMetadata =
497
+ videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex];
498
+ if (streamMetadata.numFramesFromContent.has_value()) {
499
+ metadataMap["numFramesFromHeader"] =
500
+ std::to_string(*streamMetadata.numFramesFromContent);
501
+ } else if (streamMetadata.numFramesFromHeader.has_value()) {
502
+ metadataMap["numFramesFromHeader"] =
503
+ std::to_string(*streamMetadata.numFramesFromHeader);
504
+ }
505
+ if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) {
506
+ metadataMap["beginStreamSecondsFromContent"] =
507
+ std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent);
508
+ }
509
+ if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) {
510
+ metadataMap["endStreamSecondsFromContent"] =
511
+ std::to_string(*streamMetadata.endStreamPtsSecondsFromContent);
512
+ }
513
+ if (streamMetadata.codecName.has_value()) {
514
+ metadataMap["codec"] = quoteValue(streamMetadata.codecName.value());
515
+ }
516
+ if (streamMetadata.width.has_value()) {
517
+ metadataMap["width"] = std::to_string(*streamMetadata.width);
518
+ }
519
+ if (streamMetadata.height.has_value()) {
520
+ metadataMap["height"] = std::to_string(*streamMetadata.height);
521
+ }
522
+ if (streamMetadata.averageFpsFromHeader.has_value()) {
523
+ metadataMap["averageFpsFromHeader"] =
524
+ std::to_string(*streamMetadata.averageFpsFromHeader);
525
+ }
526
+ }
527
+ if (videoMetadata.bestVideoStreamIndex.has_value()) {
528
+ metadataMap["bestVideoStreamIndex"] =
529
+ std::to_string(*videoMetadata.bestVideoStreamIndex);
530
+ }
531
+ if (videoMetadata.bestAudioStreamIndex.has_value()) {
532
+ metadataMap["bestAudioStreamIndex"] =
533
+ std::to_string(*videoMetadata.bestAudioStreamIndex);
534
+ }
535
+
536
+ return mapToJson(metadataMap);
537
+ }
538
+
539
+ // Get the container metadata as a string.
540
+ std::string get_container_json_metadata(at::Tensor& decoder) {
541
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
542
+
543
+ auto containerMetadata = videoDecoder->getContainerMetadata();
544
+
545
+ std::map<std::string, std::string> map;
546
+
547
+ if (containerMetadata.durationSecondsFromHeader.has_value()) {
548
+ map["durationSecondsFromHeader"] =
549
+ std::to_string(*containerMetadata.durationSecondsFromHeader);
550
+ }
551
+
552
+ if (containerMetadata.bitRate.has_value()) {
553
+ map["bitRate"] = std::to_string(*containerMetadata.bitRate);
554
+ }
555
+
556
+ if (containerMetadata.bestVideoStreamIndex.has_value()) {
557
+ map["bestVideoStreamIndex"] =
558
+ std::to_string(*containerMetadata.bestVideoStreamIndex);
559
+ }
560
+ if (containerMetadata.bestAudioStreamIndex.has_value()) {
561
+ map["bestAudioStreamIndex"] =
562
+ std::to_string(*containerMetadata.bestAudioStreamIndex);
563
+ }
564
+
565
+ map["numStreams"] =
566
+ std::to_string(containerMetadata.allStreamMetadata.size());
567
+
568
+ return mapToJson(map);
569
+ }
570
+
571
+ // Get the stream metadata as a string.
572
+ std::string get_stream_json_metadata(
573
+ at::Tensor& decoder,
574
+ int64_t stream_index) {
575
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
576
+ auto allStreamMetadata =
577
+ videoDecoder->getContainerMetadata().allStreamMetadata;
578
+ if (stream_index < 0 ||
579
+ stream_index >= static_cast<int64_t>(allStreamMetadata.size())) {
580
+ throw std::out_of_range(
581
+ "stream_index out of bounds: " + std::to_string(stream_index));
582
+ }
583
+
584
+ auto streamMetadata = allStreamMetadata[stream_index];
585
+
586
+ std::map<std::string, std::string> map;
587
+
588
+ if (streamMetadata.durationSecondsFromHeader.has_value()) {
589
+ map["durationSecondsFromHeader"] =
590
+ std::to_string(*streamMetadata.durationSecondsFromHeader);
591
+ }
592
+ if (streamMetadata.bitRate.has_value()) {
593
+ map["bitRate"] = std::to_string(*streamMetadata.bitRate);
594
+ }
595
+ if (streamMetadata.numFramesFromContent.has_value()) {
596
+ map["numFramesFromContent"] =
597
+ std::to_string(*streamMetadata.numFramesFromContent);
598
+ }
599
+ if (streamMetadata.numFramesFromHeader.has_value()) {
600
+ map["numFramesFromHeader"] =
601
+ std::to_string(*streamMetadata.numFramesFromHeader);
602
+ }
603
+ if (streamMetadata.beginStreamSecondsFromHeader.has_value()) {
604
+ map["beginStreamSecondsFromHeader"] =
605
+ std::to_string(*streamMetadata.beginStreamSecondsFromHeader);
606
+ }
607
+ if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) {
608
+ map["beginStreamSecondsFromContent"] =
609
+ std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent);
610
+ }
611
+ if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) {
612
+ map["endStreamSecondsFromContent"] =
613
+ std::to_string(*streamMetadata.endStreamPtsSecondsFromContent);
614
+ }
615
+ if (streamMetadata.codecName.has_value()) {
616
+ map["codec"] = quoteValue(streamMetadata.codecName.value());
617
+ }
618
+ if (streamMetadata.width.has_value()) {
619
+ map["width"] = std::to_string(*streamMetadata.width);
620
+ }
621
+ if (streamMetadata.height.has_value()) {
622
+ map["height"] = std::to_string(*streamMetadata.height);
623
+ }
624
+ if (streamMetadata.sampleAspectRatio.has_value()) {
625
+ map["sampleAspectRatioNum"] =
626
+ std::to_string((*streamMetadata.sampleAspectRatio).num);
627
+ map["sampleAspectRatioDen"] =
628
+ std::to_string((*streamMetadata.sampleAspectRatio).den);
629
+ }
630
+ if (streamMetadata.averageFpsFromHeader.has_value()) {
631
+ map["averageFpsFromHeader"] =
632
+ std::to_string(*streamMetadata.averageFpsFromHeader);
633
+ }
634
+ if (streamMetadata.sampleRate.has_value()) {
635
+ map["sampleRate"] = std::to_string(*streamMetadata.sampleRate);
636
+ }
637
+ if (streamMetadata.numChannels.has_value()) {
638
+ map["numChannels"] = std::to_string(*streamMetadata.numChannels);
639
+ }
640
+ if (streamMetadata.sampleFormat.has_value()) {
641
+ map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value());
642
+ }
643
+ if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) {
644
+ map["mediaType"] = quoteValue("video");
645
+ } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) {
646
+ map["mediaType"] = quoteValue("audio");
647
+ } else {
648
+ map["mediaType"] = quoteValue("other");
649
+ }
650
+ return mapToJson(map);
651
+ }
652
+
653
+ // Returns version information about the various FFMPEG libraries that are
654
+ // loaded in the program's address space.
655
+ // TODO: ideally we'd have a more robust way of getting the ffmpeg version,
656
+ // we're using av_version_info() which is not standardized and shouldn't be
657
+ // parsed by code (which we do!). See
658
+ // https://github.com/pytorch/torchcodec/issues/100
659
+ std::string _get_json_ffmpeg_library_versions() {
660
+ std::stringstream ss;
661
+ ss << "{\n";
662
+
663
+ unsigned int version = avfilter_version();
664
+ ss << "\"libavfilter\": [" << AV_VERSION_MAJOR(version) << ", "
665
+ << AV_VERSION_MINOR(version) << ", " << AV_VERSION_MICRO(version)
666
+ << "],\n";
667
+ version = avutil_version();
668
+ ss << "\"libavutil\": [" << AV_VERSION_MAJOR(version) << ", "
669
+ << AV_VERSION_MINOR(version) << ", " << AV_VERSION_MICRO(version)
670
+ << "],\n";
671
+ version = avcodec_version();
672
+ ss << "\"libavcodec\": [" << AV_VERSION_MAJOR(version) << ", "
673
+ << AV_VERSION_MINOR(version) << ", " << AV_VERSION_MICRO(version)
674
+ << "],\n";
675
+ version = avformat_version();
676
+ ss << "\"libavformat\": [" << AV_VERSION_MAJOR(version) << ", "
677
+ << AV_VERSION_MINOR(version) << ", " << AV_VERSION_MICRO(version)
678
+ << "],\n";
679
+ ss << "\"ffmpeg_version\": \"" << av_version_info() << "\"\n";
680
+ ss << "}\n";
681
+
682
+ return ss.str();
683
+ }
684
+
685
+ // Scans video packets to get more accurate metadata like frame count, exact
686
+ // keyframe positions, etc. Exact keyframe positions are useful for efficient
687
+ // accurate seeking. Note that this function reads the entire video but it does
688
+ // not decode frames. Reading a video file is much cheaper than decoding it.
689
+ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
690
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
691
+ videoDecoder->scanFileAndUpdateMetadataAndIndex();
692
+ }
693
+
694
+ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
695
+ m.impl("create_from_file", &create_from_file);
696
+ m.impl("create_from_tensor", &create_from_tensor);
697
+ m.impl("_convert_to_tensor", &_convert_to_tensor);
698
+ m.impl(
699
+ "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
700
+ }
701
+
702
+ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
703
+ m.impl("encode_audio_to_file", &encode_audio_to_file);
704
+ m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
705
+ m.impl("seek_to_pts", &seek_to_pts);
706
+ m.impl("add_video_stream", &add_video_stream);
707
+ m.impl("_add_video_stream", &_add_video_stream);
708
+ m.impl("add_audio_stream", &add_audio_stream);
709
+ m.impl("get_next_frame", &get_next_frame);
710
+ m.impl("_get_key_frame_indices", &_get_key_frame_indices);
711
+ m.impl("get_json_metadata", &get_json_metadata);
712
+ m.impl("get_container_json_metadata", &get_container_json_metadata);
713
+ m.impl("get_stream_json_metadata", &get_stream_json_metadata);
714
+ m.impl("get_frame_at_pts", &get_frame_at_pts);
715
+ m.impl("get_frame_at_index", &get_frame_at_index);
716
+ m.impl("get_frames_at_indices", &get_frames_at_indices);
717
+ m.impl("get_frames_in_range", &get_frames_in_range);
718
+ m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
719
+ m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio);
720
+ m.impl("get_frames_by_pts", &get_frames_by_pts);
721
+ m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
722
+ m.impl(
723
+ "scan_all_streams_to_update_metadata",
724
+ &scan_all_streams_to_update_metadata);
725
+ }
726
+
727
+ } // namespace facebook::torchcodec