pixeltable 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (49) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/__init__.py +1 -1
  3. pixeltable/catalog/dir.py +6 -0
  4. pixeltable/catalog/globals.py +13 -0
  5. pixeltable/catalog/named_function.py +4 -0
  6. pixeltable/catalog/path_dict.py +37 -11
  7. pixeltable/catalog/schema_object.py +6 -0
  8. pixeltable/catalog/table.py +22 -5
  9. pixeltable/catalog/table_version.py +22 -8
  10. pixeltable/dataframe.py +201 -3
  11. pixeltable/env.py +9 -3
  12. pixeltable/exec/expr_eval_node.py +1 -1
  13. pixeltable/exec/sql_node.py +2 -2
  14. pixeltable/exprs/function_call.py +134 -24
  15. pixeltable/exprs/inline_expr.py +22 -2
  16. pixeltable/exprs/row_builder.py +1 -1
  17. pixeltable/exprs/similarity_expr.py +9 -2
  18. pixeltable/func/aggregate_function.py +148 -68
  19. pixeltable/func/callable_function.py +49 -13
  20. pixeltable/func/expr_template_function.py +55 -24
  21. pixeltable/func/function.py +183 -22
  22. pixeltable/func/function_registry.py +2 -1
  23. pixeltable/func/query_template_function.py +11 -6
  24. pixeltable/func/signature.py +64 -7
  25. pixeltable/func/udf.py +57 -35
  26. pixeltable/functions/globals.py +54 -34
  27. pixeltable/functions/json.py +3 -8
  28. pixeltable/functions/ollama.py +4 -4
  29. pixeltable/functions/timestamp.py +1 -1
  30. pixeltable/functions/video.py +2 -8
  31. pixeltable/functions/vision.py +1 -1
  32. pixeltable/globals.py +218 -59
  33. pixeltable/index/embedding_index.py +44 -24
  34. pixeltable/metadata/__init__.py +1 -1
  35. pixeltable/metadata/converters/convert_16.py +2 -1
  36. pixeltable/metadata/converters/convert_17.py +2 -1
  37. pixeltable/metadata/converters/convert_23.py +35 -0
  38. pixeltable/metadata/converters/convert_24.py +47 -0
  39. pixeltable/metadata/converters/util.py +4 -2
  40. pixeltable/metadata/notes.py +2 -0
  41. pixeltable/metadata/schema.py +1 -0
  42. pixeltable/tool/create_test_db_dump.py +11 -0
  43. pixeltable/tool/doc_plugins/griffe.py +4 -3
  44. pixeltable/type_system.py +180 -45
  45. {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/METADATA +3 -2
  46. {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/RECORD +49 -47
  47. {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/LICENSE +0 -0
  48. {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/WHEEL +0 -0
  49. {pixeltable-0.2.28.dist-info → pixeltable-0.2.29.dist-info}/entry_points.txt +0 -0
@@ -15,6 +15,7 @@ import pixeltable.type_system as ts
15
15
  from .data_row import DataRow
16
16
  from .expr import Expr
17
17
  from .inline_expr import InlineDict, InlineList
18
+ from .literal import Literal
18
19
  from .row_builder import RowBuilder
19
20
  from .rowid_ref import RowidRef
20
21
  from .sql_element_cache import SqlElementCache
@@ -34,6 +35,7 @@ class FunctionCall(Expr):
34
35
 
35
36
  arg_types: list[ts.ColumnType]
36
37
  kwarg_types: dict[str, ts.ColumnType]
38
+ return_type: ts.ColumnType
37
39
  group_by_start_idx: int
38
40
  group_by_stop_idx: int
39
41
  fn_expr_idx: int
@@ -43,17 +45,25 @@ class FunctionCall(Expr):
43
45
  current_partition_vals: Optional[list[Any]]
44
46
 
45
47
  def __init__(
46
- self, fn: func.Function, bound_args: dict[str, Any], order_by_clause: Optional[list[Any]] = None,
47
- group_by_clause: Optional[list[Any]] = None, is_method_call: bool = False):
48
+ self,
49
+ fn: func.Function,
50
+ bound_args: dict[str, Any],
51
+ return_type: ts.ColumnType,
52
+ order_by_clause: Optional[list[Any]] = None,
53
+ group_by_clause: Optional[list[Any]] = None,
54
+ is_method_call: bool = False
55
+ ):
48
56
  if order_by_clause is None:
49
57
  order_by_clause = []
50
58
  if group_by_clause is None:
51
59
  group_by_clause = []
52
- signature = fn.signature
53
- return_type = fn.call_return_type(bound_args)
60
+
61
+ assert not fn.is_polymorphic
62
+
54
63
  self.fn = fn
55
64
  self.is_method_call = is_method_call
56
- self.normalize_args(fn.name, signature, bound_args)
65
+
66
+ signature = fn.signature
57
67
 
58
68
  # If `return_type` is non-nullable, but the function call has a nullable input to any of its non-nullable
59
69
  # parameters, then we need to make it nullable. This is because Pixeltable defaults a function output to
@@ -67,6 +77,8 @@ class FunctionCall(Expr):
67
77
  return_type = return_type.copy(nullable=True)
68
78
  break
69
79
 
80
+ self.return_type = return_type
81
+
70
82
  super().__init__(return_type)
71
83
 
72
84
  self.agg_init_args = {}
@@ -74,9 +86,9 @@ class FunctionCall(Expr):
74
86
  # we separate out the init args for the aggregator
75
87
  assert isinstance(fn, func.AggregateFunction)
76
88
  self.agg_init_args = {
77
- arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names
89
+ arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names[0]
78
90
  }
79
- bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names}
91
+ bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names[0]}
80
92
 
81
93
  # construct components, args, kwargs
82
94
  self.args = []
@@ -88,7 +100,7 @@ class FunctionCall(Expr):
88
100
 
89
101
  # the prefix of parameters that are bound can be passed by position
90
102
  processed_args: set[str] = set()
91
- for py_param in fn.signature.py_signature.parameters.values():
103
+ for py_param in signature.py_signature.parameters.values():
92
104
  if py_param.name not in bound_args or py_param.kind == inspect.Parameter.KEYWORD_ONLY:
93
105
  break
94
106
  arg = bound_args[py_param.name]
@@ -110,7 +122,7 @@ class FunctionCall(Expr):
110
122
  self.components.append(arg.copy())
111
123
  else:
112
124
  self.kwargs[param_name] = (None, arg)
113
- if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
125
+ if signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
114
126
  self.kwarg_types[param_name] = signature.parameters[param_name].col_type
115
127
 
116
128
  # window function state:
@@ -129,7 +141,7 @@ class FunctionCall(Expr):
129
141
 
130
142
  if isinstance(self.fn, func.ExprTemplateFunction):
131
143
  # we instantiate the template to create an Expr that can be evaluated and record that as a component
132
- fn_expr = self.fn.instantiate(**bound_args)
144
+ fn_expr = self.fn.instantiate([], bound_args)
133
145
  self.components.append(fn_expr)
134
146
  self.fn_expr_idx = len(self.components) - 1
135
147
  else:
@@ -360,7 +372,7 @@ class FunctionCall(Expr):
360
372
  """
361
373
  assert self.is_agg_fn_call
362
374
  assert isinstance(self.fn, func.AggregateFunction)
363
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
375
+ self.aggregator = self.fn.agg_class(**self.agg_init_args)
364
376
 
365
377
  def update(self, data_row: DataRow) -> None:
366
378
  """
@@ -432,27 +444,32 @@ class FunctionCall(Expr):
432
444
  data_row[self.slot_idx] = self.fn.py_fn(*args, **kwargs)
433
445
  elif self.is_window_fn_call:
434
446
  assert isinstance(self.fn, func.AggregateFunction)
447
+ agg_cls = self.fn.agg_class
435
448
  if self.has_group_by():
436
449
  if self.current_partition_vals is None:
437
450
  self.current_partition_vals = [None] * len(self.group_by)
438
451
  partition_vals = [data_row[e.slot_idx] for e in self.group_by]
439
452
  if partition_vals != self.current_partition_vals:
440
453
  # new partition
441
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
454
+ self.aggregator = agg_cls(**self.agg_init_args)
442
455
  self.current_partition_vals = partition_vals
443
456
  elif self.aggregator is None:
444
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
457
+ self.aggregator = agg_cls(**self.agg_init_args)
445
458
  self.aggregator.update(*args)
446
459
  data_row[self.slot_idx] = self.aggregator.value()
447
460
  else:
448
- data_row[self.slot_idx] = self.fn.exec(*args, **kwargs)
461
+ data_row[self.slot_idx] = self.fn.exec(args, kwargs)
449
462
 
450
463
  def _as_dict(self) -> dict:
451
464
  result = {
452
- 'fn': self.fn.as_dict(), 'args': self.args, 'kwargs': self.kwargs,
453
- 'group_by_start_idx': self.group_by_start_idx, 'group_by_stop_idx': self.group_by_stop_idx,
465
+ 'fn': self.fn.as_dict(),
466
+ 'args': self.args,
467
+ 'kwargs': self.kwargs,
468
+ 'return_type': self.return_type.as_dict(),
469
+ 'group_by_start_idx': self.group_by_start_idx,
470
+ 'group_by_stop_idx': self.group_by_stop_idx,
454
471
  'order_by_start_idx': self.order_by_start_idx,
455
- **super()._as_dict()
472
+ **super()._as_dict(),
456
473
  }
457
474
  return result
458
475
 
@@ -461,15 +478,108 @@ class FunctionCall(Expr):
461
478
  assert 'fn' in d
462
479
  assert 'args' in d
463
480
  assert 'kwargs' in d
464
- # reassemble bound args
481
+
465
482
  fn = func.Function.from_dict(d['fn'])
466
- param_names = list(fn.signature.parameters.keys())
467
- bound_args = {param_names[i]: arg if idx is None else components[idx] for i, (idx, arg) in enumerate(d['args'])}
468
- bound_args.update(
469
- {param_name: val if idx is None else components[idx] for param_name, (idx, val) in d['kwargs'].items()})
483
+ assert not fn.is_polymorphic
484
+ return_type = ts.ColumnType.from_dict(d['return_type']) if 'return_type' in d else None
470
485
  group_by_exprs = components[d['group_by_start_idx']:d['group_by_stop_idx']]
471
486
  order_by_exprs = components[d['order_by_start_idx']:]
487
+
488
+ args = [
489
+ expr if idx is None else components[idx]
490
+ for idx, expr in d['args']
491
+ ]
492
+ kwargs = {
493
+ param_name: (expr if idx is None else components[idx])
494
+ for param_name, (idx, expr) in d['kwargs'].items()
495
+ }
496
+
497
+ # `Function.from_dict()` does signature matching, so it is safe to assume that `args` and `kwargs` are
498
+ # consistent with its signature.
499
+
500
+ # Reassemble bound_args. Note that args and kwargs represent "already bound arguments": they are not bindable
501
+ # in the Python sense, because variable args (such as *args and **kwargs) have already been condensed.
502
+ param_names = list(fn.signature.parameters.keys())
503
+ bound_args = {param_names[i]: arg for i, arg in enumerate(args)}
504
+ bound_args.update(kwargs.items())
505
+
506
+ # TODO: In order to properly invoke call_return_type, we need to ensure that any InlineLists or InlineDicts
507
+ # in bound_args are unpacked into Python lists/dicts. There is an open task to ensure this is true in general;
508
+ # for now, as a hack, we do the unpacking here for the specific case of an InlineList of Literals (the only
509
+ # case where this is necessary to support existing conditional_return_type implementations). Once the general
510
+ # pattern is implemented, we can remove this hack.
511
+ unpacked_bound_args = {
512
+ param_name: cls.__unpack_bound_arg(arg) for param_name, arg in bound_args.items()
513
+ }
514
+
515
+ # Evaluate the call_return_type as defined in the current codebase.
516
+ call_return_type = fn.call_return_type([], unpacked_bound_args)
517
+
518
+ if return_type is None:
519
+ # Schema versions prior to 25 did not store the return_type in metadata, and there is no obvious way to
520
+ # infer it during DB migration, so we might encounter a stored return_type of None. In that case, we use
521
+ # the call_return_type that we just inferred (which matches the deserialization behavior prior to
522
+ # version 25).
523
+ return_type = call_return_type
524
+ else:
525
+ # There is a return_type stored in metadata (schema version >= 25).
526
+ # Check that the stored return_type of the UDF call matches the column type of the FunctionCall, and
527
+ # fail-fast if it doesn't (otherwise we risk getting downstream database errors).
528
+ # TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
529
+ # mark this FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or Function
530
+ # signature mismatch.
531
+ if not return_type.is_supertype_of(call_return_type, ignore_nullable=True):
532
+ raise excs.Error(
533
+ f'The return type stored in the database for a UDF call to `{fn.self_path}` no longer matches the '
534
+ f'return type of the UDF as currently defined in the code.\nThis probably means that the code for '
535
+ f'`{fn.self_path}` has changed in a backward-incompatible way.\n'
536
+ f'Return type in database: `{return_type}`\n'
537
+ f'Return type as currently defined: `{call_return_type}`'
538
+ )
539
+
472
540
  fn_call = cls(
473
- func.Function.from_dict(d['fn']), bound_args, group_by_clause=group_by_exprs,
474
- order_by_clause=order_by_exprs)
541
+ fn,
542
+ bound_args,
543
+ return_type,
544
+ group_by_clause=group_by_exprs,
545
+ order_by_clause=order_by_exprs
546
+ )
475
547
  return fn_call
