pixeltable 0.2.24__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (101) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +531 -251
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +8 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +17 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +227 -23
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +65 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/mistralai.py +0 -2
  58. pixeltable/functions/ollama.py +8 -8
  59. pixeltable/functions/openai.py +51 -4
  60. pixeltable/functions/timestamp.py +1 -1
  61. pixeltable/functions/video.py +3 -9
  62. pixeltable/functions/vision.py +1 -1
  63. pixeltable/globals.py +374 -89
  64. pixeltable/index/embedding_index.py +106 -29
  65. pixeltable/io/__init__.py +1 -1
  66. pixeltable/io/label_studio.py +1 -1
  67. pixeltable/io/parquet.py +39 -19
  68. pixeltable/iterators/__init__.py +1 -0
  69. pixeltable/iterators/document.py +12 -0
  70. pixeltable/iterators/image.py +100 -0
  71. pixeltable/iterators/video.py +7 -8
  72. pixeltable/metadata/__init__.py +1 -1
  73. pixeltable/metadata/converters/convert_16.py +2 -1
  74. pixeltable/metadata/converters/convert_17.py +2 -1
  75. pixeltable/metadata/converters/convert_22.py +17 -0
  76. pixeltable/metadata/converters/convert_23.py +35 -0
  77. pixeltable/metadata/converters/convert_24.py +56 -0
  78. pixeltable/metadata/converters/convert_25.py +19 -0
  79. pixeltable/metadata/converters/util.py +4 -2
  80. pixeltable/metadata/notes.py +4 -0
  81. pixeltable/metadata/schema.py +1 -0
  82. pixeltable/plan.py +129 -51
  83. pixeltable/store.py +1 -1
  84. pixeltable/type_system.py +196 -54
  85. pixeltable/utils/arrow.py +8 -3
  86. pixeltable/utils/description_helper.py +89 -0
  87. pixeltable/utils/documents.py +14 -0
  88. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/METADATA +32 -22
  89. pixeltable-0.3.0.dist-info/RECORD +155 -0
  90. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  91. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  92. pixeltable/tool/create_test_db_dump.py +0 -308
  93. pixeltable/tool/create_test_video.py +0 -81
  94. pixeltable/tool/doc_plugins/griffe.py +0 -50
  95. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  96. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  97. pixeltable/tool/embed_udf.py +0 -9
  98. pixeltable/tool/mypy_plugin.py +0 -55
  99. pixeltable-0.2.24.dist-info/RECORD +0 -153
  100. pixeltable-0.2.24.dist-info/entry_points.txt +0 -3
  101. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import inspect
5
- from typing import TYPE_CHECKING, Any, Callable, Optional
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
@@ -35,26 +35,73 @@ class AggregateFunction(Function):
35
35
  GROUP_BY_PARAM = 'group_by'
36
36
  RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
37
37
 
38
+ agg_classes: list[type[Aggregator]] # classes for each signature, in signature order
39
+ init_param_names: list[list[str]] # names of the __init__ parameters for each signature
40
+
38
41
  def __init__(
39
- self, aggregator_class: type[Aggregator], self_path: str,
40
- init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
41
- requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
42
- self.agg_cls = aggregator_class
42
+ self,
43
+ agg_class: type[Aggregator],
44
+ type_substitutions: Optional[Sequence[dict]],
45
+ self_path: str,
46
+ requires_order_by: bool,
47
+ allows_std_agg: bool,
48
+ allows_window: bool
49
+ ) -> None:
50
+ if type_substitutions is None:
51
+ type_substitutions = [None] # single signature with no substitutions
52
+ self.agg_classes = [agg_class]
53
+ else:
54
+ self.agg_classes = [agg_class] * len(type_substitutions)
55
+ self.init_param_names = []
43
56
  self.requires_order_by = requires_order_by
44
57
  self.allows_std_agg = allows_std_agg
45
58
  self.allows_window = allows_window
