pixeltable 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, Sequence
10
10
 
11
11
  import pixeltable as pxt
12
12
  from pixeltable.env import Env
13
+ from pixeltable.utils.code import local_public_names
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from whisper import Whisper # type: ignore[import-untyped]
@@ -90,3 +91,10 @@ def _lookup_model(model_id: str, device: str) -> 'Whisper':
90
91
 
91
92
 
92
93
  _model_cache: dict[tuple[str, str], 'Whisper'] = {}
94
+
95
+
96
+ __all__ = local_public_names(__name__)
97
+
98
+
99
+ def __dir__() -> list[str]:
100
+ return __all__
@@ -0,0 +1,177 @@
1
+ from typing import TYPE_CHECKING, Any, Optional
2
+
3
+ import numpy as np
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable.config import Config
7
+ from pixeltable.functions.util import resolve_torch_device
8
+ from pixeltable.utils.code import local_public_names
9
+
10
+ if TYPE_CHECKING:
11
+ from transformers import Wav2Vec2Model
12
+ from whisperx.asr import FasterWhisperPipeline # type: ignore[import-untyped]
13
+ from whisperx.diarize import DiarizationPipeline # type: ignore[import-untyped]
14
+
15
+
16
+ @pxt.udf
17
+ def transcribe(
18
+ audio: pxt.Audio,
19
+ *,
20
+ model: str,
21
+ diarize: bool = False,
22
+ compute_type: Optional[str] = None,
23
+ language: Optional[str] = None,
24
+ task: Optional[str] = None,
25
+ chunk_size: Optional[int] = None,
26
+ alignment_model_name: Optional[str] = None,
27
+ interpolate_method: Optional[str] = None,
28
+ return_char_alignments: Optional[bool] = None,
29
+ diarization_model_name: Optional[str] = None,
30
+ num_speakers: Optional[int] = None,
31
+ min_speakers: Optional[int] = None,
32
+ max_speakers: Optional[int] = None,
33
+ ) -> dict:
34
+ """
35
+ Transcribe an audio file using WhisperX.
36
+
37
+ This UDF runs a transcription model _locally_ using the WhisperX library,
38
+ equivalent to the WhisperX `transcribe` function, as described in the
39
+ [WhisperX library documentation](https://github.com/m-bain/whisperX).
40
+
41
+ If `diarize=True`, then speaker diarization will also be performed. Several of the UDF parameters are only valid if
42
+ `diarize=True`, as documented in the parameters list below.
43
+
44
+ __Requirements:__
45
+
46
+ - `pip install whisperx`
47
+
48
+ Args:
49
+ audio: The audio file to transcribe.
50
+ model: The name of the model to use for transcription.
51
+ diarize: Whether to perform speaker diarization.
52
+ compute_type: The compute type to use for the model (e.g., `'int8'`, `'float16'`). If `None`,
53
+ defaults to `'float16'` on CUDA devices and `'int8'` otherwise.
54
+ language: The language code for the transcription (e.g., `'en'` for English).
55
+ task: The task to perform (e.g., `'transcribe'` or `'translate'`). Defaults to `'transcribe'`.
56
+ chunk_size: The size of the audio chunks to process, in seconds. Defaults to `30`.
57
+ alignment_model_name: The name of the alignment model to use. If `None`, uses the default model for the given
58
+ language. Only valid if `diarize=True`.
59
+ interpolate_method: The method to use for interpolation of the alignment results. If not specified, uses the
60
+ WhisperX default (`'nearest'`). Only valid if `diarize=True`.
61
+ return_char_alignments: Whether to return character-level alignments. Defaults to `False`.
62
+ Only valid if `diarize=True`.
63
+ diarization_model_name: The name of the diarization model to use. Defaults to
64
+ `pyannote/speaker-diarization-3.1`. Only valid if `diarize=True`.
65
+ num_speakers: The number of speakers to expect in the audio. By default, the model with try to detect the
66
+ number of speakers. Only valid if `diarize=True`.
67
+ min_speakers: If specified, the minimum number of speakers to expect in the audio.
68
+ Only valid if `diarize=True`.
69
+ max_speakers: If specified, the maximum number of speakers to expect in the audio.
70
+ Only valid if `diarize=True`.
71
+
72
+ Returns:
73
+ A dictionary containing the audio transcription, diarization (if enabled), and various other metadata.
74
+
75
+ Examples:
76
+ Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
77
+ of the table `tbl`:
78
+
79
+ >>> tbl.add_computed_column(result=transcribe(tbl.audio, model='tiny.en'))
80
+
81
+ Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
82
+ of the table `tbl`, with speaker diarization enabled, expecting at least 2 speakers:
83
+
84
+ >>> tbl.add_computed_column(
85
+ ... result=transcribe(
86
+ ... tbl.audio, model='tiny.en', diarize=True, min_speakers=2
87
+ ... )
88
+ ... )
89
+ """
90
+ import whisperx # type: ignore[import-untyped]
91
+
92
+ if not diarize:
93
+ args = locals()
94
+ for param in (
95
+ 'alignment_model_name',
96
+ 'interpolate_method',
97
+ 'return_char_alignments',
98
+ 'diarization_model_name',
99
+ 'num_speakers',
100
+ 'min_speakers',
101
+ 'max_speakers',
102
+ ):
103
+ if args[param] is not None:
104
+ raise pxt.Error(f'`{param}` can only be set if `diarize=True`')
105
+
106
+ device = resolve_torch_device('auto', allow_mps=False)
107
+ compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
108
+ transcription_model = _lookup_transcription_model(model, device, compute_type)
109
+ audio_array: np.ndarray = whisperx.load_audio(audio)
110
+ kwargs: dict[str, Any] = {'language': language, 'task': task}
111
+ if chunk_size is not None:
112
+ kwargs['chunk_size'] = chunk_size
113
+ result: dict[str, Any] = transcription_model.transcribe(audio_array, batch_size=16, **kwargs)
114
+
115
+ if diarize:
116
+ # Alignment
117
+ alignment_model, metadata = _lookup_alignment_model(result['language'], device, alignment_model_name)
118
+ kwargs = {}
119
+ if interpolate_method is not None:
120
+ kwargs['interpolate_method'] = interpolate_method
121
+ if return_char_alignments is not None:
122
+ kwargs['return_char_alignments'] = return_char_alignments
123
+ result = whisperx.align(result['segments'], alignment_model, metadata, audio_array, device, **kwargs)
124
+
125
+ # Diarization
126
+ diarization_model = _lookup_diarization_model(device, diarization_model_name)
127
+ diarization_segments = diarization_model(
128
+ audio_array, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers
129
+ )
130
+ result = whisperx.assign_word_speakers(diarization_segments, result)
131
+
132
+ return result
133
+
134
+
135
+ def _lookup_transcription_model(model: str, device: str, compute_type: str) -> 'FasterWhisperPipeline':
136
+ import whisperx
137
+
138
+ key = (model, device, compute_type)
139
+ if key not in _model_cache:
140
+ transcription_model = whisperx.load_model(model, device, compute_type=compute_type)
141
+ _model_cache[key] = transcription_model
142
+ return _model_cache[key]
143
+
144
+
145
+ def _lookup_alignment_model(language_code: str, device: str, model_name: Optional[str]) -> tuple['Wav2Vec2Model', dict]:
146
+ import whisperx
147
+
148
+ key = (language_code, device, model_name)
149
+ if key not in _alignment_model_cache:
150
+ model, metadata = whisperx.load_align_model(language_code=language_code, device=device, model_name=model_name)
151
+ _alignment_model_cache[key] = (model, metadata)
152
+ return _alignment_model_cache[key]
153
+
154
+
155
+ def _lookup_diarization_model(device: str, model_name: Optional[str]) -> 'DiarizationPipeline':
156
+ from whisperx.diarize import DiarizationPipeline
157
+
158
+ key = (device, model_name)
159
+ if key not in _diarization_model_cache:
160
+ auth_token = Config.get().get_string_value('auth_token', section='hf')
161
+ kwargs: dict[str, Any] = {'device': device, 'use_auth_token': auth_token}
162
+ if model_name is not None:
163
+ kwargs['model_name'] = model_name
164
+ _diarization_model_cache[key] = DiarizationPipeline(**kwargs)
165
+ return _diarization_model_cache[key]
166
+
167
+
168
+ _model_cache: dict[tuple[str, str, str], 'FasterWhisperPipeline'] = {}
169
+ _alignment_model_cache: dict[tuple[str, str, Optional[str]], tuple['Wav2Vec2Model', dict]] = {}
170
+ _diarization_model_cache: dict[tuple[str, Optional[str]], 'DiarizationPipeline'] = {}
171
+
172
+
173
+ __all__ = local_public_names(__name__)
174
+
175
+
176
+ def __dir__() -> list[str]:
177
+ return __all__
@@ -20,8 +20,6 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
20
20
  Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