548
+
549
+ @classmethod
550
+ def __find_matching_signature(cls, fn: func.Function, args: list[Any], kwargs: dict[str, Any]) -> Optional[int]:
551
+ for idx, sig in enumerate(fn.signatures):
552
+ if cls.__signature_matches(sig, args, kwargs):
553
+ return idx
554
+ return None
555
+
556
+ @classmethod
557
+ def __signature_matches(cls, sig: func.Signature, args: list[Any], kwargs: dict[str, Any]) -> bool:
558
+ unbound_parameters = set(sig.parameters.keys())
559
+ for i, arg in enumerate(args):
560
+ if i >= len(sig.parameters_by_pos):
561
+ return False
562
+ param = sig.parameters_by_pos[i]
563
+ arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
564
+ if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
565
+ return False
566
+ unbound_parameters.remove(param.name)
567
+ for param_name, arg in kwargs.items():
568
+ if param_name not in unbound_parameters:
569
+ return False
570
+ param = sig.parameters[param_name]
571
+ arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
572
+ if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
573
+ return False
574
+ unbound_parameters.remove(param_name)
575
+ for param_name in unbound_parameters:
576
+ param = sig.parameters[param_name]
577
+ if not param.has_default:
578
+ return False
579
+ return True
580
+
581
+ @classmethod
582
+ def __unpack_bound_arg(cls, arg: Any) -> Any:
583
+ if isinstance(arg, InlineList) and all(isinstance(el, Literal) for el in arg.components):
584
+ return [el.val for el in arg.components]
585
+ return arg
@@ -101,7 +101,13 @@ class InlineList(Expr):
101
101
  else:
102
102
  exprs.append(Literal(el))
103
103
 
104
- super().__init__(ts.JsonType())
104
+ json_schema = {
105
+ 'type': 'array',
106
+ 'prefixItems': [expr.col_type.to_json_schema() for expr in exprs],
107
+ 'items': False # No additional items (fixed length)
108
+ }
109
+
110
+ super().__init__(ts.JsonType(json_schema))
105
111
  self.components.extend(exprs)
106
112
  self.id = self._create_id()
107
113
 
@@ -149,7 +155,21 @@ class InlineDict(Expr):
149
155
  else:
150
156
  exprs.append(Literal(val))
151
157
 
152
- super().__init__(ts.JsonType())
158
+ json_schema: Optional[dict[str, Any]]
159
+ try:
160
+ json_schema = {
161
+ 'type': 'object',
162
+ 'properties': {
163
+ key: expr.col_type.to_json_schema()
164
+ for key, expr in zip(self.keys, exprs)
165
+ },
166
+ }
167
+ except excs.Error:
168
+ # InlineDicts are used to store iterator arguments, which are not required to be valid JSON types,
169
+ # so we can't always construct a valid schema.
170
+ json_schema = None
171
+
172
+ super().__init__(ts.JsonType(json_schema))
153
173
  self.components.extend(exprs)