46
- self.__doc__ = aggregator_class.__doc__
59
+ self.__doc__ = agg_class.__doc__
60
+
61
+ signatures: list[Signature] = []
62
+
63
+ # If no type_substitutions were provided, construct a single signature for the class.
64
+ # Otherwise, construct one signature for each type substitution instance.
65
+ for subst in type_substitutions:
66
+ signature, init_param_names = self.__cls_to_signature(agg_class, subst)
67
+ signatures.append(signature)
68
+ self.init_param_names.append(init_param_names)
69
+
70
+ super().__init__(signatures, self_path=self_path)
71
+
72
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
73
+ self.agg_classes = [self.agg_classes[signature_idx]]
74
+ self.init_param_names = [self.init_param_names[signature_idx]]
75
+
76
+ def __cls_to_signature(
77
+ self, cls: type[Aggregator], type_substitutions: Optional[dict] = None
78
+ ) -> tuple[Signature, list[str]]:
79
+ """Inspects the Aggregator class to infer the corresponding function signature. Returns the
80
+ inferred signature along with the list of init_param_names (for downstream error handling).
81
+ """
82
+ # infer type parameters; set return_type=InvalidType() because it has no meaning here
83
+ init_sig = Signature.create(py_fn=cls.__init__, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
84
+ update_sig = Signature.create(py_fn=cls.update, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions)
85
+ value_sig = Signature.create(py_fn=cls.value, is_cls_method=True, type_substitutions=type_substitutions)
86
+
87
+ init_types = [p.col_type for p in init_sig.parameters.values()]
88
+ update_types = [p.col_type for p in update_sig.parameters.values()]
89
+ value_type = value_sig.return_type
90
+ assert value_type is not None
91
+
92
+ if len(update_types) == 0:
93
+ raise excs.Error('update() must have at least one parameter')
47
94
 
48
95
  # our signature is the signature of 'update', but without self,
49
96
  # plus the parameters of 'init' as keyword-only parameters
50
- py_update_params = list(inspect.signature(self.agg_cls.update).parameters.values())[1:] # leave out self
97
+ py_update_params = list(inspect.signature(cls.update).parameters.values())[1:] # leave out self
51
98
  assert len(py_update_params) == len(update_types)
52
99
  update_params = [
53
100
  Parameter(p.name, col_type=update_types[i], kind=p.kind, default=p.default)
54
101
  for i, p in enumerate(py_update_params)
55
102
  ]
56
103
  # starting at 1: leave out self
57
- py_init_params = list(inspect.signature(self.agg_cls.__init__).parameters.values())[1:]
104
+ py_init_params = list(inspect.signature(cls.__init__).parameters.values())[1:]
58
105
  assert len(py_init_params) == len(init_types)
59
106
  init_params = [
60
107
  Parameter(p.name, col_type=init_types[i], kind=inspect.Parameter.KEYWORD_ONLY, default=p.default)
@@ -67,23 +114,39 @@ class AggregateFunction(Function):
67
114
  f'{", ".join(duplicate_params)}'
68
115
  )
69
116
  params = update_params + init_params # init_params are keyword-only and come last
117
+ init_param_names = [p.name for p in init_params]
70
118
 
71
- signature = Signature(value_type, params)
72
- super().__init__(signature, self_path=self_path)
73
- self.init_param_names = [p.name for p in init_params]
119
+ return Signature(value_type, params), init_param_names
74
120
 
75
- # make sure the signature doesn't contain reserved parameter names;
76
- # do this after super().__init__(), otherwise self.name is invalid
77
- for param in signature.parameters:
78
- if param.lower() in self.RESERVED_PARAMS:
79
- raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
121
+ @property
122
+ def agg_class(self) -> type[Aggregator]:
123
+ assert not self.is_polymorphic
124
+ return self.agg_classes[0]
80
125
 
81
- def exec(self, *args: Any, **kwargs: Any) -> Any:
126
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
82
127
  raise NotImplementedError
83
128
 
129
+ def overload(self, cls: type[Aggregator]) -> AggregateFunction:
130
+ if not isinstance(cls, type) or not issubclass(cls, Aggregator):
131
+ raise excs.Error(f'Invalid argument to @overload decorator: {cls}')
132
+ if self._has_resolved_fns:
133
+ raise excs.Error('New `overload` not allowed after the UDF has already been called')
134
+ if self._conditional_return_type is not None:
135
+ raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
136
+ sig, init_param_names = self.__cls_to_signature(cls)
137
+ self.signatures.append(sig)
138
+ self.agg_classes.append(cls)
139
+ self.init_param_names.append(init_param_names)
140
+ return self
141
+
142
+ def _docstring(self) -> Optional[str]:
143
+ return inspect.getdoc(self.agg_classes[0])
144
+
84
145
  def help_str(self) -> str:
85
146
  res = super().help_str()
86
- res += '\n\n' + inspect.getdoc(self.agg_cls.update)
147
+ # We need to reference agg_classes[0] rather than agg_class here, because we want this to work even if the
148
+ # aggregator is polymorphic (in which case we use the docstring of the originally decorated UDA).
149
+ res += '\n\n' + inspect.getdoc(self.agg_classes[0].update)
87
150
  return res
88
151
 
89
152
  def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
@@ -121,18 +184,24 @@ class AggregateFunction(Function):
121
184
  f'{self.display_name}(): group_by invalid with an aggregate function that does not allow windows')
