torchcodec 0.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchcodec might be problematic. Click here for more details.

Files changed (67) hide show
  1. torchcodec/__init__.py +16 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +123 -0
  7. torchcodec/_core/AVIOTensorContext.h +43 -0
  8. torchcodec/_core/CMakeLists.txt +292 -0
  9. torchcodec/_core/Cache.h +138 -0
  10. torchcodec/_core/CpuDeviceInterface.cpp +266 -0
  11. torchcodec/_core/CpuDeviceInterface.h +70 -0
  12. torchcodec/_core/CudaDeviceInterface.cpp +514 -0
  13. torchcodec/_core/CudaDeviceInterface.h +37 -0
  14. torchcodec/_core/DeviceInterface.cpp +79 -0
  15. torchcodec/_core/DeviceInterface.h +67 -0
  16. torchcodec/_core/Encoder.cpp +514 -0
  17. torchcodec/_core/Encoder.h +123 -0
  18. torchcodec/_core/FFMPEGCommon.cpp +421 -0
  19. torchcodec/_core/FFMPEGCommon.h +227 -0
  20. torchcodec/_core/FilterGraph.cpp +142 -0
  21. torchcodec/_core/FilterGraph.h +45 -0
  22. torchcodec/_core/Frame.cpp +32 -0
  23. torchcodec/_core/Frame.h +118 -0
  24. torchcodec/_core/Metadata.h +72 -0
  25. torchcodec/_core/SingleStreamDecoder.cpp +1715 -0
  26. torchcodec/_core/SingleStreamDecoder.h +380 -0
  27. torchcodec/_core/StreamOptions.h +53 -0
  28. torchcodec/_core/ValidationUtils.cpp +35 -0
  29. torchcodec/_core/ValidationUtils.h +21 -0
  30. torchcodec/_core/__init__.py +40 -0
  31. torchcodec/_core/_metadata.py +317 -0
  32. torchcodec/_core/custom_ops.cpp +727 -0
  33. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +300 -0
  34. torchcodec/_core/ops.py +455 -0
  35. torchcodec/_core/pybind_ops.cpp +87 -0
  36. torchcodec/_frame.py +145 -0
  37. torchcodec/_internally_replaced_utils.py +67 -0
  38. torchcodec/_samplers/__init__.py +7 -0
  39. torchcodec/_samplers/video_clip_sampler.py +430 -0
  40. torchcodec/decoders/__init__.py +11 -0
  41. torchcodec/decoders/_audio_decoder.py +177 -0
  42. torchcodec/decoders/_decoder_utils.py +52 -0
  43. torchcodec/decoders/_video_decoder.py +464 -0
  44. torchcodec/encoders/__init__.py +1 -0
  45. torchcodec/encoders/_audio_encoder.py +150 -0
  46. torchcodec/libtorchcodec_core4.dll +0 -0
  47. torchcodec/libtorchcodec_core5.dll +0 -0
  48. torchcodec/libtorchcodec_core6.dll +0 -0
  49. torchcodec/libtorchcodec_core7.dll +0 -0
  50. torchcodec/libtorchcodec_custom_ops4.dll +0 -0
  51. torchcodec/libtorchcodec_custom_ops5.dll +0 -0
  52. torchcodec/libtorchcodec_custom_ops6.dll +0 -0
  53. torchcodec/libtorchcodec_custom_ops7.dll +0 -0
  54. torchcodec/libtorchcodec_pybind_ops4.pyd +0 -0
  55. torchcodec/libtorchcodec_pybind_ops5.pyd +0 -0
  56. torchcodec/libtorchcodec_pybind_ops6.pyd +0 -0
  57. torchcodec/libtorchcodec_pybind_ops7.pyd +0 -0
  58. torchcodec/samplers/__init__.py +2 -0
  59. torchcodec/samplers/_common.py +84 -0
  60. torchcodec/samplers/_index_based.py +287 -0
  61. torchcodec/samplers/_time_based.py +350 -0
  62. torchcodec/version.py +2 -0
  63. torchcodec-0.7.0.dist-info/METADATA +242 -0
  64. torchcodec-0.7.0.dist-info/RECORD +67 -0
  65. torchcodec-0.7.0.dist-info/WHEEL +5 -0
  66. torchcodec-0.7.0.dist-info/licenses/LICENSE +28 -0
  67. torchcodec-0.7.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,52 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import io