154
174
  self.id = self._create_id()
155
175
 
@@ -368,7 +368,7 @@ class RowBuilder:
368
368
  if not ignore_errors:
369
369
  input_vals = [data_row[d.slot_idx] for d in expr.dependencies()]
370
370
  raise excs.ExprEvalError(
371
- expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0)
371
+ expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0) from exc
372
372
 
373
373
  def create_table_row(self, data_row: DataRow, exc_col_ids: set[int]) -> tuple[dict[str, Any], int]:
374
374
  """Create a table row from the slots that have an output column assigned
@@ -23,7 +23,6 @@ class SimilarityExpr(Expr):
23
23
  assert item_expr.col_type.is_string_type() or item_expr.col_type.is_image_type()
24
24
 
25
25
  self.components = [col_ref, item_expr]
26
- self.id = self._create_id()
27
26
 
28
27
  # determine index to use
29
28
  idx_info = col_ref.col.get_idx_info()
@@ -54,10 +53,14 @@ class SimilarityExpr(Expr):
54
53
  raise excs.Error(
55
54
  f'Embedding index {self.idx_info.name!r} on column {self.idx_info.col.name!r} was created without the '
56
55
  f"'image_embed' parameter and does not support image queries")
56
+ self.id = self._create_id()
57
57
 
58
58
  def __repr__(self) -> str:
59
59
  return f'{self.components[0]}.similarity({self.components[1]})'
60
60
 
61
+ def _id_attrs(self):
62
+ return super()._id_attrs() + [('idx_name', self.idx_info.name)]
63
+
61
64
  def default_column_name(self) -> str:
62
65
  return 'similarity'
63
66
 
@@ -81,8 +84,12 @@ class SimilarityExpr(Expr):
81
84
  # this should never get called
82
85
  assert False
83
86
 
87
+ def _as_dict(self) -> dict:
88
+ return {'idx_name': self.idx_info.name, **super()._as_dict()}
89
+
84
90
  @classmethod
85
91
  def _from_dict(cls, d: dict, components: list[Expr]) -> 'SimilarityExpr':
92
+ iname = d['idx_name'] if 'idx_name' in d else None
86
93
  assert len(components) == 2
87
94
  assert isinstance(components[0], ColumnRef)
88
- return cls(components[0], components[1])
95
+ return cls(components[0], components[1], idx_name=iname)
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import inspect
5
- from typing import TYPE_CHECKING, Any, Callable, Optional
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
@@ -35,26 +35,73 @@ class AggregateFunction(Function):
35
35
  GROUP_BY_PARAM = 'group_by'
36
36
  RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
37
37
 
38
+ agg_classes: list[type[Aggregator]] # classes for each signature, in signature order
39
+ init_param_names: list[list[str]] # names of the __init__ parameters for each signature
40
+
38
41
  def __init__(
39
- self, aggregator_class: type[Aggregator], self_path: str,
40
- init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
41
- requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
42
- self.agg_cls = aggregator_class
42
+ self,
43
+ agg_class: type[Aggregator],
44
+ type_substitutions: Optional[Sequence[dict]],
45
+ self_path: str,
46
+ requires_order_by: bool,
47
+ allows_std_agg: bool,
48
+ allows_window: bool
49
+ ) -> None:
50
+ if type_substitutions is None:
51
+ type_substitutions = [None] # single signature with no substitutions
52
+ self.agg_classes = [agg_class]
53
+ else:
54
+ self.agg_classes = [agg_class] * len(type_substitutions)
55
+ self.init_param_names = []
43
56
  self.requires_order_by = requires_order_by
44
57
  self.allows_std_agg = allows_std_agg
45
58
  self.allows_window = allows_window
46
- self.__doc__ = aggregator_class.__doc__
59
+ self.__doc__ = agg_class.__doc__
60
+
61
+ signatures: list[Signature] = []
62
+
63
+ # If no type_substitutions were provided, construct a single signature for the class.
64
+ # Otherwise, construct one signature for each type substitution instance.
65
+ for subst in type_substitutions:
66
+ signature, init_param_names = self.__cls_to_signature(agg_class, subst)
67
+ signatures.append(signature)
68
+ self.init_param_names.append(init_param_names)
69
+
70
+ super().__init__(signatures, self_path=self_path)
71
+
72
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
73
+ self.agg_classes = [self.agg_classes[signature_idx]]
74
+ self.init_param_names = [self.init_param_names[signature_idx]]
75
+
76
+ def __cls_to_signature(
77
+ self, cls: type[Aggregator], type_substitutions: Optional[dict] = None
78
+ ) -> tuple[Signature, list[str]]:
79
+ """Inspects the Aggregator class to infer the corresponding function signature. Returns the
80
+ inferred signature along with the list of init_param_names (for downstream error handling).
81
+ """
82
+ # infer type parameters; set return_type=InvalidType() because it has no meaning here
83
+ init_sig = Signature.create(py_fn=cls.__init__, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
84
+ update_sig = Signature.create(py_fn=cls.update, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
85
+ value_sig = Signature.create(py_fn=cls.value, is_cls_method=True, type_substitutions=type_substitutions)
86
+
87
+ init_types = [p.col_type for p in init_sig.parameters.values()]
88
+ update_types = [p.col_type for p in update_sig.parameters.values()]
89
+ value_type = value_sig.return_type
90
+ assert value_type is not None
91
+
92
+ if len(update_types) == 0:
93
+ raise excs.Error('update() must have at least one parameter')
47
94
 
48
95
  # our signature is the signature of 'update', but without self,
49
96
  # plus the parameters of 'init' as keyword-only parameters
50
- py_update_params = list(inspect.signature(self.agg_cls.update).parameters.values())[1:] # leave out self
97
+ py_update_params = list(inspect.signature(cls.update).parameters.values())[1:] # leave out self
51
98
  assert len(py_update_params) == len(update_types)
52
99
  update_params = [
53
100
  Parameter(p.name, col_type=update_types[i], kind=p.kind, default=p.default)
54
101
  for i, p in enumerate(py_update_params)
55
102
  ]
56
103
  # starting at 1: leave out self
57
- py_init_params = list(inspect.signature(self.agg_cls.__init__).parameters.values())[1:]
104
+ py_init_params = list(inspect.signature(cls.__init__).parameters.values())[1:]
58
105
  assert len(py_init_params) == len(init_types)
59
106
  init_params = [
60
107
  Parameter(p.name, col_type=init_types[i], kind=inspect.Parameter.KEYWORD_ONLY, default=p.default)
@@ -67,23 +114,36 @@ class AggregateFunction(Function):
67
114
  f'{", ".join(duplicate_params)}'
68
115
  )
69
116
  params = update_params + init_params # init_params are keyword-only and come last
117
+ init_param_names = [p.name for p in init_params]
70
118
 
71
- signature = Signature(value_type, params)
72
- super().__init__(signature, self_path=self_path)
73
- self.init_param_names = [p.name for p in init_params]
119
+ return Signature(value_type, params), init_param_names
74
120
 
75
- # make sure the signature doesn't contain reserved parameter names;
76
- # do this after super().__init__(), otherwise self.name is invalid
77
- for param in signature.parameters:
78
- if param.lower() in self.RESERVED_PARAMS:
79
- raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
121
+ @property
122
+ def agg_class(self) -> type[Aggregator]:
123
+ assert not self.is_polymorphic
124
+ return self.agg_classes[0]
80
125
 
81
- def exec(self, *args: Any, **kwargs: Any) -> Any:
126
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
82
127
  raise NotImplementedError
83
128
 
129
+ def overload(self, cls: type[Aggregator]) -> AggregateFunction:
130
+ if not isinstance(cls, type) or not issubclass(cls, Aggregator):
131
+ raise excs.Error(f'Invalid argument to @overload decorator: {cls}')
132
+ if self._has_resolved_fns:
133
+ raise excs.Error('New `overload` not allowed after the UDF has already been called')
134
+ if self._conditional_return_type is not None:
135
+ raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
136
+ sig, init_param_names = self.__cls_to_signature(cls)
137
+ self.signatures.append(sig)
138
+ self.agg_classes.append(cls)
139
+ self.init_param_names.append(init_param_names)
140
+ return self
141
+
84
142
  def help_str(self) -> str:
85
143
  res = super().help_str()
86
- res += '\n\n' + inspect.getdoc(self.agg_cls.update)
144
+ # We need to reference agg_classes[0] rather than agg_class here, because we want this to work even if the
145
+ # aggregator is polymorphic (in which case we use the docstring of the originally decorated UDA).
146
+ res += '\n\n' + inspect.getdoc(self.agg_classes[0].update)
87
147
  return res
88
148
 
89
149
  def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
@@ -121,18 +181,24 @@ class AggregateFunction(Function):
121
181
  f'{self.display_name}(): group_by invalid with an aggregate function that does not allow windows')
122
182
  group_by_clause = kwargs.pop(self.GROUP_BY_PARAM)
123
183
 
124
- bound_args = self.signature.py_signature.bind(*args, **kwargs)
125
- self.validate_call(bound_args.arguments)
184
+ resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
185
+ return_type = resolved_fn.call_return_type(args, kwargs)
126
186
  return exprs.FunctionCall(
127
- self, bound_args.arguments,
187
+ resolved_fn,
188
+ bound_args,
189
+ return_type,
128
190
  order_by_clause=[order_by_clause] if order_by_clause is not None else [],
129
- group_by_clause=[group_by_clause] if group_by_clause is not None else [])
191
+ group_by_clause=[group_by_clause] if group_by_clause is not None else []
192
+ )
130
193
 
131
194
  def validate_call(self, bound_args: dict[str, Any]) -> None:
132
195
  # check that init parameters are not Exprs
133
196
  # TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
134
- import pixeltable.exprs as exprs
135
- for param_name in self.init_param_names:
197
+ from pixeltable import exprs
198
+
199
+ assert not self.is_polymorphic
200
+
201
+ for param_name in self.init_param_names[0]:
136
202
  if param_name in bound_args and isinstance(bound_args[param_name], exprs.Expr):
137
203
  raise excs.Error(
138
204
  f'{self.display_name}(): init() parameter {param_name} needs to be a constant, not a Pixeltable '
@@ -143,13 +209,23 @@ class AggregateFunction(Function):
143
209
  return f'<Pixeltable Aggregator {self.name}>'
144
210
 
145
211
 
212
+ # Decorator invoked without parentheses: @pxt.uda
213
+ @overload
214
+ def uda(decorated_fn: Callable) -> AggregateFunction: ...
215
+
216
+
217
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
218
+ @overload
146
219
  def uda(
147
- *,
148
- value_type: ts.ColumnType,
149
- update_types: list[ts.ColumnType],
150
- init_types: Optional[list[ts.ColumnType]] = None,
151
- requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
152
- ) -> Callable[[type[Aggregator]], AggregateFunction]:
220
+ *,
221
+ requires_order_by: bool = False,
222
+ allows_std_agg: bool = True,
223
+ allows_window: bool = False,
224
+ type_substitutions: Optional[Sequence[dict]] = None
225
+ ) -> Callable[[type[Aggregator]], AggregateFunction]: ...
226
+
227
+
228
+ def uda(*args, **kwargs):
153
229
  """Decorator for user-defined aggregate functions.