122
185
  group_by_clause = kwargs.pop(self.GROUP_BY_PARAM)
123
186
 
124
- bound_args = self.signature.py_signature.bind(*args, **kwargs)
125
- self.validate_call(bound_args.arguments)
187
+ resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
188
+ return_type = resolved_fn.call_return_type(args, kwargs)
126
189
  return exprs.FunctionCall(
127
- self, bound_args.arguments,
190
+ resolved_fn,
191
+ bound_args,
192
+ return_type,
128
193
  order_by_clause=[order_by_clause] if order_by_clause is not None else [],
129
- group_by_clause=[group_by_clause] if group_by_clause is not None else [])
194
+ group_by_clause=[group_by_clause] if group_by_clause is not None else []
195
+ )
130
196
 
131
197
  def validate_call(self, bound_args: dict[str, Any]) -> None:
132
198
  # check that init parameters are not Exprs
133
199
  # TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
134
- import pixeltable.exprs as exprs
135
- for param_name in self.init_param_names:
200
+ from pixeltable import exprs
201
+
202
+ assert not self.is_polymorphic
203
+
204
+ for param_name in self.init_param_names[0]:
136
205
  if param_name in bound_args and isinstance(bound_args[param_name], exprs.Expr):
137
206
  raise excs.Error(
138
207
  f'{self.display_name}(): init() parameter {param_name} needs to be a constant, not a Pixeltable '
@@ -143,13 +212,23 @@ class AggregateFunction(Function):
143
212
  return f'<Pixeltable Aggregator {self.name}>'
144
213
 
145
214
 
215
+ # Decorator invoked without parentheses: @pxt.uda
216
+ @overload
217
+ def uda(decorated_fn: Callable) -> AggregateFunction: ...
218
+
219
+
220
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
221
+ @overload
146
222
  def uda(
147
- *,
148
- value_type: ts.ColumnType,
149
- update_types: list[ts.ColumnType],
150
- init_types: Optional[list[ts.ColumnType]] = None,
151
- requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
152
- ) -> Callable[[type[Aggregator]], AggregateFunction]:
223
+ *,
224
+ requires_order_by: bool = False,
225
+ allows_std_agg: bool = True,
226
+ allows_window: bool = False,
227
+ type_substitutions: Optional[Sequence[dict]] = None
228
+ ) -> Callable[[type[Aggregator]], AggregateFunction]: ...
229
+
230
+
231
+ def uda(*args, **kwargs):
153
232
  """Decorator for user-defined aggregate functions.
154
233
 
155
234
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -161,46 +240,50 @@ def uda(
161
240
  to the module where the class is defined.
162
241
 
163
242
  Parameters:
164
- - init_types: list of types for the __init__() parameters; must match the number of parameters
165
- - update_types: list of types for the update() parameters; must match the number of parameters
166
- - value_type: return type of the aggregator
167
243
  - requires_order_by: if True, the first parameter to the function is the order-by expression
168
244
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
169
245
  - allows_window: if True, the function can be used with a window
170
246
  """