8
+ from pathlib import Path
9
+
10
+ from typing import Union
11
+
12
+ from torch import Tensor
13
+ from torchcodec import _core as core
14
+
15
+ ERROR_REPORTING_INSTRUCTIONS = """
16
+ This should never happen. Please report an issue following the steps in
17
+ https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml.
18
+ """
19
+
20
+
21
+ def create_decoder(
22
+ *,
23
+ source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
24
+ seek_mode: str,
25
+ ) -> Tensor:
26
+ if isinstance(source, str):
27
+ return core.create_from_file(source, seek_mode)
28
+ elif isinstance(source, Path):
29
+ return core.create_from_file(str(source), seek_mode)
30
+ elif isinstance(source, io.RawIOBase) or isinstance(source, io.BufferedReader):
31
+ return core.create_from_file_like(source, seek_mode)
32
+ elif isinstance(source, bytes):
33
+ return core.create_from_bytes(source, seek_mode)
34
+ elif isinstance(source, Tensor):
35
+ return core.create_from_tensor(source, seek_mode)
36
+ elif isinstance(source, io.TextIOBase):
37
+ raise TypeError(
38
+ "source is for reading text, likely from open(..., 'r'). Try with 'rb' for binary reading?"
39
+ )
40
+ elif hasattr(source, "read") and hasattr(source, "seek"):
41
+ # This check must be after checking for text-based reading. Also placing
42
+ # it last in general to be defensive: hasattr is a blunt instrument. We
43
+ # could use the inspect module to check for methods with the right
44
+ # signature.
45
+ return core.create_from_file_like(source, seek_mode)
46
+
47
+ raise TypeError(
48
+ f"Unknown source type: {type(source)}. "
49
+ "Supported types are str, Path, bytes, Tensor and file-like objects with "
50
+ "read(self, size: int) -> bytes and "
51
+ "seek(self, offset: int, whence: int) -> int methods."
52
+ )
@@ -0,0 +1,464 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import io
8
+ import json
9
+ import numbers
10
+ from pathlib import Path
11
+ from typing import Literal, Optional, Tuple, Union
12
+
13
+ import torch
14
+ from torch import device as torch_device, Tensor
15
+
16
+ from torchcodec import _core as core, Frame, FrameBatch
17
+ from torchcodec.decoders._decoder_utils import (
18
+ create_decoder,
19
+ ERROR_REPORTING_INSTRUCTIONS,
20
+ )
21
+
22
+
23
+ class VideoDecoder:
24
+ """A single-stream video decoder.
25
+
26
+ Args:
27
+ source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:
28
+
29
+ - If ``str``: a local path or a URL to a video file.
30
+ - If ``Pathlib.path``: a path to a local video file.
31
+ - If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
32
+ - If file-like object: we read video data from the object on demand. The object must
33
+ expose the methods `read(self, size: int) -> bytes` and
34
+ `seek(self, offset: int, whence: int) -> int`. Read more in:
35
+ :ref:`sphx_glr_generated_examples_decoding_file_like.py`.
36
+ stream_index (int, optional): Specifies which stream in the video to decode frames from.
37
+ Note that this index is absolute across all media types. If left unspecified, then
38
+ the :term:`best stream` is used.
39
+ dimension_order(str, optional): The dimension order of the decoded frames.
40
+ This can be either "NCHW" (default) or "NHWC", where N is the batch
41
+ size, C is the number of channels, H is the height, and W is the
42
+ width of the frames.
43
+
44
+ .. note::
45
+
46
+ Frames are natively decoded in NHWC format by the underlying
47
+ FFmpeg implementation. Converting those into NCHW format is a
48
+ cheap no-copy operation that allows these frames to be
49
+ transformed using the `torchvision transforms
50
+ <https://pytorch.org/vision/stable/transforms.html>`_.
51
+ num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
52
+ Use 1 for single-threaded decoding which may be best if you are running multiple
53
+ instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
54
+ decoding which is best if you are running a single instance of ``VideoDecoder``.
55
+ Passing 0 lets FFmpeg decide on the number of threads.
56
+ Default: 1.
57
+ device (str or torch.device, optional): The device to use for decoding. Default: "cpu".
58
+ seek_mode (str, optional): Determines if frame access will be "exact" or
59
+ "approximate". Exact guarantees that requesting frame i will always
60
+ return frame i, but doing so requires an initial :term:`scan` of the
61
+ file. Approximate is faster as it avoids scanning the file, but less
62
+ accurate as it uses the file's metadata to calculate where i
63
+ probably is. Default: "exact".
64
+ Read more about this parameter in:
65
+ :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
66
+
67
+ Attributes:
68
+ metadata (VideoStreamMetadata): Metadata of the video stream.
69
+ stream_index (int): The stream index that this decoder is retrieving frames from. If a
70
+ stream index was provided at initialization, this is the same value. If it was left
71
+ unspecified, this is the :term:`best stream`.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
77
+ *,
78
+ stream_index: Optional[int] = None,
79
+ dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
80
+ num_ffmpeg_threads: int = 1,
81
+ device: Optional[Union[str, torch_device]] = "cpu",
82
+ seek_mode: Literal["exact", "approximate"] = "exact",
83
+ ):
84
+ torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder")
85
+ allowed_seek_modes = ("exact", "approximate")
86
+ if seek_mode not in allowed_seek_modes:
87
+ raise ValueError(
88
+ f"Invalid seek mode ({seek_mode}). "
89
+ f"Supported values are {', '.join(allowed_seek_modes)}."
90
+ )
91
+
92
+ custom_frame_mappings = None
93
+ # Validate seek_mode and custom_frame_mappings are not mismatched
94
+ if custom_frame_mappings is not None and seek_mode == "approximate":
95
+ raise ValueError(
96
+ "custom_frame_mappings is incompatible with seek_mode='approximate'. "
97
+ "Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings."
98
+ )
99
+
100
+ # Auto-select custom_frame_mappings seek_mode and process data when mappings are provided
101
+ custom_frame_mappings_data = None
102
+ if custom_frame_mappings is not None:
103
+ seek_mode = "custom_frame_mappings" # type: ignore[assignment]
104
+ custom_frame_mappings_data = _read_custom_frame_mappings(
105
+ custom_frame_mappings
106
+ )
107
+
108
+ self._decoder = create_decoder(source=source, seek_mode=seek_mode)
109
+
110
+ allowed_dimension_orders = ("NCHW", "NHWC")
111
+ if dimension_order not in allowed_dimension_orders:
112
+ raise ValueError(
113
+ f"Invalid dimension order ({dimension_order}). "
114
+ f"Supported values are {', '.join(allowed_dimension_orders)}."
115
+ )
116
+
117
+ if num_ffmpeg_threads is None:
118
+ raise ValueError(f"{num_ffmpeg_threads = } should be an int.")
119
+
120
+ if isinstance(device, torch_device):
121
+ device = str(device)
122
+
123
+ core.add_video_stream(
124
+ self._decoder,
125
+ stream_index=stream_index,
126
+ dimension_order=dimension_order,
127
+ num_threads=num_ffmpeg_threads,
128
+ device=device,
129
+ custom_frame_mappings=custom_frame_mappings_data,
130
+ )
131
+
132
+ (
133
+ self.metadata,
134
+ self.stream_index,
135
+ self._begin_stream_seconds,
136
+ self._end_stream_seconds,
137
+ self._num_frames,
138
+ ) = _get_and_validate_stream_metadata(
139
+ decoder=self._decoder, stream_index=stream_index
140
+ )
141
+
142
+ def __len__(self) -> int:
143
+ return self._num_frames
144
+
145
+ def _getitem_int(self, key: int) -> Tensor:
146
+ assert isinstance(key, int)
147
+
148
+ frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
149
+ return frame_data
150
+
151
+ def _getitem_slice(self, key: slice) -> Tensor:
152
+ assert isinstance(key, slice)
153
+
154
+ start, stop, step = key.indices(len(self))
155
+ frame_data, *_ = core.get_frames_in_range(
156
+ self._decoder,
157
+ start=start,
158
+ stop=stop,
159
+ step=step,
160
+ )
161
+ return frame_data
162
+
163
+ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
164
+ """Return frame or frames as tensors, at the given index or range.
165
+
166
+ .. note::
167
+
168
+ If you need to decode multiple frames, we recommend using the batch
169
+ methods instead, since they are faster:
170
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at`,
171
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range`,
172
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_at`, and
173
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_in_range`.
174
+
175
+ Args:
176
+ key(int or slice): The index or range of frame(s) to retrieve.
177
+
178
+ Returns:
179
+ torch.Tensor: The frame or frames at the given index or range.
180
+ """
181
+ if isinstance(key, numbers.Integral):
182
+ return self._getitem_int(int(key))
183
+ elif isinstance(key, slice):
184
+ return self._getitem_slice(key)
185
+
186
+ raise TypeError(
187
+ f"Unsupported key type: {type(key)}. Supported types are int and slice."
188
+ )
189
+
190
+ def _get_key_frame_indices(self) -> list[int]:
191
+ return core._get_key_frame_indices(self._decoder)
192
+
193
+ def get_frame_at(self, index: int) -> Frame:
194
+ """Return a single frame at the given index.
195
+
196
+ .. note::
197
+
198
+ If you need to decode multiple frames, we recommend using the batch
199
+ methods instead, since they are faster:
200
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at`,
201
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range`,
202
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_at`,
203
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_in_range`.
204
+
205
+ Args:
206
+ index (int): The index of the frame to retrieve.
207
+
208
+ Returns:
209
+ Frame: The frame at the given index.
210
+ """
211
+ data, pts_seconds, duration_seconds = core.get_frame_at_index(
212
+ self._decoder, frame_index=index
213
+ )
214
+ return Frame(
215
+ data=data,
216
+ pts_seconds=pts_seconds.item(),
217
+ duration_seconds=duration_seconds.item(),
218
+ )
219
+
220
+ def get_frames_at(self, indices: list[int]) -> FrameBatch:
221
+ """Return frames at the given indices.
222
+
223
+ Args:
224
+ indices (list of int): The indices of the frames to retrieve.
225
+
226
+ Returns:
227
+ FrameBatch: The frames at the given indices.
228
+ """
229
+ if isinstance(indices, torch.Tensor):
230
+ # TODO we should avoid converting tensors to lists and just let the
231
+ # core ops and C++ code natively accept tensors. See
232
+ # https://github.com/pytorch/torchcodec/issues/879
233
+ indices = indices.to(torch.int).tolist()
234
+
235
+ data, pts_seconds, duration_seconds = core.get_frames_at_indices(
236
+ self._decoder, frame_indices=indices
237
+ )
238
+ return FrameBatch(
239
+ data=data,
240
+ pts_seconds=pts_seconds,
241
+ duration_seconds=duration_seconds,
242
+ )
243
+
244
+ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
245
+ """Return multiple frames at the given index range.
246
+
247
+ Frames are in [start, stop).
248
+
249
+ Args:
250
+ start (int): Index of the first frame to retrieve.
251
+ stop (int): End of indexing range (exclusive, as per Python
252
+ conventions).
253
+ step (int, optional): Step size between frames. Default: 1.
254
+
255
+ Returns:
256
+ FrameBatch: The frames within the specified range.
257
+ """
258
+ # Adjust start / stop indices to enable indexing semantics, ex. [-10, 1000] returns the last 10 frames
259
+ start, stop, step = slice(start, stop, step).indices(self._num_frames)
260
+ frames = core.get_frames_in_range(
261
+ self._decoder,
262
+ start=start,
263
+ stop=stop,
264
+ step=step,
265
+ )
266
+ return FrameBatch(*frames)
267
+
268
+ def get_frame_played_at(self, seconds: float) -> Frame:
269
+ """Return a single frame played at the given timestamp in seconds.
270
+
271
+ .. note::
272
+
273
+ If you need to decode multiple frames, we recommend using the batch
274
+ methods instead, since they are faster:
275
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at`,
276
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range`,
277
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_at`,
278
+ :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_in_range`.
279
+
280
+ Args:
281
+ seconds (float): The time stamp in seconds when the frame is played.
282
+
283
+ Returns:
284
+ Frame: The frame that is played at ``seconds``.
285
+ """
286
+ if not self._begin_stream_seconds <= seconds < self._end_stream_seconds:
287
+ raise IndexError(
288
+ f"Invalid pts in seconds: {seconds}. "
289
+ f"It must be greater than or equal to {self._begin_stream_seconds} "
290
+ f"and less than {self._end_stream_seconds}."
291
+ )
292
+ data, pts_seconds, duration_seconds = core.get_frame_at_pts(
293
+ self._decoder, seconds
294
+ )
295
+ return Frame(
296
+ data=data,
297
+ pts_seconds=pts_seconds.item(),
298
+ duration_seconds=duration_seconds.item(),
299
+ )
300
+
301
+ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
302
+ """Return frames played at the given timestamps in seconds.
303
+
304
+ Args:
305
+ seconds (list of float): The timestamps in seconds when the frames are played.
306
+
307
+ Returns:
308
+ FrameBatch: The frames that are played at ``seconds``.
309
+ """
310
+ if isinstance(seconds, torch.Tensor):
311
+ # TODO we should avoid converting tensors to lists and just let the
312
+ # core ops and C++ code natively accept tensors. See
313
+ # https://github.com/pytorch/torchcodec/issues/879
314
+ seconds = seconds.to(torch.float).tolist()
315
+
316
+ data, pts_seconds, duration_seconds = core.get_frames_by_pts(
317
+ self._decoder, timestamps=seconds
318
+ )
319
+ return FrameBatch(
320
+ data=data,
321
+ pts_seconds=pts_seconds,
322
+ duration_seconds=duration_seconds,
323
+ )
324
+
325
+ def get_frames_played_in_range(
326
+ self, start_seconds: float, stop_seconds: float
327
+ ) -> FrameBatch:
328
+ """Returns multiple frames in the given range.
329
+
330
+ Frames are in the half open range [start_seconds, stop_seconds). Each
331
+ returned frame's :term:`pts`, in seconds, is inside of the half open
332
+ range.
333
+
334
+ Args:
335
+ start_seconds (float): Time, in seconds, of the start of the
336
+ range.
337
+ stop_seconds (float): Time, in seconds, of the end of the
338
+ range. As a half open range, the end is excluded.
339
+
340
+ Returns:
341
+ FrameBatch: The frames within the specified range.
342
+ """
343
+ if not start_seconds <= stop_seconds:
344
+ raise ValueError(
345
+ f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
346
+ )
347
+ if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
348
+ raise ValueError(
349
+ f"Invalid start seconds: {start_seconds}. "
350
+ f"It must be greater than or equal to {self._begin_stream_seconds} "
351
+ f"and less than or equal to {self._end_stream_seconds}."
352
+ )
353
+ if not stop_seconds <= self._end_stream_seconds:
354
+ raise ValueError(
355
+ f"Invalid stop seconds: {stop_seconds}. "
356
+ f"It must be less than or equal to {self._end_stream_seconds}."
357
+ )
358
+ frames = core.get_frames_by_pts_in_range(
359
+ self._decoder,
360
+ start_seconds=start_seconds,
361
+ stop_seconds=stop_seconds,
362
+ )
363
+ return FrameBatch(*frames)
364
+
365
+
366
+ def _get_and_validate_stream_metadata(
367
+ *,
368
+ decoder: Tensor,
369
+ stream_index: Optional[int] = None,
370
+ ) -> Tuple[core._metadata.VideoStreamMetadata, int, float, float, int]:
371
+
372
+ container_metadata = core.get_container_metadata(decoder)
373
+
374
+ if stream_index is None:
375
+ if (stream_index := container_metadata.best_video_stream_index) is None:
376
+ raise ValueError(
377
+ "The best video stream is unknown and there is no specified stream. "
378
+ + ERROR_REPORTING_INSTRUCTIONS
379
+ )
380
+
381
+ metadata = container_metadata.streams[stream_index]
382
+ assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy
383
+
384
+ if metadata.begin_stream_seconds is None:
385
+ raise ValueError(
386
+ "The minimum pts value in seconds is unknown. "
387
+ + ERROR_REPORTING_INSTRUCTIONS
388
+ )
389
+ begin_stream_seconds = metadata.begin_stream_seconds
390
+
391
+ if metadata.end_stream_seconds is None:
392
+ raise ValueError(
393
+ "The maximum pts value in seconds is unknown. "
394
+ + ERROR_REPORTING_INSTRUCTIONS
395
+ )
396
+ end_stream_seconds = metadata.end_stream_seconds
397
+
398
+ if metadata.num_frames is None:
399
+ raise ValueError(
400
+ "The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
401
+ )
402
+ num_frames = metadata.num_frames
403
+
404
+ return (
405
+ metadata,
406
+ stream_index,
407
+ begin_stream_seconds,
408
+ end_stream_seconds,
409
+ num_frames,
410
+ )
411
+
412
+
413
+ def _read_custom_frame_mappings(
414
+ custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
415
+ ) -> tuple[Tensor, Tensor, Tensor]:
416
+ """Parse custom frame mappings from JSON data and extract frame metadata.
417
+
418
+ Args:
419
+ custom_frame_mappings: JSON data containing frame metadata, provided as:
420
+ - A JSON string (str, bytes)
421
+ - A file-like object with a read() method
422
+
423
+ Returns:
424
+ A tuple of three tensors:
425
+ - all_frames (Tensor): Presentation timestamps (PTS) for each frame
426
+ - is_key_frame (Tensor): Boolean tensor indicating which frames are key frames
427
+ - duration (Tensor): Duration of each frame
428
+ """
429
+ try:
430
+ input_data = (
431
+ json.load(custom_frame_mappings)
432
+ if hasattr(custom_frame_mappings, "read")
433
+ else json.loads(custom_frame_mappings)
434
+ )
435
+ except json.JSONDecodeError as e:
436
+ raise ValueError(
437
+ f"Invalid custom frame mappings: {e}. It should be a valid JSON string or a file-like object."
438
+ ) from e
439
+
440
+ if not input_data or "frames" not in input_data:
441
+ raise ValueError(
442
+ "Invalid custom frame mappings. The input is empty or missing the required 'frames' key."
443
+ )
444
+
445
+ first_frame = input_data["frames"][0]
446
+ pts_key = next((key for key in ("pts", "pkt_pts") if key in first_frame), None)
447
+ duration_key = next(
448
+ (key for key in ("duration", "pkt_duration") if key in first_frame), None
449
+ )
450
+ key_frame_present = "key_frame" in first_frame
451
+
452
+ if not pts_key or not duration_key or not key_frame_present:
453
+ raise ValueError(
454
+ "Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata."
455
+ )
456
+
457
+ frame_data = [
458
+ (float(frame[pts_key]), frame["key_frame"], float(frame[duration_key]))
459
+ for frame in input_data["frames"]
460
+ ]
461
+ all_frames, is_key_frame, duration = map(torch.tensor, zip(*frame_data))
462
+ if not (len(all_frames) == len(is_key_frame) == len(duration)):
463
+ raise ValueError("Mismatched lengths in frame index data")
464
+ return all_frames, is_key_frame, duration
@@ -0,0 +1 @@
1
+ from ._audio_encoder import AudioEncoder # noqa
@@ -0,0 +1,150 @@
1
+ from pathlib import Path
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torchcodec import _core
8
+
9
+
10
+ class AudioEncoder:
11
+ """An audio encoder.
12
+
13
+ Args:
14
+ samples (``torch.Tensor``): The samples to encode. This must be a 2D
15
+ tensor of shape ``(num_channels, num_samples)``, or a 1D tensor in
16
+ which case ``num_channels = 1`` is assumed. Values must be float
17
+ values in ``[-1, 1]``.
18
+ sample_rate (int): The sample rate of the **input** ``samples``. The
19
+ sample rate of the encoded output can be specified using the
20
+ encoding methods (``to_file``, etc.).
21
+ """
22
+
23
+ def __init__(self, samples: Tensor, *, sample_rate: int):
24
+ torch._C._log_api_usage_once("torchcodec.encoders.AudioEncoder")
25
+ # Some of these checks are also done in C++: it's OK, they're cheap, and
26
+ # doing them here allows to surface them when the AudioEncoder is
27
+ # instantiated, rather than later when the encoding methods are called.
28
+ if not isinstance(samples, Tensor):
29
+ raise ValueError(
30
+ f"Expected samples to be a Tensor, got {type(samples) = }."
31
+ )
32
+ if samples.ndim == 1:
33
+ # make it 2D and assume 1 channel
34
+ samples = torch.unsqueeze(samples, 0)
35
+ if samples.ndim != 2:
36
+ raise ValueError(f"Expected 1D or 2D samples, got {samples.shape = }.")
37
+ if samples.dtype != torch.float32:
38
+ raise ValueError(f"Expected float32 samples, got {samples.dtype = }.")
39
+ if sample_rate <= 0:
40
+ raise ValueError(f"{sample_rate = } must be > 0.")
41
+
42
+ self._samples = samples
43
+ self._sample_rate = sample_rate
44
+
45
+ def to_file(
46
+ self,
47
+ dest: Union[str, Path],
48
+ *,
49
+ bit_rate: Optional[int] = None,
50
+ num_channels: Optional[int] = None,
51
+ sample_rate: Optional[int] = None,
52
+ ) -> None:
53
+ """Encode samples into a file.
54
+
55
+ Args:
56
+ dest (str or ``pathlib.Path``): The path to the output file, e.g.
57
+ ``audio.mp3``. The extension of the file determines the audio
58
+ format and container.
59
+ bit_rate (int, optional): The output bit rate. Encoders typically
60
+ support a finite set of bit rate values, so ``bit_rate`` will be
61
+ matched to one of those supported values. The default is chosen
62
+ by FFmpeg.
63
+ num_channels (int, optional): The number of channels of the encoded
64
+ output samples. By default, the number of channels of the input
65
+ ``samples`` is used.
66
+ sample_rate (int, optional): The sample rate of the encoded output.
67
+ By default, the sample rate of the input ``samples`` is used.
68
+ """
69
+ _core.encode_audio_to_file(
70
+ samples=self._samples,
71
+ sample_rate=self._sample_rate,
72
+ filename=str(dest),
73
+ bit_rate=bit_rate,
74
+ num_channels=num_channels,
75
+ desired_sample_rate=sample_rate,
76
+ )
77
+
78
+ def to_tensor(
79
+ self,
80
+ format: str,
81
+ *,
82
+ bit_rate: Optional[int] = None,
83
+ num_channels: Optional[int] = None,
84
+ sample_rate: Optional[int] = None,
85
+ ) -> Tensor:
86
+ """Encode samples into raw bytes, as a 1D uint8 Tensor.
87
+
88
+ Args:
89
+ format (str): The format of the encoded samples, e.g. "mp3", "wav"
90
+ or "flac".
91
+ bit_rate (int, optional): The output bit rate. Encoders typically
92
+ support a finite set of bit rate values, so ``bit_rate`` will be
93
+ matched to one of those supported values. The default is chosen
94
+ by FFmpeg.
95
+ num_channels (int, optional): The number of channels of the encoded
96
+ output samples. By default, the number of channels of the input
97
+ ``samples`` is used.
98
+ sample_rate (int, optional): The sample rate of the encoded output.
99
+ By default, the sample rate of the input ``samples`` is used.
100
+
101
+ Returns:
102
+ Tensor: The raw encoded bytes as 1D uint8 Tensor.
103
+ """
104
+ return _core.encode_audio_to_tensor(
105
+ samples=self._samples,
106
+ sample_rate=self._sample_rate,
107
+ format=format,
108
+ bit_rate=bit_rate,
109
+ num_channels=num_channels,
110
+ desired_sample_rate=sample_rate,
111
+ )
112
+
113
+ def to_file_like(
114
+ self,
115
+ file_like,
116
+ format: str,
117
+ *,
118
+ bit_rate: Optional[int] = None,
119
+ num_channels: Optional[int] = None,
120
+ sample_rate: Optional[int] = None,
121
+ ) -> None:
122
+ """Encode samples into a file-like object.
123
+
124
+ Args:
125
+ file_like: A file-like object that supports ``write()`` and
126
+ ``seek()`` methods, such as io.BytesIO(), an open file in binary
127
+ write mode, etc. Methods must have the following signature:
128
+ ``write(data: bytes) -> int`` and ``seek(offset: int, whence:
129
+ int = 0) -> int``.
130
+ format (str): The format of the encoded samples, e.g. "mp3", "wav"
131
+ or "flac".
132
+ bit_rate (int, optional): The output bit rate. Encoders typically
133
+ support a finite set of bit rate values, so ``bit_rate`` will be
134
+ matched to one of those supported values. The default is chosen
135
+ by FFmpeg.
136
+ num_channels (int, optional): The number of channels of the encoded
137
+ output samples. By default, the number of channels of the input
138
+ ``samples`` is used.
139
+ sample_rate (int, optional): The sample rate of the encoded output.
140
+ By default, the sample rate of the input ``samples`` is used.
141
+ """
142
+ _core.encode_audio_to_file_like(
143
+ samples=self._samples,
144
+ sample_rate=self._sample_rate,
145
+ format=format,
146
+ file_like=file_like,
147
+ bit_rate=bit_rate,
148
+ num_channels=num_channels,
149
+ desired_sample_rate=sample_rate,
150
+ )
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file