21
21
  defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
22
22
 
23
- YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
24
-
25
23
  __Requirements__:
26
24
 
27
25
  - `pip install pixeltable-yolox`
@@ -55,8 +53,6 @@ def yolo_to_coco(detections: dict) -> list:
55
53
  """
56
54
  Converts the output of a YOLOX object detection model to COCO format.
57
55
 
58
- YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
59
-
60
56
  Args:
61
57
  detections: The output of a YOLOX object detection model, as returned by `yolox`.
62
58
 
pixeltable/globals.py CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import logging
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, Optional, Sequence, Union
7
7
 
8
8
  import pandas as pd
9
+ import pydantic
9
10
  from pandas.io.formats.style import Styler
10
11
 
11
12
  from pixeltable import DataFrame, catalog, exceptions as excs, exprs, func, share, type_system as ts
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  Path, # OS paths, filenames, URLs
26
27
  Iterator[dict[str, Any]], # iterator producing dictionaries of values
27
28
  RowData, # list of dictionaries
29
+ Sequence[pydantic.BaseModel], # list of Pydantic models
28
30
  DataFrame, # Pixeltable DataFrame
29
31
  pd.DataFrame, # pandas DataFrame
30
32
  datasets.Dataset,
@@ -1,5 +1,7 @@
1
1
  import logging
2
2
  import math
3
+ import shutil
4
+ import subprocess
3
5
  from fractions import Fraction
4
6
  from pathlib import Path
5
7
  from typing import Any, Optional
@@ -8,8 +10,11 @@ import av
8
10
  import pandas as pd
9
11
  import PIL.Image
10
12
 
13
+ import pixeltable as pxt
11
14
  import pixeltable.exceptions as excs
12
15
  import pixeltable.type_system as ts
16
+ import pixeltable.utils.av as av_utils
17
+ from pixeltable.utils.media_store import TempStore
13
18
 
14
19
  from .base import ComponentIterator
15
20
 
@@ -224,3 +229,136 @@ class FrameIterator(ComponentIterator):
224
229
  # then the iterator will step forward to the desired frame on the subsequent call to next().
225
230
  self.container.seek(seek_pos, backward=True, stream=self.container.streams.video[0])
226
231
  self.next_pos = pos
232
+
233
+
234
+ class VideoSplitter(ComponentIterator):
235
+ """
236
+ Iterator over segments of a video file, which is split into fixed-size segments of length `segment_duration`
237
+ seconds.
238
+
239
+ Args:
240
+ segment_duration: Video segment duration in seconds
241
+ overlap: Overlap between consecutive segments in seconds.
242
+ min_segment_duration: Drop the last segment if it is smaller than min_segment_duration
243
+ """
244
+
245
+ # Input parameters
246
+ video_path: Path
247
+ segment_duration: float
248
+ overlap: float
249
+ min_segment_duration: float
250
+
251
+ # Video metadata
252
+ video_duration: float
253
+ video_time_base: Fraction
254
+ video_start_time: int
255
+
256
+ # position tracking
257
+ next_segment_start: float
258
+ next_segment_start_pts: int
259
+
260
+ def __init__(self, video: str, segment_duration: float, *, overlap: float = 0.0, min_segment_duration: float = 0.0):
261
+ assert segment_duration > 0.0
262
+ assert segment_duration >= min_segment_duration
263
+ assert overlap < segment_duration
264
+
265
+ video_path = Path(video)
266
+ assert video_path.exists() and video_path.is_file()
267
+
268
+ if not shutil.which('ffmpeg'):
269
+ raise pxt.Error('ffmpeg is not installed or not in PATH. Please install ffmpeg to use VideoSplitter.')
270
+
271
+ self.video_path = video_path
272
+ self.segment_duration = segment_duration
273
+ self.overlap = overlap
274
+ self.min_segment_duration = min_segment_duration
275
+
276
+ with av.open(str(video_path)) as container:
277
+ video_stream = container.streams.video[0]
278
+ self.video_time_base = video_stream.time_base
279
+ self.video_start_time = video_stream.start_time or 0
280
+
281
+ self.next_segment_start = float(self.video_start_time * self.video_time_base)
282
+ self.next_segment_start_pts = self.video_start_time
283
+
284
+ @classmethod
285
+ def input_schema(cls) -> dict[str, ts.ColumnType]:
286
+ return {
287
+ 'video': ts.VideoType(nullable=False),
288
+ 'segment_duration': ts.FloatType(nullable=False),
289
+ 'overlap': ts.FloatType(nullable=True),
290
+ 'min_segment_duration': ts.FloatType(nullable=True),
291
+ }
292
+
293
+ @classmethod
294
+ def output_schema(cls, *args: Any, **kwargs: Any) -> tuple[dict[str, ts.ColumnType], list[str]]:
295
+ param_names = ['segment_duration', 'overlap', 'min_segment_duration']
296
+ params = dict(zip(param_names, args))
297
+ params.update(kwargs)
298
+
299
+ segment_duration = params['segment_duration']
300
+ min_segment_duration = params.get('min_segment_duration', 0.0)
301
+ overlap = params.get('overlap', 0.0)
302
+
303
+ if segment_duration <= 0.0:
304
+ raise excs.Error('segment_duration must be a positive number')
305
+ if segment_duration < min_segment_duration:
306
+ raise excs.Error('segment_duration must be at least min_segment_duration')
307
+ if overlap >= segment_duration:
308
+ raise excs.Error('overlap must be less than segment_duration')
309
+
310
+ return {
311
+ 'segment_start': ts.FloatType(nullable=False),
312
+ 'segment_start_pts': ts.IntType(nullable=False),
313
+ 'segment_end': ts.FloatType(nullable=False),
314
+ 'segment_end_pts': ts.IntType(nullable=False),
315
+ 'video_segment': ts.VideoType(nullable=False),
316
+ }, []
317
+
318
+ def __next__(self) -> dict[str, Any]:
319
+ segment_path = str(TempStore.create_path(extension='.mp4'))
320
+ try:
321
+ cmd = av_utils.ffmpeg_clip_cmd(
322
+ str(self.video_path), segment_path, self.next_segment_start, self.segment_duration
323
+ )
324
+ _ = subprocess.run(cmd, capture_output=True, text=True, check=True)
325
+
326
+ # use the actual duration
327
+ segment_duration = av_utils.get_video_duration(segment_path)
328
+ if segment_duration - self.overlap == 0.0:
329
+ # we're done
330
+ Path(segment_path).unlink()
331
+ raise StopIteration
332
+
333
+ if segment_duration < self.min_segment_duration:
334
+ Path(segment_path).unlink()
335
+ raise StopIteration
336
+
337
+ segment_end = self.next_segment_start + segment_duration
338
+ segment_end_pts = self.next_segment_start_pts + round(segment_duration / self.video_time_base)
339
+
340
+ result = {
341
+ 'segment_start': self.next_segment_start,
342
+ 'segment_start_pts': self.next_segment_start_pts,
343
+ 'segment_end': segment_end,
344
+ 'segment_end_pts': segment_end_pts,
345
+ 'video_segment': segment_path,
346
+ }
347
+ self.next_segment_start = segment_end - self.overlap
348
+ self.next_segment_start_pts = segment_end_pts - round(self.overlap / self.video_time_base)
349
+
350
+ return result
351
+
352
+ except subprocess.CalledProcessError as e:
353
+ if Path(segment_path).exists():
354
+ Path(segment_path).unlink()
355
+ error_msg = f'ffmpeg failed with return code {e.returncode}'
356
+ if e.stderr:
357
+ error_msg += f': {e.stderr.strip()}'
358
+ raise pxt.Error(error_msg) from e
359
+
360
+ def close(self) -> None:
361
+ pass
362
+
363
+ def set_pos(self, pos: int) -> None:
364
+ pass
@@ -25,6 +25,7 @@ def create_system_info(engine: sql.engine.Engine) -> None:
25
25
  """Create the system metadata record"""
26
26
  system_md = SystemInfoMd(schema_version=VERSION)
27
27
  record = SystemInfo(md=dataclasses.asdict(system_md))
28
+ _logger.debug(f'Creating pixeltable system info record {record}')
28
29
  with orm.Session(engine, future=True) as session:
29
30
  # Write system metadata only once for idempotency
30
31
  if session.query(SystemInfo).count() == 0:
@@ -54,7 +55,8 @@ for _, modname, _ in pkgutil.iter_modules([os.path.dirname(__file__) + '/convert
54
55
  def upgrade_md(engine: sql.engine.Engine) -> None:
55
56
  """Upgrade the metadata schema to the current version"""
56
57
  with orm.Session(engine) as session:
57
- system_info = session.query(SystemInfo).one().md
58
+ # Get exclusive lock on SystemInfo row
59
+ system_info = session.query(SystemInfo).with_for_update().one().md
58
60
  md_version = system_info['schema_version']
59
61
  assert isinstance(md_version, int)
60
62
  _logger.info(f'Current database version: {md_version}, installed version: {VERSION}')
@@ -0,0 +1,3 @@
1
+ from .mypy_plugin import plugin
2
+
3
+ __all__ = ['plugin']
@@ -0,0 +1,123 @@
1
+ from typing import Callable, ClassVar, Optional
2
+
3
+ from mypy import nodes
4
+ from mypy.plugin import AnalyzeTypeContext, ClassDefContext, FunctionContext, MethodSigContext, Plugin
5
+ from mypy.plugins.common import add_attribute_to_class, add_method_to_class
6
+ from mypy.types import AnyType, FunctionLike, Instance, NoneType, Type, TypeOfAny
7
+
8
+ import pixeltable as pxt
9
+ from pixeltable import exprs
10
+
11
+
12
+ class PxtPlugin(Plugin):
13
+ __UDA_FULLNAME = f'{pxt.uda.__module__}.{pxt.uda.__name__}'
14
+ __ARRAY_GETITEM_FULLNAME = f'{pxt.Array.__module__}.{pxt.Array.__name__}.__class_getitem__'
15
+ __ADD_COLUMN_FULLNAME = f'{pxt.Table.__module__}.{pxt.Table.__name__}.{pxt.Table.add_column.__name__}'
16
+ __ADD_COMPUTED_COLUMN_FULLNAME = (
17
+ f'{pxt.Table.__module__}.{pxt.Table.__name__}.{pxt.Table.add_computed_column.__name__}'
18
+ )
19
+ __TYPE_MAP: ClassVar[dict] = {
20
+ pxt.Json: 'typing.Any',
21
+ pxt.Array: 'numpy.ndarray',
22
+ pxt.Image: 'PIL.Image.Image',
23
+ pxt.Video: 'builtins.str',
24
+ pxt.Audio: 'builtins.str',
25
+ pxt.Document: 'builtins.str',
26
+ }
27
+ __FULLNAME_MAP: ClassVar[dict] = {f'{k.__module__}.{k.__name__}': v for k, v in __TYPE_MAP.items()}
28
+
29
+ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
30
+ return adjust_uda_type
31
+
32
+ def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
33
+ if fullname in self.__FULLNAME_MAP:
34
+ subst_name = self.__FULLNAME_MAP[fullname]
35
+ return lambda ctx: adjust_pxt_type(ctx, subst_name)
36
+ return None
37
+
38
+ def get_method_signature_hook(self, fullname: str) -> Optional[Callable[[MethodSigContext], FunctionLike]]:
39
+ if fullname in (self.__ADD_COLUMN_FULLNAME, self.__ADD_COMPUTED_COLUMN_FULLNAME):
40
+ return adjust_kwargs
41
+ return None
42
+
43
+ def get_class_decorator_hook_2(self, fullname: str) -> Optional[Callable[[ClassDefContext], bool]]:
44
+ if fullname == self.__UDA_FULLNAME:
45
+ return adjust_uda_methods
46
+ return None
47
+
48
+
49
+ def plugin(version: str) -> type:
50
+ return PxtPlugin
51
+
52
+
53
+ _AGGREGATOR_FULLNAME = f'{pxt.Aggregator.__module__}.{pxt.Aggregator.__name__}'
54
+ _FN_CALL_FULLNAME = f'{exprs.Expr.__module__}.{exprs.Expr.__name__}'
55
+
56
+
57
+ def adjust_uda_type(ctx: FunctionContext) -> Type:
58
+ """
59
+ Mypy doesn't understand that a class with a @uda decorator isn't actually a class, so it assumes
60
+ that sum(expr), for example, actually returns an instance of sum. We correct this by changing the
61
+ return type of any subclass of `Aggregator` to `FunctionCall`.
62
+ """
63
+ ret_type = ctx.default_return_type
64
+ if isinstance(ret_type, Instance) and (
65
+ ret_type.type.fullname == _AGGREGATOR_FULLNAME
66
+ or any(base.type.fullname == _AGGREGATOR_FULLNAME for base in ret_type.type.bases)
67
+ ):
68
+ ret_type = AnyType(TypeOfAny.special_form)
69
+ return ret_type
70
+
71
+
72
+ def adjust_pxt_type(ctx: AnalyzeTypeContext, subst_name: str) -> Type:
73
+ """
74
+ Replaces the special Pixeltable classes (such as pxt.Array) with their standard equivalents (such as np.ndarray).
75
+ """
76
+ if subst_name == 'typing.Any':
77
+ return AnyType(TypeOfAny.special_form)
78
+ return ctx.api.named_type(subst_name, [])
79
+
80
+
81
+ def adjust_kwargs(ctx: MethodSigContext) -> FunctionLike:
82
+ """
83
+ Mypy has a "feature" where it will spit out multiple warnings if a method with signature
84
+ ```
85
+ def my_func(*, arg1: int, arg2: str, **kwargs: Expr)
86
+ ```
87
+ (for example) is called with bare kwargs:
88
+ ```
89
+ my_func(my_kwarg=value)
90
+ ```
91
+ This is a disaster for type-checking of add_column and add_computed_column. Here we adjust the signature so
92
+ that mypy thinks it is simply
93
+ ```
94
+ def my_func(**kwargs: Any)
95
+ ```
96
+ thereby avoiding any type-checking errors. For details, see: <https://github.com/python/mypy/issues/18481>
97
+ """
98
+ sig = ctx.default_signature
99
+ new_arg_names = sig.arg_names[-1:]
100
+ new_arg_types = [AnyType(TypeOfAny.special_form)]
101
+ new_arg_kinds = sig.arg_kinds[-1:]
102
+ return sig.copy_modified(arg_names=new_arg_names, arg_types=new_arg_types, arg_kinds=new_arg_kinds)
103
+
104
+
105
+ def adjust_uda_methods(ctx: ClassDefContext) -> bool:
106
+ """
107
+ Mypy does not handle the `@pxt.uda` aggregator well; it continues to treat the decorated class as a class,
108
+ even though it has been replaced by an `AggregateFunction`. Here we add static methods to the class that
109
+ imitate various (instance) methods of `AggregateFunction` so that they can be properly type-checked.
110
+ """
111
+ list_type = ctx.api.named_type('builtins.list', [AnyType(TypeOfAny.special_form)])
112
+ fn_arg = nodes.Argument(nodes.Var('fn'), AnyType(TypeOfAny.special_form), None, nodes.ARG_POS)
113
+ args_arg = nodes.Argument(nodes.Var('args'), AnyType(TypeOfAny.special_form), None, nodes.ARG_STAR)
114
+ kwargs_arg = nodes.Argument(nodes.Var('kwargs'), AnyType(TypeOfAny.special_form), None, nodes.ARG_STAR2)
115
+ add_method_to_class(ctx.api, ctx.cls, '__init__', args=[args_arg, kwargs_arg], return_type=NoneType())
116
+ add_method_to_class(
117
+ ctx.api, ctx.cls, 'to_sql', args=[fn_arg], return_type=AnyType(TypeOfAny.special_form), is_staticmethod=True
118
+ )
119
+ add_method_to_class(
120
+ ctx.api, ctx.cls, 'overload', args=[fn_arg], return_type=AnyType(TypeOfAny.special_form), is_staticmethod=True
121
+ )
122
+ add_attribute_to_class(ctx.api, ctx.cls, 'signatures', typ=list_type, is_classvar=True)
123
+ return True
pixeltable/type_system.py CHANGED
@@ -9,8 +9,11 @@ import types
9
9
  import typing
10
10
  import urllib.parse
11
11
  import urllib.request
12
+ from pathlib import Path
12
13
  from typing import Any, ClassVar, Iterable, Literal, Mapping, Optional, Sequence, Union
13
14
 
15
+ from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
16
+
14
17
  import av
15
18
  import jsonschema
16
19
  import jsonschema.protocols
@@ -24,8 +27,6 @@ from typing_extensions import _AnnotatedAlias
24
27
  import pixeltable.exceptions as excs
25
28
  from pixeltable.utils import parse_local_file_path
26
29
 
27
- from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
28
-
29
30
 
30
31
  class ColumnType:
31
32
  @enum.unique
@@ -292,7 +293,11 @@ class ColumnType:
292
293
 
293
294
  @classmethod
294
295
  def from_python_type(
295
- cls, t: type | _GenericAlias, nullable_default: bool = False, allow_builtin_types: bool = True
296
+ cls,
297
+ t: type | _GenericAlias,
298
+ nullable_default: bool = False,
299
+ allow_builtin_types: bool = True,
300
+ infer_pydantic_json: bool = False,
296
301
  ) -> Optional[ColumnType]:
297
302
  """
