pixeltable 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +1 -1
- pixeltable/catalog/dir.py +6 -0
- pixeltable/catalog/globals.py +13 -0
- pixeltable/catalog/named_function.py +4 -0
- pixeltable/catalog/path_dict.py +37 -11
- pixeltable/catalog/schema_object.py +6 -0
- pixeltable/catalog/table.py +22 -5
- pixeltable/catalog/table_version.py +22 -8
- pixeltable/dataframe.py +201 -3
- pixeltable/env.py +9 -3
- pixeltable/exec/expr_eval_node.py +1 -1
- pixeltable/exec/sql_node.py +2 -2
- pixeltable/exprs/function_call.py +134 -24
- pixeltable/exprs/inline_expr.py +22 -2
- pixeltable/exprs/row_builder.py +1 -1
- pixeltable/exprs/similarity_expr.py +9 -2
- pixeltable/func/aggregate_function.py +148 -68
- pixeltable/func/callable_function.py +49 -13
- pixeltable/func/expr_template_function.py +55 -24
- pixeltable/func/function.py +183 -22
- pixeltable/func/function_registry.py +2 -1
- pixeltable/func/query_template_function.py +11 -6
- pixeltable/func/signature.py +64 -7
- pixeltable/func/udf.py +57 -35
- pixeltable/functions/globals.py +54 -34
- pixeltable/functions/json.py +3 -8
- pixeltable/functions/ollama.py +4 -4
- pixeltable/functions/timestamp.py +1 -1
- pixeltable/functions/video.py +2 -8
- pixeltable/functions/vision.py +1 -1
- pixeltable/globals.py +218 -59
- pixeltable/index/embedding_index.py +44 -24
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_16.py +2 -1
- pixeltable/metadata/converters/convert_17.py +2 -1
- pixeltable/metadata/converters/convert_23.py +35 -0
- pixeltable/metadata/converters/convert_24.py +47 -0
- pixeltable/metadata/converters/util.py +4 -2
- pixeltable/metadata/notes.py +2 -0
- pixeltable/metadata/schema.py +1 -0
- pixeltable/tool/create_test_db_dump.py +11 -0
- pixeltable/tool/doc_plugins/griffe.py +4 -3
- pixeltable/type_system.py +180 -45
- {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/METADATA +3 -2
- {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/RECORD +49 -47
- {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/entry_points.txt +0 -0
|
@@ -15,6 +15,7 @@ import pixeltable.type_system as ts
|
|
|
15
15
|
from .data_row import DataRow
|
|
16
16
|
from .expr import Expr
|
|
17
17
|
from .inline_expr import InlineDict, InlineList
|
|
18
|
+
from .literal import Literal
|
|
18
19
|
from .row_builder import RowBuilder
|
|
19
20
|
from .rowid_ref import RowidRef
|
|
20
21
|
from .sql_element_cache import SqlElementCache
|
|
@@ -34,6 +35,7 @@ class FunctionCall(Expr):
|
|
|
34
35
|
|
|
35
36
|
arg_types: list[ts.ColumnType]
|
|
36
37
|
kwarg_types: dict[str, ts.ColumnType]
|
|
38
|
+
return_type: ts.ColumnType
|
|
37
39
|
group_by_start_idx: int
|
|
38
40
|
group_by_stop_idx: int
|
|
39
41
|
fn_expr_idx: int
|
|
@@ -43,17 +45,25 @@ class FunctionCall(Expr):
|
|
|
43
45
|
current_partition_vals: Optional[list[Any]]
|
|
44
46
|
|
|
45
47
|
def __init__(
|
|
46
|
-
|
|
47
|
-
|
|
48
|
+
self,
|
|
49
|
+
fn: func.Function,
|
|
50
|
+
bound_args: dict[str, Any],
|
|
51
|
+
return_type: ts.ColumnType,
|
|
52
|
+
order_by_clause: Optional[list[Any]] = None,
|
|
53
|
+
group_by_clause: Optional[list[Any]] = None,
|
|
54
|
+
is_method_call: bool = False
|
|
55
|
+
):
|
|
48
56
|
if order_by_clause is None:
|
|
49
57
|
order_by_clause = []
|
|
50
58
|
if group_by_clause is None:
|
|
51
59
|
group_by_clause = []
|
|
52
|
-
|
|
53
|
-
|
|
60
|
+
|
|
61
|
+
assert not fn.is_polymorphic
|
|
62
|
+
|
|
54
63
|
self.fn = fn
|
|
55
64
|
self.is_method_call = is_method_call
|
|
56
|
-
|
|
65
|
+
|
|
66
|
+
signature = fn.signature
|
|
57
67
|
|
|
58
68
|
# If `return_type` is non-nullable, but the function call has a nullable input to any of its non-nullable
|
|
59
69
|
# parameters, then we need to make it nullable. This is because Pixeltable defaults a function output to
|
|
@@ -67,6 +77,8 @@ class FunctionCall(Expr):
|
|
|
67
77
|
return_type = return_type.copy(nullable=True)
|
|
68
78
|
break
|
|
69
79
|
|
|
80
|
+
self.return_type = return_type
|
|
81
|
+
|
|
70
82
|
super().__init__(return_type)
|
|
71
83
|
|
|
72
84
|
self.agg_init_args = {}
|
|
@@ -74,9 +86,9 @@ class FunctionCall(Expr):
|
|
|
74
86
|
# we separate out the init args for the aggregator
|
|
75
87
|
assert isinstance(fn, func.AggregateFunction)
|
|
76
88
|
self.agg_init_args = {
|
|
77
|
-
arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names
|
|
89
|
+
arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names[0]
|
|
78
90
|
}
|
|
79
|
-
bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names}
|
|
91
|
+
bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names[0]}
|
|
80
92
|
|
|
81
93
|
# construct components, args, kwargs
|
|
82
94
|
self.args = []
|
|
@@ -88,7 +100,7 @@ class FunctionCall(Expr):
|
|
|
88
100
|
|
|
89
101
|
# the prefix of parameters that are bound can be passed by position
|
|
90
102
|
processed_args: set[str] = set()
|
|
91
|
-
for py_param in
|
|
103
|
+
for py_param in signature.py_signature.parameters.values():
|
|
92
104
|
if py_param.name not in bound_args or py_param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
93
105
|
break
|
|
94
106
|
arg = bound_args[py_param.name]
|
|
@@ -110,7 +122,7 @@ class FunctionCall(Expr):
|
|
|
110
122
|
self.components.append(arg.copy())
|
|
111
123
|
else:
|
|
112
124
|
self.kwargs[param_name] = (None, arg)
|
|
113
|
-
if
|
|
125
|
+
if signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
|
|
114
126
|
self.kwarg_types[param_name] = signature.parameters[param_name].col_type
|
|
115
127
|
|
|
116
128
|
# window function state:
|
|
@@ -129,7 +141,7 @@ class FunctionCall(Expr):
|
|
|
129
141
|
|
|
130
142
|
if isinstance(self.fn, func.ExprTemplateFunction):
|
|
131
143
|
# we instantiate the template to create an Expr that can be evaluated and record that as a component
|
|
132
|
-
fn_expr = self.fn.instantiate(
|
|
144
|
+
fn_expr = self.fn.instantiate([], bound_args)
|
|
133
145
|
self.components.append(fn_expr)
|
|
134
146
|
self.fn_expr_idx = len(self.components) - 1
|
|
135
147
|
else:
|
|
@@ -360,7 +372,7 @@ class FunctionCall(Expr):
|
|
|
360
372
|
"""
|
|
361
373
|
assert self.is_agg_fn_call
|
|
362
374
|
assert isinstance(self.fn, func.AggregateFunction)
|
|
363
|
-
self.aggregator = self.fn.
|
|
375
|
+
self.aggregator = self.fn.agg_class(**self.agg_init_args)
|
|
364
376
|
|
|
365
377
|
def update(self, data_row: DataRow) -> None:
|
|
366
378
|
"""
|
|
@@ -432,27 +444,32 @@ class FunctionCall(Expr):
|
|
|
432
444
|
data_row[self.slot_idx] = self.fn.py_fn(*args, **kwargs)
|
|
433
445
|
elif self.is_window_fn_call:
|
|
434
446
|
assert isinstance(self.fn, func.AggregateFunction)
|
|
447
|
+
agg_cls = self.fn.agg_class
|
|
435
448
|
if self.has_group_by():
|
|
436
449
|
if self.current_partition_vals is None:
|
|
437
450
|
self.current_partition_vals = [None] * len(self.group_by)
|
|
438
451
|
partition_vals = [data_row[e.slot_idx] for e in self.group_by]
|
|
439
452
|
if partition_vals != self.current_partition_vals:
|
|
440
453
|
# new partition
|
|
441
|
-
self.aggregator =
|
|
454
|
+
self.aggregator = agg_cls(**self.agg_init_args)
|
|
442
455
|
self.current_partition_vals = partition_vals
|
|
443
456
|
elif self.aggregator is None:
|
|
444
|
-
self.aggregator =
|
|
457
|
+
self.aggregator = agg_cls(**self.agg_init_args)
|
|
445
458
|
self.aggregator.update(*args)
|
|
446
459
|
data_row[self.slot_idx] = self.aggregator.value()
|
|
447
460
|
else:
|
|
448
|
-
data_row[self.slot_idx] = self.fn.exec(
|
|
461
|
+
data_row[self.slot_idx] = self.fn.exec(args, kwargs)
|
|
449
462
|
|
|
450
463
|
def _as_dict(self) -> dict:
|
|
451
464
|
result = {
|
|
452
|
-
'fn': self.fn.as_dict(),
|
|
453
|
-
'
|
|
465
|
+
'fn': self.fn.as_dict(),
|
|
466
|
+
'args': self.args,
|
|
467
|
+
'kwargs': self.kwargs,
|
|
468
|
+
'return_type': self.return_type.as_dict(),
|
|
469
|
+
'group_by_start_idx': self.group_by_start_idx,
|
|
470
|
+
'group_by_stop_idx': self.group_by_stop_idx,
|
|
454
471
|
'order_by_start_idx': self.order_by_start_idx,
|
|
455
|
-
**super()._as_dict()
|
|
472
|
+
**super()._as_dict(),
|
|
456
473
|
}
|
|
457
474
|
return result
|
|
458
475
|
|
|
@@ -461,15 +478,108 @@ class FunctionCall(Expr):
|
|
|
461
478
|
assert 'fn' in d
|
|
462
479
|
assert 'args' in d
|
|
463
480
|
assert 'kwargs' in d
|
|
464
|
-
|
|
481
|
+
|
|
465
482
|
fn = func.Function.from_dict(d['fn'])
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
bound_args.update(
|
|
469
|
-
{param_name: val if idx is None else components[idx] for param_name, (idx, val) in d['kwargs'].items()})
|
|
483
|
+
assert not fn.is_polymorphic
|
|
484
|
+
return_type = ts.ColumnType.from_dict(d['return_type']) if 'return_type' in d else None
|
|
470
485
|
group_by_exprs = components[d['group_by_start_idx']:d['group_by_stop_idx']]
|
|
471
486
|
order_by_exprs = components[d['order_by_start_idx']:]
|
|
487
|
+
|
|
488
|
+
args = [
|
|
489
|
+
expr if idx is None else components[idx]
|
|
490
|
+
for idx, expr in d['args']
|
|
491
|
+
]
|
|
492
|
+
kwargs = {
|
|
493
|
+
param_name: (expr if idx is None else components[idx])
|
|
494
|
+
for param_name, (idx, expr) in d['kwargs'].items()
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
# `Function.from_dict()` does signature matching, so it is safe to assume that `args` and `kwargs` are
|
|
498
|
+
# consistent with its signature.
|
|
499
|
+
|
|
500
|
+
# Reassemble bound_args. Note that args and kwargs represent "already bound arguments": they are not bindable
|
|
501
|
+
# in the Python sense, because variable args (such as *args and **kwargs) have already been condensed.
|
|
502
|
+
param_names = list(fn.signature.parameters.keys())
|
|
503
|
+
bound_args = {param_names[i]: arg for i, arg in enumerate(args)}
|
|
504
|
+
bound_args.update(kwargs.items())
|
|
505
|
+
|
|
506
|
+
# TODO: In order to properly invoke call_return_type, we need to ensure that any InlineLists or InlineDicts
|
|
507
|
+
# in bound_args are unpacked into Python lists/dicts. There is an open task to ensure this is true in general;
|
|
508
|
+
# for now, as a hack, we do the unpacking here for the specific case of an InlineList of Literals (the only
|
|
509
|
+
# case where this is necessary to support existing conditional_return_type implementations). Once the general
|
|
510
|
+
# pattern is implemented, we can remove this hack.
|
|
511
|
+
unpacked_bound_args = {
|
|
512
|
+
param_name: cls.__unpack_bound_arg(arg) for param_name, arg in bound_args.items()
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
# Evaluate the call_return_type as defined in the current codebase.
|
|
516
|
+
call_return_type = fn.call_return_type([], unpacked_bound_args)
|
|
517
|
+
|
|
518
|
+
if return_type is None:
|
|
519
|
+
# Schema versions prior to 25 did not store the return_type in metadata, and there is no obvious way to
|
|
520
|
+
# infer it during DB migration, so we might encounter a stored return_type of None. In that case, we use
|
|
521
|
+
# the call_return_type that we just inferred (which matches the deserialization behavior prior to
|
|
522
|
+
# version 25).
|
|
523
|
+
return_type = call_return_type
|
|
524
|
+
else:
|
|
525
|
+
# There is a return_type stored in metadata (schema version >= 25).
|
|
526
|
+
# Check that the stored return_type of the UDF call matches the column type of the FunctionCall, and
|
|
527
|
+
# fail-fast if it doesn't (otherwise we risk getting downstream database errors).
|
|
528
|
+
# TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
|
|
529
|
+
# mark this FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or Function
|
|
530
|
+
# signature mismatch.
|
|
531
|
+
if not return_type.is_supertype_of(call_return_type, ignore_nullable=True):
|
|
532
|
+
raise excs.Error(
|
|
533
|
+
f'The return type stored in the database for a UDF call to `{fn.self_path}` no longer matches the '
|
|
534
|
+
f'return type of the UDF as currently defined in the code.\nThis probably means that the code for '
|
|
535
|
+
f'`{fn.self_path}` has changed in a backward-incompatible way.\n'
|
|
536
|
+
f'Return type in database: `{return_type}`\n'
|
|
537
|
+
f'Return type as currently defined: `{call_return_type}`'
|
|
538
|
+
)
|
|
539
|
+
|
|
472
540
|
fn_call = cls(
|
|
473
|
-
|
|
474
|
-
|
|
541
|
+
fn,
|
|
542
|
+
bound_args,
|
|
543
|
+
return_type,
|
|
544
|
+
group_by_clause=group_by_exprs,
|
|
545
|
+
order_by_clause=order_by_exprs
|
|
546
|
+
)
|
|
475
547
|
return fn_call
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def __find_matching_signature(cls, fn: func.Function, args: list[Any], kwargs: dict[str, Any]) -> Optional[int]:
|
|
551
|
+
for idx, sig in enumerate(fn.signatures):
|
|
552
|
+
if cls.__signature_matches(sig, args, kwargs):
|
|
553
|
+
return idx
|
|
554
|
+
return None
|
|
555
|
+
|
|
556
|
+
@classmethod
|
|
557
|
+
def __signature_matches(cls, sig: func.Signature, args: list[Any], kwargs: dict[str, Any]) -> bool:
|
|
558
|
+
unbound_parameters = set(sig.parameters.keys())
|
|
559
|
+
for i, arg in enumerate(args):
|
|
560
|
+
if i >= len(sig.parameters_by_pos):
|
|
561
|
+
return False
|
|
562
|
+
param = sig.parameters_by_pos[i]
|
|
563
|
+
arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
|
|
564
|
+
if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
|
|
565
|
+
return False
|
|
566
|
+
unbound_parameters.remove(param.name)
|
|
567
|
+
for param_name, arg in kwargs.items():
|
|
568
|
+
if param_name not in unbound_parameters:
|
|
569
|
+
return False
|
|
570
|
+
param = sig.parameters[param_name]
|
|
571
|
+
arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
|
|
572
|
+
if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
|
|
573
|
+
return False
|
|
574
|
+
unbound_parameters.remove(param_name)
|
|
575
|
+
for param_name in unbound_parameters:
|
|
576
|
+
param = sig.parameters[param_name]
|
|
577
|
+
if not param.has_default:
|
|
578
|
+
return False
|
|
579
|
+
return True
|
|
580
|
+
|
|
581
|
+
@classmethod
|
|
582
|
+
def __unpack_bound_arg(cls, arg: Any) -> Any:
|
|
583
|
+
if isinstance(arg, InlineList) and all(isinstance(el, Literal) for el in arg.components):
|
|
584
|
+
return [el.val for el in arg.components]
|
|
585
|
+
return arg
|
pixeltable/exprs/inline_expr.py
CHANGED
|
@@ -101,7 +101,13 @@ class InlineList(Expr):
|
|
|
101
101
|
else:
|
|
102
102
|
exprs.append(Literal(el))
|
|
103
103
|
|
|
104
|
-
|
|
104
|
+
json_schema = {
|
|
105
|
+
'type': 'array',
|
|
106
|
+
'prefixItems': [expr.col_type.to_json_schema() for expr in exprs],
|
|
107
|
+
'items': False # No additional items (fixed length)
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
super().__init__(ts.JsonType(json_schema))
|
|
105
111
|
self.components.extend(exprs)
|
|
106
112
|
self.id = self._create_id()
|
|
107
113
|
|
|
@@ -149,7 +155,21 @@ class InlineDict(Expr):
|
|
|
149
155
|
else:
|
|
150
156
|
exprs.append(Literal(val))
|
|
151
157
|
|
|
152
|
-
|
|
158
|
+
json_schema: Optional[dict[str, Any]]
|
|
159
|
+
try:
|
|
160
|
+
json_schema = {
|
|
161
|
+
'type': 'object',
|
|
162
|
+
'properties': {
|
|
163
|
+
key: expr.col_type.to_json_schema()
|
|
164
|
+
for key, expr in zip(self.keys, exprs)
|
|
165
|
+
},
|
|
166
|
+
}
|
|
167
|
+
except excs.Error:
|
|
168
|
+
# InlineDicts are used to store iterator arguments, which are not required to be valid JSON types,
|
|
169
|
+
# so we can't always construct a valid schema.
|
|
170
|
+
json_schema = None
|
|
171
|
+
|
|
172
|
+
super().__init__(ts.JsonType(json_schema))
|
|
153
173
|
self.components.extend(exprs)
|
|
154
174
|
self.id = self._create_id()
|
|
155
175
|
|
pixeltable/exprs/row_builder.py
CHANGED
|
@@ -368,7 +368,7 @@ class RowBuilder:
|
|
|
368
368
|
if not ignore_errors:
|
|
369
369
|
input_vals = [data_row[d.slot_idx] for d in expr.dependencies()]
|
|
370
370
|
raise excs.ExprEvalError(
|
|
371
|
-
expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0)
|
|
371
|
+
expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0) from exc
|
|
372
372
|
|
|
373
373
|
def create_table_row(self, data_row: DataRow, exc_col_ids: set[int]) -> tuple[dict[str, Any], int]:
|
|
374
374
|
"""Create a table row from the slots that have an output column assigned
|
|
@@ -23,7 +23,6 @@ class SimilarityExpr(Expr):
|
|
|
23
23
|
assert item_expr.col_type.is_string_type() or item_expr.col_type.is_image_type()
|
|
24
24
|
|
|
25
25
|
self.components = [col_ref, item_expr]
|
|
26
|
-
self.id = self._create_id()
|
|
27
26
|
|
|
28
27
|
# determine index to use
|
|
29
28
|
idx_info = col_ref.col.get_idx_info()
|
|
@@ -54,10 +53,14 @@ class SimilarityExpr(Expr):
|
|
|
54
53
|
raise excs.Error(
|
|
55
54
|
f'Embedding index {self.idx_info.name!r} on column {self.idx_info.col.name!r} was created without the '
|
|
56
55
|
f"'image_embed' parameter and does not support image queries")
|
|
56
|
+
self.id = self._create_id()
|
|
57
57
|
|
|
58
58
|
def __repr__(self) -> str:
|
|
59
59
|
return f'{self.components[0]}.similarity({self.components[1]})'
|
|
60
60
|
|
|
61
|
+
def _id_attrs(self):
|
|
62
|
+
return super()._id_attrs() + [('idx_name', self.idx_info.name)]
|
|
63
|
+
|
|
61
64
|
def default_column_name(self) -> str:
|
|
62
65
|
return 'similarity'
|
|
63
66
|
|
|
@@ -81,8 +84,12 @@ class SimilarityExpr(Expr):
|
|
|
81
84
|
# this should never get called
|
|
82
85
|
assert False
|
|
83
86
|
|
|
87
|
+
def _as_dict(self) -> dict:
|
|
88
|
+
return {'idx_name': self.idx_info.name, **super()._as_dict()}
|
|
89
|
+
|
|
84
90
|
@classmethod
|
|
85
91
|
def _from_dict(cls, d: dict, components: list[Expr]) -> 'SimilarityExpr':
|
|
92
|
+
iname = d['idx_name'] if 'idx_name' in d else None
|
|
86
93
|
assert len(components) == 2
|
|
87
94
|
assert isinstance(components[0], ColumnRef)
|
|
88
|
-
return cls(components[0], components[1])
|
|
95
|
+
return cls(components[0], components[1], idx_name=iname)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import inspect
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
8
|
import pixeltable.type_system as ts
|
|
@@ -35,26 +35,73 @@ class AggregateFunction(Function):
|
|
|
35
35
|
GROUP_BY_PARAM = 'group_by'
|
|
36
36
|
RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
|
|
37
37
|
|
|
38
|
+
agg_classes: list[type[Aggregator]] # classes for each signature, in signature order
|
|
39
|
+
init_param_names: list[list[str]] # names of the __init__ parameters for each signature
|
|
40
|
+
|
|
38
41
|
def __init__(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
self,
|
|
43
|
+
agg_class: type[Aggregator],
|
|
44
|
+
type_substitutions: Optional[Sequence[dict]],
|
|
45
|
+
self_path: str,
|
|
46
|
+
requires_order_by: bool,
|
|
47
|
+
allows_std_agg: bool,
|
|
48
|
+
allows_window: bool
|
|
49
|
+
) -> None:
|
|
50
|
+
if type_substitutions is None:
|
|
51
|
+
type_substitutions = [None] # single signature with no substitutions
|
|
52
|
+
self.agg_classes = [agg_class]
|
|
53
|
+
else:
|
|
54
|
+
self.agg_classes = [agg_class] * len(type_substitutions)
|
|
55
|
+
self.init_param_names = []
|
|
43
56
|
self.requires_order_by = requires_order_by
|
|
44
57
|
self.allows_std_agg = allows_std_agg
|
|
45
58
|
self.allows_window = allows_window
|
|
46
|
-
self.__doc__ =
|
|
59
|
+
self.__doc__ = agg_class.__doc__
|
|
60
|
+
|
|
61
|
+
signatures: list[Signature] = []
|
|
62
|
+
|
|
63
|
+
# If no type_substitutions were provided, construct a single signature for the class.
|
|
64
|
+
# Otherwise, construct one signature for each type substitution instance.
|
|
65
|
+
for subst in type_substitutions:
|
|
66
|
+
signature, init_param_names = self.__cls_to_signature(agg_class, subst)
|
|
67
|
+
signatures.append(signature)
|
|
68
|
+
self.init_param_names.append(init_param_names)
|
|
69
|
+
|
|
70
|
+
super().__init__(signatures, self_path=self_path)
|
|
71
|
+
|
|
72
|
+
def _update_as_overload_resolution(self, signature_idx: int) -> None:
|
|
73
|
+
self.agg_classes = [self.agg_classes[signature_idx]]
|
|
74
|
+
self.init_param_names = [self.init_param_names[signature_idx]]
|
|
75
|
+
|
|
76
|
+
def __cls_to_signature(
|
|
77
|
+
self, cls: type[Aggregator], type_substitutions: Optional[dict] = None
|
|
78
|
+
) -> tuple[Signature, list[str]]:
|
|
79
|
+
"""Inspects the Aggregator class to infer the corresponding function signature. Returns the
|
|
80
|
+
inferred signature along with the list of init_param_names (for downstream error handling).
|
|
81
|
+
"""
|
|
82
|
+
# infer type parameters; set return_type=InvalidType() because it has no meaning here
|
|
83
|
+
init_sig = Signature.create(py_fn=cls.__init__, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
|
|
84
|
+
update_sig = Signature.create(py_fn=cls.update, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
|
|
85
|
+
value_sig = Signature.create(py_fn=cls.value, is_cls_method=True, type_substitutions=type_substitutions)
|
|
86
|
+
|
|
87
|
+
init_types = [p.col_type for p in init_sig.parameters.values()]
|
|
88
|
+
update_types = [p.col_type for p in update_sig.parameters.values()]
|
|
89
|
+
value_type = value_sig.return_type
|
|
90
|
+
assert value_type is not None
|
|
91
|
+
|
|
92
|
+
if len(update_types) == 0:
|
|
93
|
+
raise excs.Error('update() must have at least one parameter')
|
|
47
94
|
|
|
48
95
|
# our signature is the signature of 'update', but without self,
|
|
49
96
|
# plus the parameters of 'init' as keyword-only parameters
|
|
50
|
-
py_update_params = list(inspect.signature(
|
|
97
|
+
py_update_params = list(inspect.signature(cls.update).parameters.values())[1:] # leave out self
|
|
51
98
|
assert len(py_update_params) == len(update_types)
|
|
52
99
|
update_params = [
|
|
53
100
|
Parameter(p.name, col_type=update_types[i], kind=p.kind, default=p.default)
|
|
54
101
|
for i, p in enumerate(py_update_params)
|
|
55
102
|
]
|
|
56
103
|
# starting at 1: leave out self
|
|
57
|
-
py_init_params = list(inspect.signature(
|
|
104
|
+
py_init_params = list(inspect.signature(cls.__init__).parameters.values())[1:]
|
|
58
105
|
assert len(py_init_params) == len(init_types)
|
|
59
106
|
init_params = [
|
|
60
107
|
Parameter(p.name, col_type=init_types[i], kind=inspect.Parameter.KEYWORD_ONLY, default=p.default)
|
|
@@ -67,23 +114,36 @@ class AggregateFunction(Function):
|
|
|
67
114
|
f'{", ".join(duplicate_params)}'
|
|
68
115
|
)
|
|
69
116
|
params = update_params + init_params # init_params are keyword-only and come last
|
|
117
|
+
init_param_names = [p.name for p in init_params]
|
|
70
118
|
|
|
71
|
-
|
|
72
|
-
super().__init__(signature, self_path=self_path)
|
|
73
|
-
self.init_param_names = [p.name for p in init_params]
|
|
119
|
+
return Signature(value_type, params), init_param_names
|
|
74
120
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
|
|
121
|
+
@property
|
|
122
|
+
def agg_class(self) -> type[Aggregator]:
|
|
123
|
+
assert not self.is_polymorphic
|
|
124
|
+
return self.agg_classes[0]
|
|
80
125
|
|
|
81
|
-
def exec(self,
|
|
126
|
+
def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
|
|
82
127
|
raise NotImplementedError
|
|
83
128
|
|
|
129
|
+
def overload(self, cls: type[Aggregator]) -> AggregateFunction:
|
|
130
|
+
if not isinstance(cls, type) or not issubclass(cls, Aggregator):
|
|
131
|
+
raise excs.Error(f'Invalid argument to @overload decorator: {cls}')
|
|
132
|
+
if self._has_resolved_fns:
|
|
133
|
+
raise excs.Error('New `overload` not allowed after the UDF has already been called')
|
|
134
|
+
if self._conditional_return_type is not None:
|
|
135
|
+
raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
|
|
136
|
+
sig, init_param_names = self.__cls_to_signature(cls)
|
|
137
|
+
self.signatures.append(sig)
|
|
138
|
+
self.agg_classes.append(cls)
|
|
139
|
+
self.init_param_names.append(init_param_names)
|
|
140
|
+
return self
|
|
141
|
+
|
|
84
142
|
def help_str(self) -> str:
|
|
85
143
|
res = super().help_str()
|
|
86
|
-
|
|
144
|
+
# We need to reference agg_classes[0] rather than agg_class here, because we want this to work even if the
|
|
145
|
+
# aggregator is polymorphic (in which case we use the docstring of the originally decorated UDA).
|
|
146
|
+
res += '\n\n' + inspect.getdoc(self.agg_classes[0].update)
|
|
87
147
|
return res
|
|
88
148
|
|
|
89
149
|
def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
|
|
@@ -121,18 +181,24 @@ class AggregateFunction(Function):
|
|
|
121
181
|
f'{self.display_name}(): group_by invalid with an aggregate function that does not allow windows')
|
|
122
182
|
group_by_clause = kwargs.pop(self.GROUP_BY_PARAM)
|
|
123
183
|
|
|
124
|
-
bound_args = self.
|
|
125
|
-
|
|
184
|
+
resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
|
|
185
|
+
return_type = resolved_fn.call_return_type(args, kwargs)
|
|
126
186
|
return exprs.FunctionCall(
|
|
127
|
-
|
|
187
|
+
resolved_fn,
|
|
188
|
+
bound_args,
|
|
189
|
+
return_type,
|
|
128
190
|
order_by_clause=[order_by_clause] if order_by_clause is not None else [],
|
|
129
|
-
group_by_clause=[group_by_clause] if group_by_clause is not None else []
|
|
191
|
+
group_by_clause=[group_by_clause] if group_by_clause is not None else []
|
|
192
|
+
)
|
|
130
193
|
|
|
131
194
|
def validate_call(self, bound_args: dict[str, Any]) -> None:
|
|
132
195
|
# check that init parameters are not Exprs
|
|
133
196
|
# TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
|
|
134
|
-
|
|
135
|
-
|
|
197
|
+
from pixeltable import exprs
|
|
198
|
+
|
|
199
|
+
assert not self.is_polymorphic
|
|
200
|
+
|
|
201
|
+
for param_name in self.init_param_names[0]:
|
|
136
202
|
if param_name in bound_args and isinstance(bound_args[param_name], exprs.Expr):
|
|
137
203
|
raise excs.Error(
|
|
138
204
|
f'{self.display_name}(): init() parameter {param_name} needs to be a constant, not a Pixeltable '
|
|
@@ -143,13 +209,23 @@ class AggregateFunction(Function):
|
|
|
143
209
|
return f'<Pixeltable Aggregator {self.name}>'
|
|
144
210
|
|
|
145
211
|
|
|
212
|
+
# Decorator invoked without parentheses: @pxt.uda
|
|
213
|
+
@overload
|
|
214
|
+
def uda(decorated_fn: Callable) -> AggregateFunction: ...
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
|
|
218
|
+
@overload
|
|
146
219
|
def uda(
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
) -> Callable[[type[Aggregator]], AggregateFunction]:
|
|
220
|
+
*,
|
|
221
|
+
requires_order_by: bool = False,
|
|
222
|
+
allows_std_agg: bool = True,
|
|
223
|
+
allows_window: bool = False,
|
|
224
|
+
type_substitutions: Optional[Sequence[dict]] = None
|
|
225
|
+
) -> Callable[[type[Aggregator]], AggregateFunction]: ...
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def uda(*args, **kwargs):
|
|
153
229
|
"""Decorator for user-defined aggregate functions.
|
|
154
230
|
|
|
155
231
|
The decorated class must inherit from Aggregator and implement the following methods:
|
|
@@ -161,46 +237,50 @@ def uda(
|
|
|
161
237
|
to the module where the class is defined.
|
|
162
238
|
|
|
163
239
|
Parameters:
|
|
164
|
-
- init_types: list of types for the __init__() parameters; must match the number of parameters
|
|
165
|
-
- update_types: list of types for the update() parameters; must match the number of parameters
|
|
166
|
-
- value_type: return type of the aggregator
|
|
167
240
|
- requires_order_by: if True, the first parameter to the function is the order-by expression
|
|
168
241
|
- allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
|
|
169
242
|
- allows_window: if True, the function can be used with a window
|
|
170
243
|
"""
|
|
171
|
-
if
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
|
|
199
|
-
# do the path validation at the very end, in order to be able to write tests for the other failure cases
|
|
200
|
-
validate_symbol_path(class_path)
|
|
201
|
-
#module = importlib.import_module(cls.__module__)
|
|
202
|
-
#setattr(module, name, instance)
|
|
203
|
-
|
|
204
|
-
return instance
|
|
244
|
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
245
|
+
|
|
246
|
+
# Decorator invoked without parentheses: @pxt.uda
|
|
247
|
+
# Simply call make_aggregator with defaults.
|
|
248
|
+
return make_aggregator(cls=args[0])
|
|
249
|
+
|
|
250
|
+
else:
|
|
251
|
+
|
|
252
|
+
# Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
|
|
253
|
+
# Create a decorator for the specified schema.
|
|
254
|
+
requires_order_by = kwargs.pop('requires_order_by', False)
|
|
255
|
+
allows_std_agg = kwargs.pop('allows_std_agg', True)
|
|
256
|
+
allows_window = kwargs.pop('allows_window', False)
|
|
257
|
+
type_substitutions = kwargs.pop('type_substitutions', None)
|
|
258
|
+
if len(kwargs) > 0:
|
|
259
|
+
raise excs.Error(f'Invalid @uda decorator kwargs: {", ".join(kwargs.keys())}')
|
|
260
|
+
if len(args) > 0:
|
|
261
|
+
raise excs.Error('Unexpected @uda decorator arguments.')
|
|
262
|
+
|
|
263
|
+
def decorator(cls: type[Aggregator]) -> AggregateFunction:
|
|
264
|
+
return make_aggregator(
|
|
265
|
+
cls,
|
|
266
|
+
requires_order_by=requires_order_by,
|
|
267
|
+
allows_std_agg=allows_std_agg,
|
|
268
|
+
allows_window=allows_window,
|
|
269
|
+
type_substitutions=type_substitutions
|
|
270
|
+
)
|
|
205
271
|
|
|
206
|
-
|
|
272
|
+
return decorator
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def make_aggregator(
|
|
276
|
+
cls: type[Aggregator],
|
|
277
|
+
requires_order_by: bool = False,
|
|
278
|
+
allows_std_agg: bool = True,
|
|
279
|
+
allows_window: bool = False,
|
|
280
|
+
type_substitutions: Optional[Sequence[dict]] = None
|
|
281
|
+
) -> AggregateFunction:
|
|
282
|
+
class_path = f'{cls.__module__}.{cls.__qualname__}'
|
|
283
|
+
instance = AggregateFunction(cls, type_substitutions, class_path, requires_order_by, allows_std_agg, allows_window)
|
|
284
|
+
# do the path validation at the very end, in order to be able to write tests for the other failure cases
|
|
285
|
+
validate_symbol_path(class_path)
|
|
286
|
+
return instance
|