171
- if init_types is None:
172
- init_types = []
173
-
174
- def decorator(cls: type[Aggregator]) -> AggregateFunction:
175
- # validate type parameters
176
- num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
177
- if num_init_params > 0:
178
- if len(init_types) != num_init_params:
179
- raise excs.Error(
180
- f'init_types must be a list of {num_init_params} types, one for each parameter of __init__()')
181
- num_update_params = len(inspect.signature(cls.update).parameters) - 1
182
- if num_update_params == 0:
183
- raise excs.Error('update() must have at least one parameter')
184
- if len(update_types) != num_update_params:
185
- raise excs.Error(
186
- f'update_types must be a list of {num_update_params} types, one for each parameter of update()')
187
- assert value_type is not None
188
-
189
- # the AggregateFunction instance resides in the same module as cls
190
- class_path = f'{cls.__module__}.{cls.__qualname__}'
191
- # nonlocal name
192
- # name = name or cls.__name__
193
- # instance_path_elements = class_path.split('.')[:-1] + [name]
194
- # instance_path = '.'.join(instance_path_elements)
195
-
196
- # create the corresponding AggregateFunction instance
197
- instance = AggregateFunction(
198
- cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
199
- # do the path validation at the very end, in order to be able to write tests for the other failure cases
200
- validate_symbol_path(class_path)
201
- #module = importlib.import_module(cls.__module__)
202
- #setattr(module, name, instance)
203
-
204
- return instance
247
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
248
+
249
+ # Decorator invoked without parentheses: @pxt.uda
250
+ # Simply call make_aggregator with defaults.
251
+ return make_aggregator(cls=args[0])
252
+
253
+ else:
254
+
255
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
256
+ # Create a decorator for the specified schema.
257
+ requires_order_by = kwargs.pop('requires_order_by', False)
258
+ allows_std_agg = kwargs.pop('allows_std_agg', True)
259
+ allows_window = kwargs.pop('allows_window', False)
260
+ type_substitutions = kwargs.pop('type_substitutions', None)
261
+ if len(kwargs) > 0:
262
+ raise excs.Error(f'Invalid @uda decorator kwargs: {", ".join(kwargs.keys())}')
263
+ if len(args) > 0:
264
+ raise excs.Error('Unexpected @uda decorator arguments.')
265
+
266
+ def decorator(cls: type[Aggregator]) -> AggregateFunction:
267
+ return make_aggregator(
268
+ cls,
269
+ requires_order_by=requires_order_by,
270
+ allows_std_agg=allows_std_agg,
271
+ allows_window=allows_window,
272
+ type_substitutions=type_substitutions
273
+ )
205
274
 
206
- return decorator
275
+ return decorator
276
+
277
+
278
+ def make_aggregator(
279
+ cls: type[Aggregator],
280
+ requires_order_by: bool = False,
281
+ allows_std_agg: bool = True,
282
+ allows_window: bool = False,
283
+ type_substitutions: Optional[Sequence[dict]] = None
284
+ ) -> AggregateFunction:
285
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
286
+ instance = AggregateFunction(cls, type_substitutions, class_path, requires_order_by, allows_std_agg, allows_window)
287
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
288
+ validate_symbol_path(class_path)
289
+ return instance
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Any, Callable, Optional
4
+ from typing import Any, Callable, Optional, Sequence
5
5
  from uuid import UUID
6
6
 
7
7
  import cloudpickle # type: ignore[import-untyped]
8
8
 
9
+ import pixeltable.exceptions as excs
10
+
9
11
  from .function import Function
10
12
  from .signature import Signature
11
13
 
@@ -20,26 +22,42 @@ class CallableFunction(Function):
20
22
 
21
23
  def __init__(
22
24
  self,
23
- signature: Signature,
24
- py_fn: Callable,
25
+ signatures: list[Signature],
26
+ py_fns: list[Callable],
25
27
  self_path: Optional[str] = None,
26
28
  self_name: Optional[str] = None,
27
29
  batch_size: Optional[int] = None,
28
30
  is_method: bool = False,
29
31
  is_property: bool = False
30
32
  ):
31
- assert py_fn is not None
32
- self.py_fn = py_fn
33
+ assert len(signatures) > 0
34
+ assert len(signatures) == len(py_fns)
35
+ if self_path is None and len(signatures) > 1:
36
+ raise excs.Error('Multiple signatures are only allowed for module UDFs (not locally defined UDFs)')
37
+ self.py_fns = py_fns
33
38
  self.self_name = self_name
34
39
  self.batch_size = batch_size
35
- self.__doc__ = py_fn.__doc__
36
- super().__init__(signature, self_path=self_path, is_method=is_method, is_property=is_property)
40
+ self.__doc__ = self.py_fns[0].__doc__
41
+ super().__init__(signatures, self_path=self_path, is_method=is_method, is_property=is_property)
42
+
43
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
44
+ assert len(self.py_fns) > signature_idx
45
+ self.py_fns = [self.py_fns[signature_idx]]
37
46
 
38
47
  @property
39
48
  def is_batched(self) -> bool:
40
49
  return self.batch_size is not None
41
50
 
42
- def exec(self, *args: Any, **kwargs: Any) -> Any:
51
+ def _docstring(self) -> Optional[str]:
52
+ return inspect.getdoc(self.py_fns[0])
53
+
54
+ @property
55
+ def py_fn(self) -> Callable:
56
+ assert not self.is_polymorphic
57
+ return self.py_fns[0]
58
+
59
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
60
+ assert not self.is_polymorphic
43
61
  if self.is_batched:
44
62
  # Pack the batched parameters into singleton lists
45
63
  constant_param_names = [p.name for p in self.signature.constant_parameters]
@@ -52,13 +70,14 @@ class CallableFunction(Function):
52
70
  else:
53
71
  return self.py_fn(*args, **kwargs)
54
72
 
55
- def exec_batch(self, *args: Any, **kwargs: Any) -> list:
73
+ def exec_batch(self, args: list[Any], kwargs: dict[str, Any]) -> list:
56
74
  """Execute the function with the given arguments and return the result.
57
75
  The arguments are expected to be batched: if the corresponding parameter has type T,
58
76
  then the argument should have type T if it's a constant parameter, or list[T] if it's
59
77
  a batched parameter.
60
78
  """
61
79
  assert self.is_batched
80
+ assert not self.is_polymorphic
62
81
  # Unpack the constant parameters
63
82
  constant_param_names = [p.name for p in self.signature.constant_parameters]
64
83
  constant_kwargs = {k: v[0] for k, v in kwargs.items() if k in constant_param_names}
@@ -77,10 +96,19 @@ class CallableFunction(Function):
77
96
  def name(self) -> str:
78
97
  return self.self_name
79
98
 
80
- def help_str(self) -> str:
81
- res = super().help_str()
82
- res += '\n\n' + inspect.getdoc(self.py_fn)
83
- return res
99
+ def overload(self, fn: Callable) -> CallableFunction:
100
+ if self.self_path is None:
101
+ raise excs.Error('`overload` can only be used with module UDFs (not locally defined UDFs)')
102
+ if self.is_method or self.is_property:
103
+ raise excs.Error('`overload` cannot be used with `is_method` or `is_property`')
104
+ if self._has_resolved_fns:
105
+ raise excs.Error('New `overload` not allowed after the UDF has already been called')
106
+ if self._conditional_return_type is not None:
107
+ raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
108
+ sig = Signature.create(fn)
109
+ self.signatures.append(sig)
110
+ self.py_fns.append(fn)
111
+ return self
84
112
 
85
113
  def _as_dict(self) -> dict:
86
114
  if self.self_path is None:
@@ -99,6 +127,7 @@ class CallableFunction(Function):
99
127
  return super()._from_dict(d)
100
128
 
101
129
  def to_store(self) -> tuple[dict, bytes]:
130
+ assert not self.is_polymorphic # multi-signature UDFs not allowed for stored fns
102
131
  md = {
103
132
  'signature': self.signature.as_dict(),
104
133
  'batch_size': self.batch_size,
@@ -111,12 +140,15 @@ class CallableFunction(Function):
111
140
  assert callable(py_fn)
112
141
  sig = Signature.from_dict(md['signature'])
113
142
  batch_size = md['batch_size']
114
- return CallableFunction(sig, py_fn, self_name=name, batch_size=batch_size)
143
+ return CallableFunction([sig], [py_fn], self_name=name, batch_size=batch_size)
115
144
 
116
145
  def validate_call(self, bound_args: dict[str, Any]) -> None:
117
- import pixeltable.exprs as exprs
146
+ from pixeltable import exprs
147
+
148
+ assert not self.is_polymorphic
118
149
  if self.is_batched:
119
- for param in self.signature.constant_parameters:
150
+ signature = self.signatures[0]
151
+ for param in signature.constant_parameters:
120
152
  if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
121
153
  raise ValueError(
122
154
  f'{self.display_name}(): '
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, Optional
2
+ from typing import Any, Optional, Sequence
3
3
 
4
4
  import pixeltable
5
5
  import pixeltable.exceptions as excs
@@ -8,15 +8,22 @@ from .function import Function
8
8
  from .signature import Signature
9
9
 
10
10
 
11
- class ExprTemplateFunction(Function):
12
- """A parameterized expression from which an executable Expr is created with a function call."""
11
+ class ExprTemplate:
12
+ """
13
+ Encapsulates a single signature of an `ExprTemplateFunction` and its associated parameterized expression,
14
+ along with various precomputed metadata. (This is analogous to a `Callable`-`Signature` pair in a
15
+ `CallableFunction`.)
16
+ """
17
+ expr: 'pixeltable.exprs.Expr'
18
+ signature: Signature
19
+ param_exprs: list['pixeltable.exprs.Variable']
20
+
21
+ def __init__(self, expr: 'pixeltable.exprs.Expr', signature: Signature):
22
+ from pixeltable import exprs
13
23
 
14
- def __init__(
15
- self, expr: 'pixeltable.exprs.Expr', signature: Signature, self_path: Optional[str] = None,
16
- name: Optional[str] = None):
17
- import pixeltable.exprs as exprs
18
24
  self.expr = expr
19
- self.self_name = name
25
+ self.signature = signature
26
+
20
27
  self.param_exprs = list(set(expr.subexprs(expr_class=exprs.Variable)))
21
28
  # make sure there are no duplicate names
22
29
  assert len(self.param_exprs) == len(set(p.name for p in self.param_exprs))
@@ -24,7 +31,7 @@ class ExprTemplateFunction(Function):
24
31
 
25
32
  # verify default values
26
33
  self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
27
- for param in signature.parameters.values():
34
+ for param in self.signature.parameters.values():
28
35
  if param.default is inspect.Parameter.empty:
29
36
  continue
30
37
  param_expr = self.param_exprs_by_name[param.name]
@@ -35,18 +42,39 @@ class ExprTemplateFunction(Function):
35
42
  msg = str(e)
36
43
  raise excs.Error(f"Default value for parameter '{param.name}': {msg[0].lower() + msg[1:]}")
37
44
 
38
- super().__init__(signature, self_path=self_path)
39
45
 
40
- def instantiate(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.Expr':
46
+ class ExprTemplateFunction(Function):
47
+ """A parameterized expression from which an executable Expr is created with a function call."""
48
+ templates: list[ExprTemplate]
49
+ self_name: str
50
+
51
+ def __init__(self, templates: list[ExprTemplate], self_path: Optional[str] = None, name: Optional[str] = None):
52
+ self.templates = templates
53
+ self.self_name = name
54
+
55
+ super().__init__([t.signature for t in templates], self_path=self_path)
56
+
57
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
58
+ self.templates = [self.templates[signature_idx]]
59
+
60
+ @property
61
+ def template(self) -> ExprTemplate:
62
+ assert not self.is_polymorphic
63
+ return self.templates[0]
64
+
65
+ def instantiate(self, args: Sequence[Any], kwargs: dict[str, Any]) -> 'pixeltable.exprs.Expr':
66
+ from pixeltable import exprs
67
+
68
+ assert not self.is_polymorphic
69
+ template = self.template
41
70
  bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
42
71
  # apply defaults, otherwise we might have Parameters left over
43
72
  bound_args.update(
44
- {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
45
- result = self.expr.copy()
46
- import pixeltable.exprs as exprs
73
+ {param_name: default for param_name, default in template.defaults.items() if param_name not in bound_args})
74
+ result = template.expr.copy()
47
75
  arg_exprs: dict[exprs.Expr, exprs.Expr] = {}
48
76
  for param_name, arg in bound_args.items():
49
- param_expr = self.param_exprs_by_name[param_name]
77
+ param_expr = template.param_exprs_by_name[param_name]
50
78
  if not isinstance(arg, exprs.Expr):
51
79
  # TODO: use the available param_expr.col_type
52
80
  arg_expr = exprs.Expr.from_object(arg)
@@ -56,15 +84,22 @@ class ExprTemplateFunction(Function):
56
84
  arg_expr = arg
57
85
  arg_exprs[param_expr] = arg_expr
58
86
  result = result.substitute(arg_exprs)
59
- import pixeltable.exprs as exprs
60
87
  assert not result._contains(exprs.Variable)
61
88
  return result
62
89
 
63
- def exec(self, *args: Any, **kwargs: Any) -> Any:
64
- expr = self.instantiate(*args, **kwargs)
65
- import pixeltable.exprs as exprs
90
+ def _docstring(self) -> Optional[str]:
91
+ from pixeltable import exprs
92
+
93
+ if isinstance(self.templates[0].expr, exprs.FunctionCall):
94
+ return self.templates[0].expr.fn._docstring()
95
+ return None
96
+
97
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
98
+ from pixeltable import exec, exprs
99
+
100
+ assert not self.is_polymorphic
101
+ expr = self.instantiate(args, kwargs)
66
102
  row_builder = exprs.RowBuilder(output_exprs=[expr], columns=[], input_exprs=[])
67
- import pixeltable.exec as exec
68
103
  row_batch = exec.DataRowBatch(tbl=None, row_builder=row_builder, len=1)
69
104
  row = row_batch[0]
70
105
  row_builder.eval(row, ctx=row_builder.default_eval_ctx)
@@ -78,11 +113,16 @@ class ExprTemplateFunction(Function):
78
113
  def name(self) -> str:
79
114
  return self.self_name
80
115
 
116
+ def __str__(self) -> str:
117
+ return str(self.templates[0].expr)
118
+
81
119
  def _as_dict(self) -> dict:
82
120
  if self.self_path is not None:
83
121
  return super()._as_dict()
122
+ assert not self.is_polymorphic
123
+ assert len(self.templates) == 1
84
124
  return {
85
- 'expr': self.expr.as_dict(),
125
+ 'expr': self.template.expr.as_dict(),
86
126
  'signature': self.signature.as_dict(),
87
127
  'name': self.name,
88
128
  }
@@ -90,7 +130,8 @@ class ExprTemplateFunction(Function):
90
130
  @classmethod
91
131
  def _from_dict(cls, d: dict) -> Function:
92
132
  if 'expr' not in d:
93
- return super()._from_dict(d)
133
+ return super()._from_dict(d)
94
134
  assert 'signature' in d and 'name' in d
95
135
  import pixeltable.exprs as exprs
96
- return cls(exprs.Expr.from_dict(d['expr']), Signature.from_dict(d['signature']), name=d['name'])
136
+ template = ExprTemplate(exprs.Expr.from_dict(d['expr']), Signature.from_dict(d['signature']))
137
+ return cls([template], name=d['name'])