298
303
  Convert a Python type into a Pixeltable `ColumnType` instance.
@@ -305,6 +310,8 @@ class ColumnType:
305
310
  allowed (as in UDF definitions). If False, then only Pixeltable types such as `pxt.String`,
306
311
  `pxt.Int`, etc., will be allowed (as in schema definitions). `Optional` and `Required`
307
312
  designations will be allowed regardless.
313
+ infer_pydantic_json: If True, accepts an extended set of built-ins (eg, Enum, Path) and returns the type to
314
+ which pydantic.BaseModel.model_dump(mode='json') serializes it.
308
315
  """
309
316
  origin = typing.get_origin(t)
310
317
  type_args = typing.get_args(t)
@@ -314,7 +321,9 @@ class ColumnType:
314
321
  # `t` is a type of the form Optional[T] (equivalently, T | None or None | T).
315
322
  # We treat it as the underlying type but with nullable=True.
316
323
  underlying_py_type = type_args[0] if type_args[1] is type(None) else type_args[1]
317
- underlying = cls.from_python_type(underlying_py_type, allow_builtin_types=allow_builtin_types)
324
+ underlying = cls.from_python_type(
325
+ underlying_py_type, allow_builtin_types=allow_builtin_types, infer_pydantic_json=infer_pydantic_json
326
+ )
318
327
  if underlying is not None:
319
328
  return underlying.copy(nullable=True)
320
329
  elif origin is Required:
@@ -341,6 +350,13 @@ class ColumnType:
341
350
  if literal_type is None:
342
351
  return None
343
352
  return literal_type.copy(nullable=(literal_type.nullable or nullable_default))
353
+ if infer_pydantic_json and isinstance(t, type) and issubclass(t, enum.Enum):
354
+ literal_type = cls.infer_common_literal_type(member.value for member in t)
355
+ if literal_type is None:
356
+ return None
357
+ return literal_type.copy(nullable=(literal_type.nullable or nullable_default))
358
+ if infer_pydantic_json and t is Path:
359
+ return StringType(nullable=nullable_default)
344
360
  if t is str:
345
361
  return StringType(nullable=nullable_default)
346
362
  if t is int: