pixeltable 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/catalog/insertable_table.py +125 -28
- pixeltable/catalog/table.py +10 -1
- pixeltable/config.py +1 -0
- pixeltable/env.py +57 -4
- pixeltable/functions/__init__.py +2 -0
- pixeltable/functions/audio.py +2 -1
- pixeltable/functions/gemini.py +8 -0
- pixeltable/functions/video.py +534 -81
- pixeltable/functions/whisper.py +8 -0
- pixeltable/functions/whisperx.py +177 -0
- pixeltable/{ext/functions → functions}/yolox.py +0 -4
- pixeltable/globals.py +3 -1
- pixeltable/iterators/video.py +138 -0
- pixeltable/metadata/__init__.py +3 -1
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/type_system.py +20 -4
- pixeltable/utils/av.py +111 -0
- pixeltable/utils/code.py +2 -1
- pixeltable/utils/pydantic.py +60 -0
- {pixeltable-0.4.8.dist-info → pixeltable-0.4.10.dist-info}/METADATA +1 -1
- {pixeltable-0.4.8.dist-info → pixeltable-0.4.10.dist-info}/RECORD +26 -24
- pixeltable/ext/__init__.py +0 -17
- pixeltable/ext/functions/__init__.py +0 -11
- pixeltable/ext/functions/whisperx.py +0 -77
- {pixeltable-0.4.8.dist-info → pixeltable-0.4.10.dist-info}/WHEEL +0 -0
- {pixeltable-0.4.8.dist-info → pixeltable-0.4.10.dist-info}/entry_points.txt +0 -0
- {pixeltable-0.4.8.dist-info → pixeltable-0.4.10.dist-info}/licenses/LICENSE +0 -0
pixeltable/functions/whisper.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, Sequence
|
|
|
10
10
|
|
|
11
11
|
import pixeltable as pxt
|
|
12
12
|
from pixeltable.env import Env
|
|
13
|
+
from pixeltable.utils.code import local_public_names
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from whisper import Whisper # type: ignore[import-untyped]
|
|
@@ -90,3 +91,10 @@ def _lookup_model(model_id: str, device: str) -> 'Whisper':
|
|
|
90
91
|
|
|
91
92
|
|
|
92
93
|
_model_cache: dict[tuple[str, str], 'Whisper'] = {}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
__all__ = local_public_names(__name__)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def __dir__() -> list[str]:
|
|
100
|
+
return __all__
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
from pixeltable.config import Config
|
|
7
|
+
from pixeltable.functions.util import resolve_torch_device
|
|
8
|
+
from pixeltable.utils.code import local_public_names
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from transformers import Wav2Vec2Model
|
|
12
|
+
from whisperx.asr import FasterWhisperPipeline # type: ignore[import-untyped]
|
|
13
|
+
from whisperx.diarize import DiarizationPipeline # type: ignore[import-untyped]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pxt.udf
|
|
17
|
+
def transcribe(
|
|
18
|
+
audio: pxt.Audio,
|
|
19
|
+
*,
|
|
20
|
+
model: str,
|
|
21
|
+
diarize: bool = False,
|
|
22
|
+
compute_type: Optional[str] = None,
|
|
23
|
+
language: Optional[str] = None,
|
|
24
|
+
task: Optional[str] = None,
|
|
25
|
+
chunk_size: Optional[int] = None,
|
|
26
|
+
alignment_model_name: Optional[str] = None,
|
|
27
|
+
interpolate_method: Optional[str] = None,
|
|
28
|
+
return_char_alignments: Optional[bool] = None,
|
|
29
|
+
diarization_model_name: Optional[str] = None,
|
|
30
|
+
num_speakers: Optional[int] = None,
|
|
31
|
+
min_speakers: Optional[int] = None,
|
|
32
|
+
max_speakers: Optional[int] = None,
|
|
33
|
+
) -> dict:
|
|
34
|
+
"""
|
|
35
|
+
Transcribe an audio file using WhisperX.
|
|
36
|
+
|
|
37
|
+
This UDF runs a transcription model _locally_ using the WhisperX library,
|
|
38
|
+
equivalent to the WhisperX `transcribe` function, as described in the
|
|
39
|
+
[WhisperX library documentation](https://github.com/m-bain/whisperX).
|
|
40
|
+
|
|
41
|
+
If `diarize=True`, then speaker diarization will also be performed. Several of the UDF parameters are only valid if
|
|
42
|
+
`diarize=True`, as documented in the parameters list below.
|
|
43
|
+
|
|
44
|
+
__Requirements:__
|
|
45
|
+
|
|
46
|
+
- `pip install whisperx`
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
audio: The audio file to transcribe.
|
|
50
|
+
model: The name of the model to use for transcription.
|
|
51
|
+
diarize: Whether to perform speaker diarization.
|
|
52
|
+
compute_type: The compute type to use for the model (e.g., `'int8'`, `'float16'`). If `None`,
|
|
53
|
+
defaults to `'float16'` on CUDA devices and `'int8'` otherwise.
|
|
54
|
+
language: The language code for the transcription (e.g., `'en'` for English).
|
|
55
|
+
task: The task to perform (e.g., `'transcribe'` or `'translate'`). Defaults to `'transcribe'`.
|
|
56
|
+
chunk_size: The size of the audio chunks to process, in seconds. Defaults to `30`.
|
|
57
|
+
alignment_model_name: The name of the alignment model to use. If `None`, uses the default model for the given
|
|
58
|
+
language. Only valid if `diarize=True`.
|
|
59
|
+
interpolate_method: The method to use for interpolation of the alignment results. If not specified, uses the
|
|
60
|
+
WhisperX default (`'nearest'`). Only valid if `diarize=True`.
|
|
61
|
+
return_char_alignments: Whether to return character-level alignments. Defaults to `False`.
|
|
62
|
+
Only valid if `diarize=True`.
|
|
63
|
+
diarization_model_name: The name of the diarization model to use. Defaults to
|
|
64
|
+
`pyannote/speaker-diarization-3.1`. Only valid if `diarize=True`.
|
|
65
|
+
num_speakers: The number of speakers to expect in the audio. By default, the model with try to detect the
|
|
66
|
+
number of speakers. Only valid if `diarize=True`.
|
|
67
|
+
min_speakers: If specified, the minimum number of speakers to expect in the audio.
|
|
68
|
+
Only valid if `diarize=True`.
|
|
69
|
+
max_speakers: If specified, the maximum number of speakers to expect in the audio.
|
|
70
|
+
Only valid if `diarize=True`.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A dictionary containing the audio transcription, diarization (if enabled), and various other metadata.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
|
|
77
|
+
of the table `tbl`:
|
|
78
|
+
|
|
79
|
+
>>> tbl.add_computed_column(result=transcribe(tbl.audio, model='tiny.en'))
|
|
80
|
+
|
|
81
|
+
Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
|
|
82
|
+
of the table `tbl`, with speaker diarization enabled, expecting at least 2 speakers:
|
|
83
|
+
|
|
84
|
+
>>> tbl.add_computed_column(
|
|
85
|
+
... result=transcribe(
|
|
86
|
+
... tbl.audio, model='tiny.en', diarize=True, min_speakers=2
|
|
87
|
+
... )
|
|
88
|
+
... )
|
|
89
|
+
"""
|
|
90
|
+
import whisperx # type: ignore[import-untyped]
|
|
91
|
+
|
|
92
|
+
if not diarize:
|
|
93
|
+
args = locals()
|
|
94
|
+
for param in (
|
|
95
|
+
'alignment_model_name',
|
|
96
|
+
'interpolate_method',
|
|
97
|
+
'return_char_alignments',
|
|
98
|
+
'diarization_model_name',
|
|
99
|
+
'num_speakers',
|
|
100
|
+
'min_speakers',
|
|
101
|
+
'max_speakers',
|
|
102
|
+
):
|
|
103
|
+
if args[param] is not None:
|
|
104
|
+
raise pxt.Error(f'`{param}` can only be set if `diarize=True`')
|
|
105
|
+
|
|
106
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
107
|
+
compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
|
|
108
|
+
transcription_model = _lookup_transcription_model(model, device, compute_type)
|
|
109
|
+
audio_array: np.ndarray = whisperx.load_audio(audio)
|
|
110
|
+
kwargs: dict[str, Any] = {'language': language, 'task': task}
|
|
111
|
+
if chunk_size is not None:
|
|
112
|
+
kwargs['chunk_size'] = chunk_size
|
|
113
|
+
result: dict[str, Any] = transcription_model.transcribe(audio_array, batch_size=16, **kwargs)
|
|
114
|
+
|
|
115
|
+
if diarize:
|
|
116
|
+
# Alignment
|
|
117
|
+
alignment_model, metadata = _lookup_alignment_model(result['language'], device, alignment_model_name)
|
|
118
|
+
kwargs = {}
|
|
119
|
+
if interpolate_method is not None:
|
|
120
|
+
kwargs['interpolate_method'] = interpolate_method
|
|
121
|
+
if return_char_alignments is not None:
|
|
122
|
+
kwargs['return_char_alignments'] = return_char_alignments
|
|
123
|
+
result = whisperx.align(result['segments'], alignment_model, metadata, audio_array, device, **kwargs)
|
|
124
|
+
|
|
125
|
+
# Diarization
|
|
126
|
+
diarization_model = _lookup_diarization_model(device, diarization_model_name)
|
|
127
|
+
diarization_segments = diarization_model(
|
|
128
|
+
audio_array, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers
|
|
129
|
+
)
|
|
130
|
+
result = whisperx.assign_word_speakers(diarization_segments, result)
|
|
131
|
+
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _lookup_transcription_model(model: str, device: str, compute_type: str) -> 'FasterWhisperPipeline':
|
|
136
|
+
import whisperx
|
|
137
|
+
|
|
138
|
+
key = (model, device, compute_type)
|
|
139
|
+
if key not in _model_cache:
|
|
140
|
+
transcription_model = whisperx.load_model(model, device, compute_type=compute_type)
|
|
141
|
+
_model_cache[key] = transcription_model
|
|
142
|
+
return _model_cache[key]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _lookup_alignment_model(language_code: str, device: str, model_name: Optional[str]) -> tuple['Wav2Vec2Model', dict]:
|
|
146
|
+
import whisperx
|
|
147
|
+
|
|
148
|
+
key = (language_code, device, model_name)
|
|
149
|
+
if key not in _alignment_model_cache:
|
|
150
|
+
model, metadata = whisperx.load_align_model(language_code=language_code, device=device, model_name=model_name)
|
|
151
|
+
_alignment_model_cache[key] = (model, metadata)
|
|
152
|
+
return _alignment_model_cache[key]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _lookup_diarization_model(device: str, model_name: Optional[str]) -> 'DiarizationPipeline':
|
|
156
|
+
from whisperx.diarize import DiarizationPipeline
|
|
157
|
+
|
|
158
|
+
key = (device, model_name)
|
|
159
|
+
if key not in _diarization_model_cache:
|
|
160
|
+
auth_token = Config.get().get_string_value('auth_token', section='hf')
|
|
161
|
+
kwargs: dict[str, Any] = {'device': device, 'use_auth_token': auth_token}
|
|
162
|
+
if model_name is not None:
|
|
163
|
+
kwargs['model_name'] = model_name
|
|
164
|
+
_diarization_model_cache[key] = DiarizationPipeline(**kwargs)
|
|
165
|
+
return _diarization_model_cache[key]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
_model_cache: dict[tuple[str, str, str], 'FasterWhisperPipeline'] = {}
|
|
169
|
+
_alignment_model_cache: dict[tuple[str, str, Optional[str]], tuple['Wav2Vec2Model', dict]] = {}
|
|
170
|
+
_diarization_model_cache: dict[tuple[str, Optional[str]], 'DiarizationPipeline'] = {}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
__all__ = local_public_names(__name__)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def __dir__() -> list[str]:
|
|
177
|
+
return __all__
|
|
@@ -20,8 +20,6 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
|
|
|
20
20
|
Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
|
|
21
21
|
defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
|
|
22
22
|
|
|
23
|
-
YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
|
|
24
|
-
|
|
25
23
|
__Requirements__:
|
|
26
24
|
|
|
27
25
|
- `pip install pixeltable-yolox`
|
|
@@ -55,8 +53,6 @@ def yolo_to_coco(detections: dict) -> list:
|
|
|
55
53
|
"""
|
|
56
54
|
Converts the output of a YOLOX object detection model to COCO format.
|
|
57
55
|
|
|
58
|
-
YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
|
|
59
|
-
|
|
60
56
|
Args:
|
|
61
57
|
detections: The output of a YOLOX object detection model, as returned by `yolox`.
|
|
62
58
|
|
pixeltable/globals.py
CHANGED
|
@@ -3,9 +3,10 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, Optional, Union
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, Optional, Sequence, Union
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
|
+
import pydantic
|
|
9
10
|
from pandas.io.formats.style import Styler
|
|
10
11
|
|
|
11
12
|
from pixeltable import DataFrame, catalog, exceptions as excs, exprs, func, share, type_system as ts
|
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
|
25
26
|
Path, # OS paths, filenames, URLs
|
|
26
27
|
Iterator[dict[str, Any]], # iterator producing dictionaries of values
|
|
27
28
|
RowData, # list of dictionaries
|
|
29
|
+
Sequence[pydantic.BaseModel], # list of Pydantic models
|
|
28
30
|
DataFrame, # Pixeltable DataFrame
|
|
29
31
|
pd.DataFrame, # pandas DataFrame
|
|
30
32
|
datasets.Dataset,
|
pixeltable/iterators/video.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import math
|
|
3
|
+
import shutil
|
|
4
|
+
import subprocess
|
|
3
5
|
from fractions import Fraction
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from typing import Any, Optional
|
|
@@ -8,8 +10,11 @@ import av
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
import PIL.Image
|
|
10
12
|
|
|
13
|
+
import pixeltable as pxt
|
|
11
14
|
import pixeltable.exceptions as excs
|
|
12
15
|
import pixeltable.type_system as ts
|
|
16
|
+
import pixeltable.utils.av as av_utils
|
|
17
|
+
from pixeltable.utils.media_store import TempStore
|
|
13
18
|
|
|
14
19
|
from .base import ComponentIterator
|
|
15
20
|
|
|
@@ -224,3 +229,136 @@ class FrameIterator(ComponentIterator):
|
|
|
224
229
|
# then the iterator will step forward to the desired frame on the subsequent call to next().
|
|
225
230
|
self.container.seek(seek_pos, backward=True, stream=self.container.streams.video[0])
|
|
226
231
|
self.next_pos = pos
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class VideoSplitter(ComponentIterator):
|
|
235
|
+
"""
|
|
236
|
+
Iterator over segments of a video file, which is split into fixed-size segments of length `segment_duration`
|
|
237
|
+
seconds.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
segment_duration: Video segment duration in seconds
|
|
241
|
+
overlap: Overlap between consecutive segments in seconds.
|
|
242
|
+
min_segment_duration: Drop the last segment if it is smaller than min_segment_duration
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
# Input parameters
|
|
246
|
+
video_path: Path
|
|
247
|
+
segment_duration: float
|
|
248
|
+
overlap: float
|
|
249
|
+
min_segment_duration: float
|
|
250
|
+
|
|
251
|
+
# Video metadata
|
|
252
|
+
video_duration: float
|
|
253
|
+
video_time_base: Fraction
|
|
254
|
+
video_start_time: int
|
|
255
|
+
|
|
256
|
+
# position tracking
|
|
257
|
+
next_segment_start: float
|
|
258
|
+
next_segment_start_pts: int
|
|
259
|
+
|
|
260
|
+
def __init__(self, video: str, segment_duration: float, *, overlap: float = 0.0, min_segment_duration: float = 0.0):
|
|
261
|
+
assert segment_duration > 0.0
|
|
262
|
+
assert segment_duration >= min_segment_duration
|
|
263
|
+
assert overlap < segment_duration
|
|
264
|
+
|
|
265
|
+
video_path = Path(video)
|
|
266
|
+
assert video_path.exists() and video_path.is_file()
|
|
267
|
+
|
|
268
|
+
if not shutil.which('ffmpeg'):
|
|
269
|
+
raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use VideoSplitter.')
|
|
270
|
+
|
|
271
|
+
self.video_path = video_path
|
|
272
|
+
self.segment_duration = segment_duration
|
|
273
|
+
self.overlap = overlap
|
|
274
|
+
self.min_segment_duration = min_segment_duration
|
|
275
|
+
|
|
276
|
+
with av.open(str(video_path)) as container:
|
|
277
|
+
video_stream = container.streams.video[0]
|
|
278
|
+
self.video_time_base = video_stream.time_base
|
|
279
|
+
self.video_start_time = video_stream.start_time or 0
|
|
280
|
+
|
|
281
|
+
self.next_segment_start = float(self.video_start_time * self.video_time_base)
|
|
282
|
+
self.next_segment_start_pts = self.video_start_time
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def input_schema(cls) -> dict[str, ts.ColumnType]:
|
|
286
|
+
return {
|
|
287
|
+
'video': ts.VideoType(nullable=False),
|
|
288
|
+
'segment_duration': ts.FloatType(nullable=False),
|
|
289
|
+
'overlap': ts.FloatType(nullable=True),
|
|
290
|
+
'min_segment_duration': ts.FloatType(nullable=True),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ts.ColumnType], list[str]]:
|
|
295
|
+
param_names = ['segment_duration', 'overlap', 'min_segment_duration']
|
|
296
|
+
params = dict(zip(param_names, args))
|
|
297
|
+
params.update(kwargs)
|
|
298
|
+
|
|
299
|
+
segment_duration = params['segment_duration']
|
|
300
|
+
min_segment_duration = params.get('min_segment_duration', 0.0)
|
|
301
|
+
overlap = params.get('overlap', 0.0)
|
|
302
|
+
|
|
303
|
+
if segment_duration <= 0.0:
|
|
304
|
+
raise excs.Error('segment_duration must be a positive number')
|
|
305
|
+
if segment_duration < min_segment_duration:
|
|
306
|
+
raise excs.Error('segment_duration must be at least min_segment_duration')
|
|
307
|
+
if overlap >= segment_duration:
|
|
308
|
+
raise excs.Error('overlap must be less than segment_duration')
|
|
309
|
+
|
|
310
|
+
return {
|
|
311
|
+
'segment_start': ts.FloatType(nullable=False),
|
|
312
|
+
'segment_start_pts': ts.IntType(nullable=False),
|
|
313
|
+
'segment_end': ts.FloatType(nullable=False),
|
|
314
|
+
'segment_end_pts': ts.IntType(nullable=False),
|
|
315
|
+
'video_segment': ts.VideoType(nullable=False),
|
|
316
|
+
}, []
|
|
317
|
+
|
|
318
|
+
def __next__(self) -> dict[str, Any]:
|
|
319
|
+
segment_path = str(TempStore.create_path(extension='.mp4'))
|
|
320
|
+
try:
|
|
321
|
+
cmd = av_utils.ffmpeg_clip_cmd(
|
|
322
|
+
str(self.video_path), segment_path, self.next_segment_start, self.segment_duration
|
|
323
|
+
)
|
|
324
|
+
_ = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
|
325
|
+
|
|
326
|
+
# use the actual duration
|
|
327
|
+
segment_duration = av_utils.get_video_duration(segment_path)
|
|
328
|
+
if segment_duration - self.overlap == 0.0:
|
|
329
|
+
# we're done
|
|
330
|
+
Path(segment_path).unlink()
|
|
331
|
+
raise StopIteration
|
|
332
|
+
|
|
333
|
+
if segment_duration < self.min_segment_duration:
|
|
334
|
+
Path(segment_path).unlink()
|
|
335
|
+
raise StopIteration
|
|
336
|
+
|
|
337
|
+
segment_end = self.next_segment_start + segment_duration
|
|
338
|
+
segment_end_pts = self.next_segment_start_pts + round(segment_duration / self.video_time_base)
|
|
339
|
+
|
|
340
|
+
result = {
|
|
341
|
+
'segment_start': self.next_segment_start,
|
|
342
|
+
'segment_start_pts': self.next_segment_start_pts,
|
|
343
|
+
'segment_end': segment_end,
|
|
344
|
+
'segment_end_pts': segment_end_pts,
|
|
345
|
+
'video_segment': segment_path,
|
|
346
|
+
}
|
|
347
|
+
self.next_segment_start = segment_end - self.overlap
|
|
348
|
+
self.next_segment_start_pts = segment_end_pts - round(self.overlap / self.video_time_base)
|
|
349
|
+
|
|
350
|
+
return result
|
|
351
|
+
|
|
352
|
+
except subprocess.CalledProcessError as e:
|
|
353
|
+
if Path(segment_path).exists():
|
|
354
|
+
Path(segment_path).unlink()
|
|
355
|
+
error_msg = f'ffmpeg failed with return code {e.returncode}'
|
|
356
|
+
if e.stderr:
|
|
357
|
+
error_msg += f': {e.stderr.strip()}'
|
|
358
|
+
raise pxt.Error(error_msg) from e
|
|
359
|
+
|
|
360
|
+
def close(self) -> None:
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
def set_pos(self, pos: int) -> None:
|
|
364
|
+
pass
|
pixeltable/metadata/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ def create_system_info(engine: sql.engine.Engine) -> None:
|
|
|
25
25
|
"""Create the system metadata record"""
|
|
26
26
|
system_md = SystemInfoMd(schema_version=VERSION)
|
|
27
27
|
record = SystemInfo(md=dataclasses.asdict(system_md))
|
|
28
|
+
_logger.debug(f'Creating pixeltable system info record {record}')
|
|
28
29
|
with orm.Session(engine, future=True) as session:
|
|
29
30
|
# Write system metadata only once for idempotency
|
|
30
31
|
if session.query(SystemInfo).count() == 0:
|
|
@@ -54,7 +55,8 @@ for _, modname, _ in pkgutil.iter_modules([os.path.dirname(__file__) + '/convert
|
|
|
54
55
|
def upgrade_md(engine: sql.engine.Engine) -> None:
|
|
55
56
|
"""Upgrade the metadata schema to the current version"""
|
|
56
57
|
with orm.Session(engine) as session:
|
|
57
|
-
|
|
58
|
+
# Get exclusive lock on SystemInfo row
|
|
59
|
+
system_info = session.query(SystemInfo).with_for_update().one().md
|
|
58
60
|
md_version = system_info['schema_version']
|
|
59
61
|
assert isinstance(md_version, int)
|
|
60
62
|
_logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import Callable, ClassVar, Optional
|
|
2
|
+
|
|
3
|
+
from mypy import nodes
|
|
4
|
+
from mypy.plugin import AnalyzeTypeContext, ClassDefContext, FunctionContext, MethodSigContext, Plugin
|
|
5
|
+
from mypy.plugins.common import add_attribute_to_class, add_method_to_class
|
|
6
|
+
from mypy.types import AnyType, FunctionLike, Instance, NoneType, Type, TypeOfAny
|
|
7
|
+
|
|
8
|
+
import pixeltable as pxt
|
|
9
|
+
from pixeltable import exprs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PxtPlugin(Plugin):
|
|
13
|
+
__UDA_FULLNAME = f'{pxt.uda.__module__}.{pxt.uda.__name__}'
|
|
14
|
+
__ARRAY_GETITEM_FULLNAME = f'{pxt.Array.__module__}.{pxt.Array.__name__}.__class_getitem__'
|
|
15
|
+
__ADD_COLUMN_FULLNAME = f'{pxt.Table.__module__}.{pxt.Table.__name__}.{pxt.Table.add_column.__name__}'
|
|
16
|
+
__ADD_COMPUTED_COLUMN_FULLNAME = (
|
|
17
|
+
f'{pxt.Table.__module__}.{pxt.Table.__name__}.{pxt.Table.add_computed_column.__name__}'
|
|
18
|
+
)
|
|
19
|
+
__TYPE_MAP: ClassVar[dict] = {
|
|
20
|
+
pxt.Json: 'typing.Any',
|
|
21
|
+
pxt.Array: 'numpy.ndarray',
|
|
22
|
+
pxt.Image: 'PIL.Image.Image',
|
|
23
|
+
pxt.Video: 'builtins.str',
|
|
24
|
+
pxt.Audio: 'builtins.str',
|
|
25
|
+
pxt.Document: 'builtins.str',
|
|
26
|
+
}
|
|
27
|
+
__FULLNAME_MAP: ClassVar[dict] = {f'{k.__module__}.{k.__name__}': v for k, v in __TYPE_MAP.items()}
|
|
28
|
+
|
|
29
|
+
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
|
|
30
|
+
return adjust_uda_type
|
|
31
|
+
|
|
32
|
+
def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
|
|
33
|
+
if fullname in self.__FULLNAME_MAP:
|
|
34
|
+
subst_name = self.__FULLNAME_MAP[fullname]
|
|
35
|
+
return lambda ctx: adjust_pxt_type(ctx, subst_name)
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
def get_method_signature_hook(self, fullname: str) -> Optional[Callable[[MethodSigContext], FunctionLike]]:
|
|
39
|
+
if fullname in (self.__ADD_COLUMN_FULLNAME, self.__ADD_COMPUTED_COLUMN_FULLNAME):
|
|
40
|
+
return adjust_kwargs
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
def get_class_decorator_hook_2(self, fullname: str) -> Optional[Callable[[ClassDefContext], bool]]:
|
|
44
|
+
if fullname == self.__UDA_FULLNAME:
|
|
45
|
+
return adjust_uda_methods
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def plugin(version: str) -> type:
|
|
50
|
+
return PxtPlugin
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
_AGGREGATOR_FULLNAME = f'{pxt.Aggregator.__module__}.{pxt.Aggregator.__name__}'
|
|
54
|
+
_FN_CALL_FULLNAME = f'{exprs.Expr.__module__}.{exprs.Expr.__name__}'
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def adjust_uda_type(ctx: FunctionContext) -> Type:
|
|
58
|
+
"""
|
|
59
|
+
Mypy doesn't understand that a class with a @uda decorator isn't actually a class, so it assumes
|
|
60
|
+
that sum(expr), for example, actually returns an instance of sum. We correct this by changing the
|
|
61
|
+
return type of any subclass of `Aggregator` to `FunctionCall`.
|
|
62
|
+
"""
|
|
63
|
+
ret_type = ctx.default_return_type
|
|
64
|
+
if isinstance(ret_type, Instance) and (
|
|
65
|
+
ret_type.type.fullname == _AGGREGATOR_FULLNAME
|
|
66
|
+
or any(base.type.fullname == _AGGREGATOR_FULLNAME for base in ret_type.type.bases)
|
|
67
|
+
):
|
|
68
|
+
ret_type = AnyType(TypeOfAny.special_form)
|
|
69
|
+
return ret_type
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def adjust_pxt_type(ctx: AnalyzeTypeContext, subst_name: str) -> Type:
|
|
73
|
+
"""
|
|
74
|
+
Replaces the special Pixeltable classes (such as pxt.Array) with their standard equivalents (such as np.ndarray).
|
|
75
|
+
"""
|
|
76
|
+
if subst_name == 'typing.Any':
|
|
77
|
+
return AnyType(TypeOfAny.special_form)
|
|
78
|
+
return ctx.api.named_type(subst_name, [])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def adjust_kwargs(ctx: MethodSigContext) -> FunctionLike:
|
|
82
|
+
"""
|
|
83
|
+
Mypy has a "feature" where it will spit out multiple warnings if a method with signature
|
|
84
|
+
```
|
|
85
|
+
def my_func(*, arg1: int, arg2: str, **kwargs: Expr)
|
|
86
|
+
```
|
|
87
|
+
(for example) is called with bare kwargs:
|
|
88
|
+
```
|
|
89
|
+
my_func(my_kwarg=value)
|
|
90
|
+
```
|
|
91
|
+
This is a disaster for type-checking of add_column and add_computed_column. Here we adjust the signature so
|
|
92
|
+
that mypy thinks it is simply
|
|
93
|
+
```
|
|
94
|
+
def my_func(**kwargs: Any)
|
|
95
|
+
```
|
|
96
|
+
thereby avoiding any type-checking errors. For details, see: <https://github.com/python/mypy/issues/18481>
|
|
97
|
+
"""
|
|
98
|
+
sig = ctx.default_signature
|
|
99
|
+
new_arg_names = sig.arg_names[-1:]
|
|
100
|
+
new_arg_types = [AnyType(TypeOfAny.special_form)]
|
|
101
|
+
new_arg_kinds = sig.arg_kinds[-1:]
|
|
102
|
+
return sig.copy_modified(arg_names=new_arg_names, arg_types=new_arg_types, arg_kinds=new_arg_kinds)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def adjust_uda_methods(ctx: ClassDefContext) -> bool:
|
|
106
|
+
"""
|
|
107
|
+
Mypy does not handle the `@pxt.uda` aggregator well; it continues to treat the decorated class as a class,
|
|
108
|
+
even though it has been replaced by an `AggregateFunction`. Here we add static methods to the class that
|
|
109
|
+
imitate various (instance) methods of `AggregateFunction` so that they can be properly type-checked.
|
|
110
|
+
"""
|
|
111
|
+
list_type = ctx.api.named_type('builtins.list', [AnyType(TypeOfAny.special_form)])
|
|
112
|
+
fn_arg = nodes.Argument(nodes.Var('fn'), AnyType(TypeOfAny.special_form), None, nodes.ARG_POS)
|
|
113
|
+
args_arg = nodes.Argument(nodes.Var('args'), AnyType(TypeOfAny.special_form), None, nodes.ARG_STAR)
|
|
114
|
+
kwargs_arg = nodes.Argument(nodes.Var('kwargs'), AnyType(TypeOfAny.special_form), None, nodes.ARG_STAR2)
|
|
115
|
+
add_method_to_class(ctx.api, ctx.cls, '__init__', args=[args_arg, kwargs_arg], return_type=NoneType())
|
|
116
|
+
add_method_to_class(
|
|
117
|
+
ctx.api, ctx.cls, 'to_sql', args=[fn_arg], return_type=AnyType(TypeOfAny.special_form), is_staticmethod=True
|
|
118
|
+
)
|
|
119
|
+
add_method_to_class(
|
|
120
|
+
ctx.api, ctx.cls, 'overload', args=[fn_arg], return_type=AnyType(TypeOfAny.special_form), is_staticmethod=True
|
|
121
|
+
)
|
|
122
|
+
add_attribute_to_class(ctx.api, ctx.cls, 'signatures', typ=list_type, is_classvar=True)
|
|
123
|
+
return True
|
pixeltable/type_system.py
CHANGED
|
@@ -9,8 +9,11 @@ import types
|
|
|
9
9
|
import typing
|
|
10
10
|
import urllib.parse
|
|
11
11
|
import urllib.request
|
|
12
|
+
from pathlib import Path
|
|
12
13
|
from typing import Any, ClassVar, Iterable, Literal, Mapping, Optional, Sequence, Union
|
|
13
14
|
|
|
15
|
+
from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
|
|
16
|
+
|
|
14
17
|
import av
|
|
15
18
|
import jsonschema
|
|
16
19
|
import jsonschema.protocols
|
|
@@ -24,8 +27,6 @@ from typing_extensions import _AnnotatedAlias
|
|
|
24
27
|
import pixeltable.exceptions as excs
|
|
25
28
|
from pixeltable.utils import parse_local_file_path
|
|
26
29
|
|
|
27
|
-
from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
|
|
28
|
-
|
|
29
30
|
|
|
30
31
|
class ColumnType:
|
|
31
32
|
@enum.unique
|
|
@@ -292,7 +293,11 @@ class ColumnType:
|
|
|
292
293
|
|
|
293
294
|
@classmethod
|
|
294
295
|
def from_python_type(
|
|
295
|
-
cls,
|
|
296
|
+
cls,
|
|
297
|
+
t: type | _GenericAlias,
|
|
298
|
+
nullable_default: bool = False,
|
|
299
|
+
allow_builtin_types: bool = True,
|
|
300
|
+
infer_pydantic_json: bool = False,
|
|
296
301
|
) -> Optional[ColumnType]:
|
|
297
302
|
"""
|
|
298
303
|
Convert a Python type into a Pixeltable `ColumnType` instance.
|
|
@@ -305,6 +310,8 @@ class ColumnType:
|
|
|
305
310
|
allowed (as in UDF definitions). If False, then only Pixeltable types such as `pxt.String`,
|
|
306
311
|
`pxt.Int`, etc., will be allowed (as in schema definitions). `Optional` and `Required`
|
|
307
312
|
designations will be allowed regardless.
|
|
313
|
+
infer_pydantic_json: If True, accepts an extended set of built-ins (eg, Enum, Path) and returns the type to
|
|
314
|
+
which pydantic.BaseModel.model_dump(mode='json') serializes it.
|
|
308
315
|
"""
|
|
309
316
|
origin = typing.get_origin(t)
|
|
310
317
|
type_args = typing.get_args(t)
|
|
@@ -314,7 +321,9 @@ class ColumnType:
|
|
|
314
321
|
# `t` is a type of the form Optional[T] (equivalently, T | None or None | T).
|
|
315
322
|
# We treat it as the underlying type but with nullable=True.
|
|
316
323
|
underlying_py_type = type_args[0] if type_args[1] is type(None) else type_args[1]
|
|
317
|
-
underlying = cls.from_python_type(
|
|
324
|
+
underlying = cls.from_python_type(
|
|
325
|
+
underlying_py_type, allow_builtin_types=allow_builtin_types, infer_pydantic_json=infer_pydantic_json
|
|
326
|
+
)
|
|
318
327
|
if underlying is not None:
|
|
319
328
|
return underlying.copy(nullable=True)
|
|
320
329
|
elif origin is Required:
|
|
@@ -341,6 +350,13 @@ class ColumnType:
|
|
|
341
350
|
if literal_type is None:
|
|
342
351
|
return None
|
|
343
352
|
return literal_type.copy(nullable=(literal_type.nullable or nullable_default))
|
|
353
|
+
if infer_pydantic_json and isinstance(t, type) and issubclass(t, enum.Enum):
|
|
354
|
+
literal_type = cls.infer_common_literal_type(member.value for member in t)
|
|
355
|
+
if literal_type is None:
|
|
356
|
+
return None
|
|
357
|
+
return literal_type.copy(nullable=(literal_type.nullable or nullable_default))
|
|
358
|
+
if infer_pydantic_json and t is Path:
|
|
359
|
+
return StringType(nullable=nullable_default)
|
|
344
360
|
if t is str:
|
|
345
361
|
return StringType(nullable=nullable_default)
|
|
346
362
|
if t is int:
|