154
230
 
155
231
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -161,46 +237,50 @@ def uda(
161
237
  to the module where the class is defined.
162
238
 
163
239
  Parameters:
164
- - init_types: list of types for the __init__() parameters; must match the number of parameters
165
- - update_types: list of types for the update() parameters; must match the number of parameters
166
- - value_type: return type of the aggregator
167
240
  - requires_order_by: if True, the first parameter to the function is the order-by expression
168
241
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
169
242
  - allows_window: if True, the function can be used with a window
170
243
  """
171
- if init_types is None:
172
- init_types = []
173
-
174
- def decorator(cls: type[Aggregator]) -> AggregateFunction:
175
- # validate type parameters
176
- num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
177
- if num_init_params > 0:
178
- if len(init_types) != num_init_params:
179
- raise excs.Error(
180
- f'init_types must be a list of {num_init_params} types, one for each parameter of __init__()')
181
- num_update_params = len(inspect.signature(cls.update).parameters) - 1
182
- if num_update_params == 0:
183
- raise excs.Error('update() must have at least one parameter')
184
- if len(update_types) != num_update_params:
185
- raise excs.Error(
186
- f'update_types must be a list of {num_update_params} types, one for each parameter of update()')
187
- assert value_type is not None
188
-
189
- # the AggregateFunction instance resides in the same module as cls
190
- class_path = f'{cls.__module__}.{cls.__qualname__}'
191
- # nonlocal name
192
- # name = name or cls.__name__
193
- # instance_path_elements = class_path.split('.')[:-1] + [name]
194
- # instance_path = '.'.join(instance_path_elements)
195
-
196
- # create the corresponding AggregateFunction instance
197
- instance = AggregateFunction(
198
- cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
199
- # do the path validation at the very end, in order to be able to write tests for the other failure cases
200
- validate_symbol_path(class_path)
201
- #module = importlib.import_module(cls.__module__)
202
- #setattr(module, name, instance)
203
-
204
- return instance
244
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
245
+
246
+ # Decorator invoked without parentheses: @pxt.uda
247
+ # Simply call make_aggregator with defaults.
248
+ return make_aggregator(cls=args[0])
249
+
250
+ else:
251
+
252
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
253
+ # Create a decorator for the specified schema.
254
+ requires_order_by = kwargs.pop('requires_order_by', False)
255
+ allows_std_agg = kwargs.pop('allows_std_agg', True)
256
+ allows_window = kwargs.pop('allows_window', False)
257
+ type_substitutions = kwargs.pop('type_substitutions', None)
258
+ if len(kwargs) > 0:
259
+ raise excs.Error(f'Invalid @uda decorator kwargs: {", ".join(kwargs.keys())}')
260
+ if len(args) > 0:
261
+ raise excs.Error('Unexpected @uda decorator arguments.')
262
+
263
+ def decorator(cls: type[Aggregator]) -> AggregateFunction:
264
+ return make_aggregator(
265
+ cls,
266
+ requires_order_by=requires_order_by,
267
+ allows_std_agg=allows_std_agg,
268
+ allows_window=allows_window,
269
+ type_substitutions=type_substitutions
270
+ )
205
271
 
206
- return decorator
272
+ return decorator
273
+
274
+
275
+ def make_aggregator(
276
+ cls: type[Aggregator],
277
+ requires_order_by: bool = False,
278
+ allows_std_agg: bool = True,
279
+ allows_window: bool = False,
280
+ type_substitutions: Optional[Sequence[dict]] = None
281
+ ) -> AggregateFunction:
282
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
283
+ instance = AggregateFunction(cls, type_substitutions, class_path, requires_order_by, allows_std_agg, allows_window)
284
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
285
+ validate_symbol_path(class_path)
286
+ return instance