torchcodec 0.10.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. torchcodec/__init__.py +27 -0
  2. torchcodec/_core/AVIOContextHolder.cpp +60 -0
  3. torchcodec/_core/AVIOContextHolder.h +64 -0
  4. torchcodec/_core/AVIOFileLikeContext.cpp +98 -0
  5. torchcodec/_core/AVIOFileLikeContext.h +55 -0
  6. torchcodec/_core/AVIOTensorContext.cpp +130 -0
  7. torchcodec/_core/AVIOTensorContext.h +44 -0
  8. torchcodec/_core/BetaCudaDeviceInterface.cpp +849 -0
  9. torchcodec/_core/BetaCudaDeviceInterface.h +196 -0
  10. torchcodec/_core/CMakeLists.txt +295 -0
  11. torchcodec/_core/CUDACommon.cpp +330 -0
  12. torchcodec/_core/CUDACommon.h +51 -0
  13. torchcodec/_core/Cache.h +124 -0
  14. torchcodec/_core/CpuDeviceInterface.cpp +509 -0
  15. torchcodec/_core/CpuDeviceInterface.h +141 -0
  16. torchcodec/_core/CudaDeviceInterface.cpp +602 -0
  17. torchcodec/_core/CudaDeviceInterface.h +79 -0
  18. torchcodec/_core/DeviceInterface.cpp +117 -0
  19. torchcodec/_core/DeviceInterface.h +191 -0
  20. torchcodec/_core/Encoder.cpp +1054 -0
  21. torchcodec/_core/Encoder.h +192 -0
  22. torchcodec/_core/FFMPEGCommon.cpp +684 -0
  23. torchcodec/_core/FFMPEGCommon.h +314 -0
  24. torchcodec/_core/FilterGraph.cpp +159 -0
  25. torchcodec/_core/FilterGraph.h +59 -0
  26. torchcodec/_core/Frame.cpp +47 -0
  27. torchcodec/_core/Frame.h +72 -0
  28. torchcodec/_core/Metadata.cpp +124 -0
  29. torchcodec/_core/Metadata.h +92 -0
  30. torchcodec/_core/NVCUVIDRuntimeLoader.cpp +320 -0
  31. torchcodec/_core/NVCUVIDRuntimeLoader.h +14 -0
  32. torchcodec/_core/NVDECCache.cpp +60 -0
  33. torchcodec/_core/NVDECCache.h +102 -0
  34. torchcodec/_core/SingleStreamDecoder.cpp +1586 -0
  35. torchcodec/_core/SingleStreamDecoder.h +391 -0
  36. torchcodec/_core/StreamOptions.h +70 -0
  37. torchcodec/_core/Transform.cpp +128 -0
  38. torchcodec/_core/Transform.h +86 -0
  39. torchcodec/_core/ValidationUtils.cpp +35 -0
  40. torchcodec/_core/ValidationUtils.h +21 -0
  41. torchcodec/_core/__init__.py +46 -0
  42. torchcodec/_core/_metadata.py +262 -0
  43. torchcodec/_core/custom_ops.cpp +1090 -0
  44. torchcodec/_core/fetch_and_expose_non_gpl_ffmpeg_libs.cmake +169 -0
  45. torchcodec/_core/nvcuvid_include/cuviddec.h +1374 -0
  46. torchcodec/_core/nvcuvid_include/nvcuvid.h +610 -0
  47. torchcodec/_core/ops.py +605 -0
  48. torchcodec/_core/pybind_ops.cpp +50 -0
  49. torchcodec/_frame.py +146 -0
  50. torchcodec/_internally_replaced_utils.py +68 -0
  51. torchcodec/_samplers/__init__.py +7 -0
  52. torchcodec/_samplers/video_clip_sampler.py +419 -0
  53. torchcodec/decoders/__init__.py +12 -0
  54. torchcodec/decoders/_audio_decoder.py +185 -0
  55. torchcodec/decoders/_decoder_utils.py +113 -0
  56. torchcodec/decoders/_video_decoder.py +601 -0
  57. torchcodec/encoders/__init__.py +2 -0
  58. torchcodec/encoders/_audio_encoder.py +149 -0
  59. torchcodec/encoders/_video_encoder.py +196 -0
  60. torchcodec/libtorchcodec_core4.so +0 -0
  61. torchcodec/libtorchcodec_core5.so +0 -0
  62. torchcodec/libtorchcodec_core6.so +0 -0
  63. torchcodec/libtorchcodec_core7.so +0 -0
  64. torchcodec/libtorchcodec_core8.so +0 -0
  65. torchcodec/libtorchcodec_custom_ops4.so +0 -0
  66. torchcodec/libtorchcodec_custom_ops5.so +0 -0
  67. torchcodec/libtorchcodec_custom_ops6.so +0 -0
  68. torchcodec/libtorchcodec_custom_ops7.so +0 -0
  69. torchcodec/libtorchcodec_custom_ops8.so +0 -0
  70. torchcodec/libtorchcodec_pybind_ops4.so +0 -0
  71. torchcodec/libtorchcodec_pybind_ops5.so +0 -0
  72. torchcodec/libtorchcodec_pybind_ops6.so +0 -0
  73. torchcodec/libtorchcodec_pybind_ops7.so +0 -0
  74. torchcodec/libtorchcodec_pybind_ops8.so +0 -0
  75. torchcodec/samplers/__init__.py +2 -0
  76. torchcodec/samplers/_common.py +84 -0
  77. torchcodec/samplers/_index_based.py +287 -0
  78. torchcodec/samplers/_time_based.py +358 -0
  79. torchcodec/share/cmake/TorchCodec/TorchCodecConfig.cmake +76 -0
  80. torchcodec/share/cmake/TorchCodec/ffmpeg_versions.cmake +122 -0
  81. torchcodec/transforms/__init__.py +12 -0
  82. torchcodec/transforms/_decoder_transforms.py +375 -0
  83. torchcodec/version.py +2 -0
  84. torchcodec-0.10.0.dist-info/METADATA +286 -0
  85. torchcodec-0.10.0.dist-info/RECORD +88 -0
  86. torchcodec-0.10.0.dist-info/WHEEL +5 -0
  87. torchcodec-0.10.0.dist-info/licenses/LICENSE +28 -0
  88. torchcodec-0.10.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,287 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from torchcodec import FrameBatch
6
+ from torchcodec.decoders import VideoDecoder
7
+ from torchcodec.samplers._common import (
8
+ _FRAMEBATCH_RETURN_DOCS,
9
+ _POLICY_FUNCTION_TYPE,
10
+ _POLICY_FUNCTIONS,
11
+ _reshape_4d_framebatch_into_5d,
12
+ _validate_common_params,
13
+ )
14
+
15
+
16
+ def _validate_params_index_based(*, num_clips, num_indices_between_frames):
17
+ if num_clips <= 0:
18
+ raise ValueError(f"num_clips ({num_clips}) must be > 0")
19
+
20
+ if num_indices_between_frames <= 0:
21
+ raise ValueError(
22
+ f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive"
23
+ )
24
+
25
+
26
+ def _validate_sampling_range_index_based(
27
+ *,
28
+ num_indices_between_frames,
29
+ num_frames_per_clip,
30
+ sampling_range_start,
31
+ sampling_range_end,
32
+ num_frames_in_video,
33
+ ):
34
+ if sampling_range_start < 0:
35
+ sampling_range_start = num_frames_in_video + sampling_range_start
36
+
37
+ if sampling_range_start >= num_frames_in_video:
38
+ raise ValueError(
39
+ f"sampling_range_start ({sampling_range_start}) must be smaller than "
40
+ f"the number of frames ({num_frames_in_video})."
41
+ )
42
+
43
+ clip_span = _get_clip_span(
44
+ num_indices_between_frames=num_indices_between_frames,
45
+ num_frames_per_clip=num_frames_per_clip,
46
+ )
47
+
48
+ if sampling_range_end is None:
49
+ sampling_range_end = max(num_frames_in_video - clip_span + 1, 1)
50
+ if sampling_range_start >= sampling_range_end:
51
+ raise ValueError(
52
+ f"We determined that sampling_range_end should be {sampling_range_end}, "
53
+ "but it is smaller than or equal to sampling_range_start "
54
+ f"({sampling_range_start})."
55
+ )
56
+ else:
57
+ if sampling_range_end < 0:
58
+ # Support negative values so that -1 means last frame.
59
+ sampling_range_end = num_frames_in_video + sampling_range_end
60
+ sampling_range_end = min(sampling_range_end, num_frames_in_video)
61
+ if sampling_range_start >= sampling_range_end:
62
+ raise ValueError(
63
+ f"sampling_range_start ({sampling_range_start}) must be smaller than "
64
+ f"sampling_range_end ({sampling_range_end})."
65
+ )
66
+
67
+ return sampling_range_start, sampling_range_end
68
+
69
+
70
+ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip):
71
+ """Return the span of a clip, i.e. the number of frames (or indices)
72
+ between the first and last frame in the clip, both included.
73
+
74
+ This isn't the same as the number of frames in a clip!
75
+ Example: f means a frame in the clip, x means a frame excluded from the clip
76
+ num_frames_per_clip = 4
77
+ num_indices_between_frames = 1, clip = ffff , span = 4
78
+ num_indices_between_frames = 2, clip = fxfxfxf , span = 7
79
+ num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10
80
+ """
81
+ return num_indices_between_frames * (num_frames_per_clip - 1) + 1
82
+
83
+
84
+ def _build_all_clips_indices(
85
+ *,
86
+ clip_start_indices: torch.Tensor, # 1D int tensor
87
+ num_frames_per_clip: int,
88
+ num_indices_between_frames: int,
89
+ num_frames_in_video: int,
90
+ policy_fun: _POLICY_FUNCTION_TYPE,
91
+ ) -> list[int]:
92
+ # From the clip_start_indices [f_00, f_10, f_20, ...]
93
+ # and from the rest of the parameters, return the list of all the frame
94
+ # indices that make up all the clips.
95
+ # I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...]
96
+ # where f_01 is the index of frame 1 in clip 0.
97
+ #
98
+ # All clips in the output are of length num_frames_per_clip (=4 in example
99
+ # above). When the frame indices go beyond num_frames_in_video, we force the
100
+ # frame indices back to valid values by applying the user's policy (wrap,
101
+ # repeat, etc.).
102
+ all_clips_indices: list[int] = []
103
+
104
+ clip_span = _get_clip_span(
105
+ num_indices_between_frames=num_indices_between_frames,
106
+ num_frames_per_clip=num_frames_per_clip,
107
+ )
108
+
109
+ for start_index in clip_start_indices:
110
+ frame_index_upper_bound = min(start_index + clip_span, num_frames_in_video)
111
+ frame_indices = list(
112
+ range(start_index, frame_index_upper_bound, num_indices_between_frames)
113
+ )
114
+ if len(frame_indices) < num_frames_per_clip:
115
+ frame_indices = policy_fun(frame_indices, num_frames_per_clip) # type: ignore[assignment]
116
+ all_clips_indices += frame_indices
117
+ return all_clips_indices
118
+
119
+
120
+ def _generic_index_based_sampler(
121
+ kind: Literal["random", "regular"],
122
+ decoder: VideoDecoder,
123
+ *,
124
+ num_clips: int,
125
+ num_frames_per_clip: int,
126
+ num_indices_between_frames: int,
127
+ sampling_range_start: int,
128
+ sampling_range_end: int | None, # interval is [start, end).
129
+ # Important note: sampling_range_end defines the upper bound of where a clip
130
+ # can *start*, not where a clip can end.
131
+ policy: Literal["repeat_last", "wrap", "error"],
132
+ ) -> FrameBatch:
133
+
134
+ _validate_common_params(
135
+ decoder=decoder,
136
+ num_frames_per_clip=num_frames_per_clip,
137
+ policy=policy,
138
+ )
139
+ _validate_params_index_based(
140
+ num_clips=num_clips,
141
+ num_indices_between_frames=num_indices_between_frames,
142
+ )
143
+
144
+ sampling_range_start, sampling_range_end = _validate_sampling_range_index_based(
145
+ num_frames_per_clip=num_frames_per_clip,
146
+ num_indices_between_frames=num_indices_between_frames,
147
+ sampling_range_start=sampling_range_start,
148
+ sampling_range_end=sampling_range_end,
149
+ num_frames_in_video=len(decoder),
150
+ )
151
+
152
+ if kind == "random":
153
+ clip_start_indices = torch.randint(
154
+ low=sampling_range_start, high=sampling_range_end, size=(num_clips,)
155
+ )
156
+ else:
157
+ # Note [num clips larger than sampling range]
158
+ # If we ask for more clips than there are frames in the sampling range or
159
+ # in the video, we rely on torch.linspace behavior which will return
160
+ # duplicated indices.
161
+ # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns
162
+ # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10
163
+ # Alternatively we could wrap around, but the current behavior is closer to
164
+ # the expected "equally spaced indices" sampling.
165
+ clip_start_indices = torch.linspace(
166
+ sampling_range_start,
167
+ sampling_range_end - 1,
168
+ steps=num_clips,
169
+ dtype=torch.int,
170
+ )
171
+
172
+ all_clips_indices = _build_all_clips_indices(
173
+ clip_start_indices=clip_start_indices,
174
+ num_frames_per_clip=num_frames_per_clip,
175
+ num_indices_between_frames=num_indices_between_frames,
176
+ num_frames_in_video=len(decoder),
177
+ policy_fun=_POLICY_FUNCTIONS[policy],
178
+ )
179
+
180
+ frames = decoder.get_frames_at(indices=all_clips_indices)
181
+ return _reshape_4d_framebatch_into_5d(
182
+ frames=frames,
183
+ num_clips=num_clips,
184
+ num_frames_per_clip=num_frames_per_clip,
185
+ )
186
+
187
+
188
+ def clips_at_random_indices(
189
+ decoder: VideoDecoder,
190
+ *,
191
+ num_clips: int = 1,
192
+ num_frames_per_clip: int = 1,
193
+ num_indices_between_frames: int = 1,
194
+ sampling_range_start: int = 0,
195
+ sampling_range_end: int | None = None, # interval is [start, end).
196
+ policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
197
+ ) -> FrameBatch:
198
+ # See docstring below
199
+ torch._C._log_api_usage_once("torchcodec.samplers.clips_at_random_indices")
200
+ return _generic_index_based_sampler(
201
+ kind="random",
202
+ decoder=decoder,
203
+ num_clips=num_clips,
204
+ num_frames_per_clip=num_frames_per_clip,
205
+ num_indices_between_frames=num_indices_between_frames,
206
+ sampling_range_start=sampling_range_start,
207
+ sampling_range_end=sampling_range_end,
208
+ policy=policy,
209
+ )
210
+
211
+
212
+ def clips_at_regular_indices(
213
+ decoder: VideoDecoder,
214
+ *,
215
+ num_clips: int = 1,
216
+ num_frames_per_clip: int = 1,
217
+ num_indices_between_frames: int = 1,
218
+ sampling_range_start: int = 0,
219
+ sampling_range_end: int | None = None, # interval is [start, end).
220
+ policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
221
+ ) -> FrameBatch:
222
+ # See docstring below
223
+ torch._C._log_api_usage_once("torchcodec.samplers.clips_at_regular_indices")
224
+ return _generic_index_based_sampler(
225
+ kind="regular",
226
+ decoder=decoder,
227
+ num_clips=num_clips,
228
+ num_frames_per_clip=num_frames_per_clip,
229
+ num_indices_between_frames=num_indices_between_frames,
230
+ sampling_range_start=sampling_range_start,
231
+ sampling_range_end=sampling_range_end,
232
+ policy=policy,
233
+ )
234
+
235
+
236
+ _COMMON_DOCS = f"""
237
+ Args:
238
+ decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
239
+ instance to sample clips from.
240
+ num_clips (int, optional): The number of clips to return. Default: 1.
241
+ num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
242
+ num_indices_between_frames(int, optional): The number of indices between
243
+ the frames *within* a clip. Default: 1, which means frames are
244
+ consecutive. This is sometimes refered-to as "dilation".
245
+ sampling_range_start (int, optional): The start of the sampling range,
246
+ which defines the first index that a clip may *start* at. Default:
247
+ 0, i.e. the start of the video.
248
+ sampling_range_end (int or None, optional): The end of the sampling
249
+ range, which defines the last index that a clip may *start* at. This
250
+ value is exclusive, i.e. a clip may only start within
251
+ [``sampling_range_start``, ``sampling_range_end``). If None
252
+ (default), the value is set automatically such that the clips never
253
+ span beyond the end of the video. For example if the last valid
254
+ index in a video is 99 and the clips span 10 frames, this value is
255
+ set to 99 - 10 + 1 = 90. Negative values are accepted and are
256
+ equivalent to ``len(video) - val``. When a clip spans beyond the end
257
+ of the video, the ``policy`` parameter defines how to construct such
258
+ clip.
259
+ policy (str, optional): Defines how to construct clips that span beyond
260
+ the end of the video. This is best described with an example:
261
+ assuming the last valid index in a video is 99, and a clip was
262
+ sampled to start at index 95, with ``num_frames_per_clip=5`` and
263
+ ``num_indices_between_frames=2``, the indices of the frames in the
264
+ clip are supposed to be [95, 97, 99, 101, 103]. But 101 and 103 are
265
+ invalid indices, so the ``policy`` parameter defines how to replace
266
+ those frames, with valid indices:
267
+
268
+ - "repeat_last": repeats the last valid frame of the clip. We would
269
+ get [95, 97, 99, 99, 99].
270
+ - "wrap": wraps around to the beginning of the clip. We would get
271
+ [95, 97, 99, 95, 97].
272
+ - "error": raises an error.
273
+
274
+ Default is "repeat_last". Note that when ``sampling_range_end=None``
275
+ (default), this policy parameter is unlikely to be relevant.
276
+
277
+ {_FRAMEBATCH_RETURN_DOCS}
278
+ """
279
+
280
+ clips_at_random_indices.__doc__ = f"""Sample :term:`clips` at random indices.
281
+ {_COMMON_DOCS}
282
+ """
283
+
284
+
285
+ clips_at_regular_indices.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) indices.
286
+ {_COMMON_DOCS}
287
+ """
@@ -0,0 +1,358 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from torchcodec import FrameBatch
6
+ from torchcodec.samplers._common import (
7
+ _FRAMEBATCH_RETURN_DOCS,
8
+ _POLICY_FUNCTION_TYPE,
9
+ _POLICY_FUNCTIONS,
10
+ _reshape_4d_framebatch_into_5d,
11
+ _validate_common_params,
12
+ )
13
+
14
+
15
+ def _validate_params_time_based(
16
+ *,
17
+ decoder,
18
+ num_clips,
19
+ seconds_between_clip_starts,
20
+ seconds_between_frames,
21
+ ):
22
+
23
+ if (num_clips is None and seconds_between_clip_starts is None) or (
24
+ num_clips is not None and seconds_between_clip_starts is not None
25
+ ):
26
+ raise ValueError("This is internal only and should never happen.")
27
+
28
+ if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0:
29
+ raise ValueError(
30
+ f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0"
31
+ )
32
+
33
+ if num_clips is not None and num_clips <= 0:
34
+ raise ValueError(f"num_clips ({num_clips}) must be > 0")
35
+
36
+ if decoder.metadata.average_fps is None:
37
+ raise ValueError(
38
+ "Could not infer average fps from video metadata. "
39
+ "Try using an index-based sampler instead."
40
+ )
41
+
42
+ # Note that metadata.begin_stream_seconds is a property that will always yield a valid
43
+ # value; if it is not present in the actual metadata, the metadata object will return 0.
44
+ # Hence, we do not test for it here and only test metadata.end_stream_seconds.
45
+ if decoder.metadata.end_stream_seconds is None:
46
+ raise ValueError(
47
+ "Could not infer stream end from video metadata. "
48
+ "Try using an index-based sampler instead."
49
+ )
50
+
51
+ average_frame_duration_seconds = 1 / decoder.metadata.average_fps
52
+ if seconds_between_frames is None:
53
+ seconds_between_frames = average_frame_duration_seconds
54
+ elif seconds_between_frames <= 0:
55
+ raise ValueError(
56
+ f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0, got"
57
+ )
58
+
59
+ return seconds_between_frames
60
+
61
+
62
+ def _validate_sampling_range_time_based(
63
+ *,
64
+ num_frames_per_clip,
65
+ seconds_between_frames,
66
+ sampling_range_start,
67
+ sampling_range_end,
68
+ begin_stream_seconds,
69
+ end_stream_seconds,
70
+ ):
71
+
72
+ if sampling_range_start is None:
73
+ sampling_range_start = begin_stream_seconds
74
+ else:
75
+ if sampling_range_start < begin_stream_seconds:
76
+ raise ValueError(
77
+ f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
78
+ )
79
+ if sampling_range_start >= end_stream_seconds:
80
+ raise ValueError(
81
+ f"sampling_range_start ({sampling_range_start}) must be smaller than {end_stream_seconds}"
82
+ )
83
+
84
+ if sampling_range_end is None:
85
+ # We allow a clip to start anywhere within
86
+ # [sampling_range_start, sampling_range_end)
87
+ # When sampling_range_end is None, we want to automatically set it to
88
+ # the largest possible value such that the sampled frames in any clip
89
+ # are within the bounds of the video duration (in other words, we don't
90
+ # want to have to resort to the `policy`).
91
+ # I.e. we want to guarantee that for all frames in any clip we have
92
+ # pts < end_stream_seconds.
93
+ #
94
+ # The frames of a clip will be sampled at the following pts:
95
+ # clip_timestamps = [
96
+ # clip_start + 0 * seconds_between_frames,
97
+ # clip_start + 1 * seconds_between_frames,
98
+ # clip_start + 2 * seconds_between_frames,
99
+ # ...
100
+ # clip_start + (num_frames_per_clip - 1) * seconds_between_frames,
101
+ # ]
102
+ # To guarantee that any such value is < end_stream_seconds, we only need
103
+ # to guarantee that
104
+ # clip_start < end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames
105
+ #
106
+ # So that's the value of sampling_range_end we want to use.
107
+ sampling_range_end = (
108
+ end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames
109
+ )
110
+ elif sampling_range_end <= begin_stream_seconds:
111
+ raise ValueError(
112
+ f"sampling_range_end ({sampling_range_end}) must be at least {begin_stream_seconds}"
113
+ )
114
+
115
+ if sampling_range_start >= sampling_range_end:
116
+ raise ValueError(
117
+ f"sampling_range_start ({sampling_range_start}) must be smaller than sampling_range_end ({sampling_range_end})"
118
+ )
119
+
120
+ sampling_range_end = min(sampling_range_end, end_stream_seconds)
121
+
122
+ return sampling_range_start, sampling_range_end
123
+
124
+
125
+ def _build_all_clips_timestamps(
126
+ *,
127
+ clip_start_seconds: torch.Tensor, # 1D float tensor
128
+ num_frames_per_clip: int,
129
+ seconds_between_frames: float,
130
+ end_stream_seconds: float,
131
+ policy_fun: _POLICY_FUNCTION_TYPE,
132
+ ) -> list[float]:
133
+
134
+ all_clips_timestamps: list[float] = []
135
+ for start_seconds in clip_start_seconds:
136
+ clip_timestamps = [
137
+ timestamp
138
+ for i in range(num_frames_per_clip)
139
+ if (timestamp := start_seconds + i * seconds_between_frames)
140
+ < end_stream_seconds
141
+ ]
142
+
143
+ if len(clip_timestamps) < num_frames_per_clip:
144
+ clip_timestamps = policy_fun(clip_timestamps, num_frames_per_clip)
145
+ all_clips_timestamps += clip_timestamps
146
+
147
+ return all_clips_timestamps
148
+
149
+
150
+ def _generic_time_based_sampler(
151
+ kind: Literal["random", "regular"],
152
+ decoder,
153
+ *,
154
+ num_clips: int | None, # mutually exclusive with seconds_between_clip_starts
155
+ seconds_between_clip_starts: float | None,
156
+ num_frames_per_clip: int,
157
+ seconds_between_frames: float | None,
158
+ # None means "begining", which may not always be 0
159
+ sampling_range_start: float | None,
160
+ sampling_range_end: float | None, # interval is [start, end).
161
+ policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
162
+ ) -> FrameBatch:
163
+ # Note: *everywhere*, sampling_range_end denotes the upper bound of where a
164
+ # clip can start. This is an *open* upper bound, i.e. we will make sure no
165
+ # clip starts exactly at (or above) sampling_range_end.
166
+
167
+ _validate_common_params(
168
+ decoder=decoder,
169
+ num_frames_per_clip=num_frames_per_clip,
170
+ policy=policy,
171
+ )
172
+
173
+ seconds_between_frames = _validate_params_time_based(
174
+ decoder=decoder,
175
+ num_clips=num_clips,
176
+ seconds_between_clip_starts=seconds_between_clip_starts,
177
+ seconds_between_frames=seconds_between_frames,
178
+ )
179
+
180
+ sampling_range_start, sampling_range_end = _validate_sampling_range_time_based(
181
+ num_frames_per_clip=num_frames_per_clip,
182
+ seconds_between_frames=seconds_between_frames,
183
+ sampling_range_start=sampling_range_start,
184
+ sampling_range_end=sampling_range_end,
185
+ begin_stream_seconds=decoder.metadata.begin_stream_seconds,
186
+ end_stream_seconds=decoder.metadata.end_stream_seconds,
187
+ )
188
+
189
+ if kind == "random":
190
+ assert num_clips is not None # appease type-checker
191
+ sampling_range_width = sampling_range_end - sampling_range_start
192
+ # torch.rand() returns in [0, 1)
193
+ # which ensures all clip starts are < sampling_range_end
194
+ clip_start_seconds = (
195
+ torch.rand(num_clips) * sampling_range_width + sampling_range_start
196
+ )
197
+ else:
198
+ assert seconds_between_clip_starts is not None # appease type-checker
199
+ clip_start_seconds = torch.arange(
200
+ sampling_range_start,
201
+ sampling_range_end, # excluded
202
+ seconds_between_clip_starts,
203
+ )
204
+ # As mentioned in the docs, torch.arange may return values
205
+ # equal to or above `end` because of floating precision errors.
206
+ # Here, we manually ensure all values are strictly lower than `sample_range_end`
207
+ if clip_start_seconds[-1] >= sampling_range_end:
208
+ clip_start_seconds = clip_start_seconds[
209
+ clip_start_seconds < sampling_range_end
210
+ ]
211
+
212
+ num_clips = len(clip_start_seconds)
213
+
214
+ all_clips_timestamps = _build_all_clips_timestamps(
215
+ clip_start_seconds=clip_start_seconds,
216
+ num_frames_per_clip=num_frames_per_clip,
217
+ seconds_between_frames=seconds_between_frames,
218
+ end_stream_seconds=decoder.metadata.end_stream_seconds,
219
+ policy_fun=_POLICY_FUNCTIONS[policy],
220
+ )
221
+
222
+ frames = decoder.get_frames_played_at(seconds=all_clips_timestamps)
223
+ return _reshape_4d_framebatch_into_5d(
224
+ frames=frames,
225
+ num_clips=num_clips,
226
+ num_frames_per_clip=num_frames_per_clip,
227
+ )
228
+
229
+
230
+ def clips_at_random_timestamps(
231
+ decoder,
232
+ *,
233
+ num_clips: int = 1,
234
+ num_frames_per_clip: int = 1,
235
+ seconds_between_frames: float | None = None,
236
+ # None means "begining", which may not always be 0
237
+ sampling_range_start: float | None = None,
238
+ sampling_range_end: float | None = None, # interval is [start, end).
239
+ policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
240
+ ) -> FrameBatch:
241
+ # See docstring below
242
+ torch._C._log_api_usage_once("torchcodec.samplers.clips_at_random_timestamps")
243
+ return _generic_time_based_sampler(
244
+ kind="random",
245
+ decoder=decoder,
246
+ num_clips=num_clips,
247
+ seconds_between_clip_starts=None,
248
+ num_frames_per_clip=num_frames_per_clip,
249
+ seconds_between_frames=seconds_between_frames,
250
+ sampling_range_start=sampling_range_start,
251
+ sampling_range_end=sampling_range_end,
252
+ policy=policy,
253
+ )
254
+
255
+
256
+ def clips_at_regular_timestamps(
257
+ decoder,
258
+ *,
259
+ seconds_between_clip_starts: float,
260
+ num_frames_per_clip: int = 1,
261
+ seconds_between_frames: float | None = None,
262
+ # None means "begining", which may not always be 0
263
+ sampling_range_start: float | None = None,
264
+ sampling_range_end: float | None = None, # interval is [start, end).
265
+ policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
266
+ ) -> FrameBatch:
267
+ # See docstring below
268
+ torch._C._log_api_usage_once("torchcodec.samplers.clips_at_regular_timestamps")
269
+ return _generic_time_based_sampler(
270
+ kind="regular",
271
+ decoder=decoder,
272
+ num_clips=None,
273
+ seconds_between_clip_starts=seconds_between_clip_starts,
274
+ num_frames_per_clip=num_frames_per_clip,
275
+ seconds_between_frames=seconds_between_frames,
276
+ sampling_range_start=sampling_range_start,
277
+ sampling_range_end=sampling_range_end,
278
+ policy=policy,
279
+ )
280
+
281
+
282
+ _COMMON_DOCS = """
283
+ {maybe_note}
284
+
285
+ Args:
286
+ decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder`
287
+ instance to sample clips from.
288
+ {num_clips_or_seconds_between_clip_starts}
289
+ num_frames_per_clip (int, optional): The number of frames per clips. Default: 1.
290
+ seconds_between_frames (float or None, optional): The time (in seconds)
291
+ between each frame within a clip. More accurately, this defines the
292
+ time between the *frame sampling point*, i.e. the timestamps at
293
+ which we sample the frames. Because frames span intervals in time ,
294
+ the resulting start of frames within a clip may not be exactly
295
+ spaced by ``seconds_between_frames`` - but on average, they will be.
296
+ Default is None, which is set to the average frame duration
297
+ (``1/average_fps``).
298
+ sampling_range_start (float or None, optional): The start of the
299
+ sampling range, which defines the first timestamp (in seconds) that
300
+ a clip may *start* at. Default: None, which corresponds to the start
301
+ of the video. (Note: some videos start at negative values, which is
302
+ why the default is not 0).
303
+ sampling_range_end (float or None, optional): The end of the sampling
304
+ range, which defines the last timestamp (in seconds) that a clip may
305
+ *start* at. This value is exclusive, i.e. a clip may only start within
306
+ [``sampling_range_start``, ``sampling_range_end``). If None
307
+ (default), the value is set automatically such that the clips never
308
+ span beyond the end of the video, i.e. it is set to
309
+ ``end_video_seconds - (num_frames_per_clip - 1) *
310
+ seconds_between_frames``. When a clip spans beyond the end of the
311
+ video, the ``policy`` parameter defines how to construct such clip.
312
+ policy (str, optional): Defines how to construct clips that span beyond
313
+ the end of the video. This is best described with an example:
314
+ assuming the last valid (seekable) timestamp in a video is 10.9, and
315
+ a clip was sampled to start at timestamp 10.5, with
316
+ ``num_frames_per_clip=5`` and ``seconds_between_frames=0.2``, the
317
+ sampling timestamps of the frames in the clip are supposed to be
318
+ [10.5, 10.7, 10.9, 11.1, 11.2]. But 11.1 and 11.2 are invalid
319
+ timestamps, so the ``policy`` parameter defines how to replace those
320
+ frames, with valid sampling timestamps:
321
+
322
+ - "repeat_last": repeats the last valid frame of the clip. We would
323
+ get frames sampled at timestamps [10.5, 10.7, 10.9, 10.9, 10.9].
324
+ - "wrap": wraps around to the beginning of the clip. We would get
325
+ frames sampled at timestamps [10.5, 10.7, 10.9, 10.5, 10.7].
326
+ - "error": raises an error.
327
+
328
+ Default is "repeat_last". Note that when ``sampling_range_end=None``
329
+ (default), this policy parameter is unlikely to be relevant.
330
+
331
+ {return_docs}
332
+ """
333
+
334
+
335
+ _NUM_CLIPS_DOCS = """
336
+ num_clips (int, optional): The number of clips to return. Default: 1.
337
+ """
338
+ clips_at_random_timestamps.__doc__ = f"""Sample :term:`clips` at random timestamps.
339
+ {_COMMON_DOCS.format(maybe_note="", num_clips_or_seconds_between_clip_starts=_NUM_CLIPS_DOCS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
340
+ """
341
+
342
+
343
+ _SECONDS_BETWEEN_CLIP_STARTS = """
344
+ seconds_between_clip_starts (float): The space (in seconds) between each
345
+ clip start.
346
+ """
347
+
348
+ _NOTE_DOCS = """
349
+ .. note::
350
+ For consistency with existing sampling APIs (such as torchvision), this
351
+ sampler takes a ``seconds_between_clip_starts`` parameter instead of
352
+ ``num_clips``. If you find that supporting ``num_clips`` would be
353
+ useful, please let us know by `opening a feature request
354
+ <https://github.com/pytorch/torchcodec/issues?q=is:open+is:issue>`_.
355
+ """
356
+ clips_at_regular_timestamps.__doc__ = f"""Sample :term:`clips` at regular (equally-spaced) timestamps.
357
+ {_COMMON_DOCS.format(maybe_note=_NOTE_DOCS, num_clips_or_seconds_between_clip_starts=_SECONDS_BETWEEN_CLIP_STARTS, return_docs=_FRAMEBATCH_RETURN_DOCS)}
358
+ """