pixeltable 0.2.25__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (97) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +421 -231
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +5 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +14 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +195 -27
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +64 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/ollama.py +8 -8
  58. pixeltable/functions/openai.py +51 -4
  59. pixeltable/functions/timestamp.py +1 -1
  60. pixeltable/functions/video.py +3 -9
  61. pixeltable/functions/vision.py +1 -1
  62. pixeltable/globals.py +354 -80
  63. pixeltable/index/embedding_index.py +106 -34
  64. pixeltable/io/__init__.py +1 -1
  65. pixeltable/io/label_studio.py +1 -1
  66. pixeltable/io/parquet.py +39 -19
  67. pixeltable/iterators/document.py +12 -0
  68. pixeltable/metadata/__init__.py +1 -1
  69. pixeltable/metadata/converters/convert_16.py +2 -1
  70. pixeltable/metadata/converters/convert_17.py +2 -1
  71. pixeltable/metadata/converters/convert_22.py +17 -0
  72. pixeltable/metadata/converters/convert_23.py +35 -0
  73. pixeltable/metadata/converters/convert_24.py +56 -0
  74. pixeltable/metadata/converters/convert_25.py +19 -0
  75. pixeltable/metadata/converters/util.py +4 -2
  76. pixeltable/metadata/notes.py +4 -0
  77. pixeltable/metadata/schema.py +1 -0
  78. pixeltable/plan.py +128 -50
  79. pixeltable/store.py +1 -1
  80. pixeltable/type_system.py +196 -54
  81. pixeltable/utils/arrow.py +8 -3
  82. pixeltable/utils/description_helper.py +89 -0
  83. pixeltable/utils/documents.py +14 -0
  84. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/METADATA +30 -20
  85. pixeltable-0.3.0.dist-info/RECORD +155 -0
  86. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  87. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  88. pixeltable/tool/create_test_db_dump.py +0 -311
  89. pixeltable/tool/create_test_video.py +0 -81
  90. pixeltable/tool/doc_plugins/griffe.py +0 -50
  91. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  92. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  93. pixeltable/tool/embed_udf.py +0 -9
  94. pixeltable/tool/mypy_plugin.py +0 -55
  95. pixeltable-0.2.25.dist-info/RECORD +0 -154
  96. pixeltable-0.2.25.dist-info/entry_points.txt +0 -3
  97. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import TYPE_CHECKING, Any, Callable, Optional
6
+ from copy import copy
7
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, cast
7
8
 
8
9
  import sqlalchemy as sql
10
+ from typing_extensions import Self
9
11
 
10
12
  import pixeltable as pxt
11
13
  import pixeltable.exceptions as excs
@@ -15,7 +17,7 @@ from .globals import resolve_symbol
15
17
  from .signature import Signature
16
18
 
17
19
  if TYPE_CHECKING:
18
- from .expr_template_function import ExprTemplateFunction
20
+ from .expr_template_function import ExprTemplate, ExprTemplateFunction
19
21
 
20
22
 
21
23
  class Function(abc.ABC):
@@ -26,30 +28,42 @@ class Function(abc.ABC):
26
28
  via the member self_path.
27
29
  """
28
30
 
29
- signature: Signature
31
+ signatures: list[Signature]
30
32
  self_path: Optional[str]
31
33
  is_method: bool
32
34
  is_property: bool
33
35
  _conditional_return_type: Optional[Callable[..., ts.ColumnType]]
34
36
 
37
+ # We cache the overload resolutions in self._resolutions. This ensures that each resolution is represented
38
+ # globally by a single Python object. We do this dynamically rather than pre-constructing them in order to
39
+ # avoid circular complexity in the `Function` initialization logic.
40
+ __resolved_fns: list[Self]
41
+
35
42
  # Translates a call to this function with the given arguments to its SQLAlchemy equivalent.
36
43
  # Overriden for specific Function instances via the to_sql() decorator. The override must accept the same
37
44
  # parameter names as the original function. Each parameter is going to be of type sql.ColumnElement.
38
45
  _to_sql: Callable[..., Optional[sql.ColumnElement]]
39
46
 
40
-
41
47
  def __init__(
42
- self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False
48
+ self,
49
+ signatures: list[Signature],
50
+ self_path: Optional[str] = None,
51
+ is_method: bool = False,
52
+ is_property: bool = False
43
53
  ):
44
54
  # Check that stored functions cannot be declared using `is_method` or `is_property`:
45
55
  assert not ((is_method or is_property) and self_path is None)
46
- self.signature = signature
56
+ assert isinstance(signatures, list)
57
+ assert len(signatures) > 0
58
+ self.signatures = signatures
47
59
  self.self_path = self_path # fully-qualified path to self
48
60
  self.is_method = is_method
49
61
  self.is_property = is_property
50
62
  self._conditional_return_type = None
51
63
  self._to_sql = self.__default_to_sql
52
64
 
65
+ self.__resolved_fns = []
66
+
53
67
  @property
54
68
  def name(self) -> str:
55
69
  assert self.self_path is not None
@@ -64,49 +78,166 @@ class Function(abc.ABC):
64
78
  return self.self_path[len(ptf_prefix):]
65
79
  return self.self_path
66
80
 
81
+ @property
82
+ def is_polymorphic(self) -> bool:
83
+ return len(self.signatures) > 1
84
+
85
+ @property
86
+ def signature(self) -> Signature:
87
+ assert not self.is_polymorphic
88
+ return self.signatures[0]
89
+
67
90
  @property
68
91
  def arity(self) -> int:
92
+ assert not self.is_polymorphic
69
93
  return len(self.signature.parameters)
70
94
 
95
+ def _docstring(self) -> Optional[str]:
96
+ return None
97
+
71
98
  def help_str(self) -> str:
72
- return self.display_name + str(self.signature)
99
+ docstring = self._docstring()
100
+ display = self.display_name + str(self.signatures[0])
101
+ if docstring is None:
102
+ return display
103
+ return f'{display}\n\n{docstring}'
104
+
105
+ @property
106
+ def _resolved_fns(self) -> list[Self]:
107
+ """
108
+ Return the list of overload resolutions for this `Function`, constructing it first if necessary.
109
+ Each resolution is a new `Function` instance that retains just the single signature at index `signature_idx`,
110
+ and is otherwise identical to this `Function`.
111
+ """
112
+ if len(self.__resolved_fns) == 0:
113
+ # The list of overload resolutions hasn't been constructed yet; do so now.
114
+ if len(self.signatures) == 1:
115
+ # Only one signature: no need to construct separate resolutions
116
+ self.__resolved_fns.append(self)
117
+ else:
118
+ # Multiple signatures: construct a resolution for each signature
119
+ for idx in range(len(self.signatures)):
120
+ resolution = cast(Self, copy(self))
121
+ resolution.signatures = [self.signatures[idx]]
122
+ resolution.__resolved_fns = [resolution] # Resolves to itself
123
+ resolution._update_as_overload_resolution(idx)
124
+ self.__resolved_fns.append(resolution)
125
+
126
+ return self.__resolved_fns
127
+
128
+ @property
129
+ def _has_resolved_fns(self) -> bool:
130
+ """
131
+ Returns true if the resolved_fns for this `Function` have been constructed (i.e., if self._resolved_fns
132
+ has been accessed).
133
+ """
134
+ return len(self.__resolved_fns) > 0
135
+
136
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
137
+ """
138
+ Subclasses must implement this in order to do any additional work when creating a resolution, beyond
139
+ simply updating `self.signatures`.
140
+ """
141
+ raise NotImplementedError()
73
142
 
74
143
  def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
75
144
  from pixeltable import exprs
76
- bound_args = self.signature.py_signature.bind(*args, **kwargs)
77
- self.validate_call(bound_args.arguments)
78
- return exprs.FunctionCall(self, bound_args.arguments)
145
+
146
+ resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
147
+ return_type = resolved_fn.call_return_type(args, kwargs)
148
+ return exprs.FunctionCall(resolved_fn, bound_args, return_type)
149
+
150
+ def _bind_to_matching_signature(self, args: Sequence[Any], kwargs: dict[str, Any]) -> tuple[Self, dict[str, Any]]:
151
+ result: int = -1
152
+ bound_args: Optional[dict[str, Any]] = None
153
+ assert len(self.signatures) > 0
154
+ if len(self.signatures) == 1:
155
+ # Only one signature: call _bind_to_signature() and surface any errors directly
156
+ result = 0
157
+ bound_args = self._bind_to_signature(0, args, kwargs)
158
+ else:
159
+ # Multiple signatures: try each signature in declaration order and trap any errors.
160
+ # If none of them succeed, raise a generic error message.
161
+ for i in range(len(self.signatures)):
162
+ try:
163
+ bound_args = self._bind_to_signature(i, args, kwargs)
164
+ except (TypeError, excs.Error):
165
+ continue
166
+ result = i
167
+ break
168
+ if result == -1:
169
+ raise excs.Error(f'Function {self.name!r} has no matching signature for arguments')
170
+ assert result >= 0
171
+ assert bound_args is not None
172
+ return self._resolved_fns[result], bound_args
173
+
174
+ def _bind_to_signature(self, signature_idx: int, args: Sequence[Any], kwargs: dict[str, Any]) -> dict[str, Any]:
175
+ from pixeltable import exprs
176
+
177
+ signature = self.signatures[signature_idx]
178
+ bound_args = signature.py_signature.bind(*args, **kwargs).arguments
179
+ self._resolved_fns[signature_idx].validate_call(bound_args)
180
+ exprs.FunctionCall.normalize_args(self.name, signature, bound_args)
181
+ return bound_args
79
182
 
80
183
  def validate_call(self, bound_args: dict[str, Any]) -> None:
81
184
  """Override this to do custom validation of the arguments"""
82
- pass
185
+ assert not self.is_polymorphic
83
186
 
84
- def call_return_type(self, kwargs: dict[str, Any]) -> ts.ColumnType:
187
+ def call_return_type(self, args: Sequence[Any], kwargs: dict[str, Any]) -> ts.ColumnType:
85
188
  """Return the type of the value returned by calling this function with the given arguments"""
189
+ assert not self.is_polymorphic
86
190
  if self._conditional_return_type is None:
87
191
  return self.signature.return_type
88
- bound_args = self.signature.py_signature.bind(**kwargs)
192
+ bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
89
193
  kw_args: dict[str, Any] = {}
90
194
  sig = inspect.signature(self._conditional_return_type)
91
195
  for param in sig.parameters.values():
92
- if param.name in bound_args.arguments:
93
- kw_args[param.name] = bound_args.arguments[param.name]
196
+ if param.name in bound_args:
197
+ kw_args[param.name] = bound_args[param.name]
94
198
  return self._conditional_return_type(**kw_args)
95
199
 
96
200
  def conditional_return_type(self, fn: Callable[..., ts.ColumnType]) -> Callable[..., ts.ColumnType]:
97
201
  """Instance decorator for specifying a conditional return type for this function"""
98
202
  # verify that call_return_type only has parameters that are also present in the signature
99
- sig = inspect.signature(fn)
100
- for param in sig.parameters.values():
101
- if param.name not in self.signature.parameters:
102
- raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in the signature')
203
+ fn_sig = inspect.signature(fn)
204
+ for param in fn_sig.parameters.values():
205
+ for self_sig in self.signatures:
206
+ if param.name not in self_sig.parameters:
207
+ raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in a signature')
103
208
  self._conditional_return_type = fn
104
209
  return fn
105
210
 
106
211
  def using(self, **kwargs: Any) -> 'ExprTemplateFunction':
212
+ from .expr_template_function import ExprTemplateFunction
213
+
214
+ assert len(self.signatures) > 0
215
+ if len(self.signatures) == 1:
216
+ # Only one signature: call _bind_and_create_template() and surface any errors directly
217
+ template = self._bind_and_create_template(kwargs)
218
+ return ExprTemplateFunction([template])
219
+ else:
220
+ # Multiple signatures: iterate over each signature and generate a template for each
221
+ # successful binding. If there are no successful bindings, raise a generic error.
222
+ # (Note that the resulting ExprTemplateFunction may have strictly fewer signatures than
223
+ # this Function, in the event that only some of the signatures are successfully bound.)
224
+ templates: list['ExprTemplate'] = []
225
+ for i in range(len(self.signatures)):
226
+ try:
227
+ template = self._resolved_fns[i]._bind_and_create_template(kwargs)
228
+ templates.append(template)
229
+ except (TypeError, excs.Error):
230
+ continue
231
+ if len(templates) == 0:
232
+ raise excs.Error(f'Function {self.name!r} has no matching signature for arguments')
233
+ return ExprTemplateFunction(templates)
234
+
235
+ def _bind_and_create_template(self, kwargs: dict[str, Any]) -> 'ExprTemplate':
107
236
  from pixeltable import exprs
108
237
 
109
- from .expr_template_function import ExprTemplateFunction
238
+ from .expr_template_function import ExprTemplate
239
+
240
+ assert not self.is_polymorphic
110
241
 
111
242
  # Resolve each kwarg into a parameter binding
112
243
  bindings: dict[str, exprs.Expr] = {}
@@ -127,16 +258,18 @@ class Function(abc.ABC):
127
258
  for param in residual_params:
128
259
  bindings[param.name] = exprs.Variable(param.name, param.col_type)
129
260
 
130
- call = exprs.FunctionCall(self, bindings)
261
+ return_type = self.call_return_type([], bindings)
262
+ call = exprs.FunctionCall(self, bindings, return_type)
131
263
 
132
264
  # Construct the (n-k)-ary signature of the new function. We use `call.col_type` for this, rather than
133
265
  # `self.signature.return_type`, because the return type of the new function may be specialized via a
134
266
  # conditional return type.
135
267
  new_signature = Signature(call.col_type, residual_params, self.signature.is_batched)
136
- return ExprTemplateFunction(call, new_signature)
268
+
269
+ return ExprTemplate(call, new_signature)
137
270
 
138
271
  @abc.abstractmethod
139
- def exec(self, *args: Any, **kwargs: Any) -> Any:
272
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
140
273
  """Execute the function with the given arguments and return the result."""
141
274
  pass
142
275
 
@@ -158,19 +291,24 @@ class Function(abc.ABC):
158
291
  """Print source code"""
159
292
  print('source not available')
160
293
 
161
- def as_dict(self) -> dict:
294
+ def as_dict(self) -> dict[str, Any]:
162
295
  """
163
296
  Return a serialized reference to the instance that can be passed to json.dumps() and converted back
164
297
  to an instance with from_dict().
165
298
  Subclasses can override _as_dict().
166
299
  """
300
+ # We currently only ever serialize a function that has a specific signature (not a polymorphic form).
301
+ assert not self.is_polymorphic
167
302
  classpath = f'{self.__class__.__module__}.{self.__class__.__qualname__}'
168
303
  return {'_classpath': classpath, **self._as_dict()}
169
304
 
170
305
  def _as_dict(self) -> dict:
171
- """Default serialization: store the path to self (which includes the module path)"""
306
+ """Default serialization: store the path to self (which includes the module path) and signature."""
172
307
  assert self.self_path is not None
173
- return {'path': self.self_path}
308
+ return {
309
+ 'path': self.self_path,
310
+ 'signature': self.signature.as_dict(),
311
+ }
174
312
 
175
313
  @classmethod
176
314
  def from_dict(cls, d: dict) -> Function:
@@ -181,15 +319,45 @@ class Function(abc.ABC):
181
319
  module_path, class_name = d['_classpath'].rsplit('.', 1)
182
320
  class_module = importlib.import_module(module_path)
183
321
  func_class = getattr(class_module, class_name)
322
+ assert isinstance(func_class, type) and issubclass(func_class, Function)
184
323
  return func_class._from_dict(d)
185
324
 
186
325
  @classmethod
187
326
  def _from_dict(cls, d: dict) -> Function:
188
327
  """Default deserialization: load the symbol indicated by the stored symbol_path"""
189
328
  assert 'path' in d and d['path'] is not None
329
+ assert 'signature' in d and d['signature'] is not None
190
330
  instance = resolve_symbol(d['path'])
191
331
  assert isinstance(instance, Function)
192
- return instance
332
+
333
+ # Load the signature from the DB and check that it is still valid (i.e., is still consistent with a signature
334
+ # in the code).
335
+ signature = Signature.from_dict(d['signature'])
336
+ idx = instance.__find_matching_overload(signature)
337
+ if idx is None:
338
+ # No match; generate an informative error message.
339
+ signature_note_str = 'any of its signatures' if instance.is_polymorphic else 'its signature as'
340
+ instance_signature_str = (
341
+ f'{len(instance.signatures)} signatures' if instance.is_polymorphic else str(instance.signature)
342
+ )
343
+ # TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
344
+ # mark any enclosing FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or
345
+ # FunctionCall return type mismatch.
346
+ raise excs.Error(
347
+ f'The signature stored in the database for the UDF `{instance.self_path}` no longer matches '
348
+ f'{signature_note_str} as currently defined in the code.\nThis probably means that the code for '
349
+ f'`{instance.self_path}` has changed in a backward-incompatible way.\n'
350
+ f'Signature in database: {signature}\n'
351
+ f'Signature in code: {instance_signature_str}'
352
+ )
353
+ # We found a match; specialize to the appropriate overload resolution (non-polymorphic form) and return that.
354
+ return instance._resolved_fns[idx]
355
+
356
+ def __find_matching_overload(self, sig: Signature) -> Optional[int]:
357
+ for idx, overload_sig in enumerate(self.signatures):
358
+ if sig.is_consistent_with(overload_sig):
359
+ return idx
360
+ return None
193
361
 
194
362
  def to_store(self) -> tuple[dict, bytes]:
195
363
  """
@@ -13,6 +13,7 @@ import pixeltable.env as env
13
13
  import pixeltable.exceptions as excs
14
14
  import pixeltable.type_system as ts
15
15
  from pixeltable.metadata import schema
16
+
16
17
  from .function import Function
17
18
 
18
19
  _logger = logging.getLogger('pixeltable')
@@ -68,7 +69,7 @@ class FunctionRegistry:
68
69
  raise excs.Error(f'A UDF with that name already exists: {fqn}')
69
70
  self.module_fns[fqn] = fn
70
71
  if fn.is_method or fn.is_property:
71
- base_type = fn.signature.parameters_by_pos[0].col_type.type_enum
72
+ base_type = fn.signatures[0].parameters_by_pos[0].col_type.type_enum
72
73
  if base_type not in self.type_methods:
73
74
  self.type_methods[base_type] = {}
74
75
  if fn.name in self.type_methods[base_type]:
@@ -1,23 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Any, Callable, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
5
5
 
6
6
  import sqlalchemy as sql
7
7
 
8
- import pixeltable as pxt
8
+ import pixeltable.exceptions as excs
9
+ import pixeltable.type_system as ts
9
10
  from pixeltable import exprs
10
11
 
11
12
  from .function import Function
12
13
  from .signature import Signature
13
14
 
15
+ if TYPE_CHECKING:
16
+ from pixeltable import DataFrame
17
+
14
18
 
15
19
  class QueryTemplateFunction(Function):
16
20
  """A parameterized query/DataFrame from which an executable DataFrame is created with a function call."""
17
21
 
18
22
  @classmethod
19
23
  def create(
20
- cls, template_callable: Callable, param_types: Optional[list[pxt.ColumnType]], path: str, name: str
24
+ cls, template_callable: Callable, param_types: Optional[list[ts.ColumnType]], path: str, name: str
21
25
  ) -> QueryTemplateFunction:
22
26
  # we need to construct a template df and a signature
23
27
  py_sig = inspect.signature(template_callable)
@@ -29,14 +33,15 @@ class QueryTemplateFunction(Function):
29
33
  from pixeltable import DataFrame
30
34
  assert isinstance(template_df, DataFrame)
31
35
  # we take params and return json
32
- sig = Signature(return_type=pxt.JsonType(), parameters=params)
36
+ sig = Signature(return_type=ts.JsonType(), parameters=params)
33
37
  return QueryTemplateFunction(template_df, sig, path=path, name=name)
34
38
 
35
39
  def __init__(
36
- self, template_df: Optional['pxt.DataFrame'], sig: Optional[Signature], path: Optional[str] = None,
40
+ self, template_df: Optional['DataFrame'], sig: Signature, path: Optional[str] = None,
37
41
  name: Optional[str] = None,
38
42
  ):
39
- super().__init__(sig, self_path=path)
43
+ assert sig is not None
44
+ super().__init__([sig], self_path=path)
40
45
  self.self_name = name
41
46
  self.template_df = template_df
42
47
 
@@ -48,16 +53,20 @@ class QueryTemplateFunction(Function):
48
53
  # convert defaults to Literals
49
54
  self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
50
55
  param_types = self.template_df.parameters()
51
- for param in [p for p in self.signature.parameters.values() if p.has_default()]:
56
+ for param in [p for p in sig.parameters.values() if p.has_default()]:
52
57
  assert param.name in param_types
53
58
  param_type = param_types[param.name]
54
59
  literal_default = exprs.Literal(param.default, col_type=param_type)
55
60
  self.defaults[param.name] = literal_default
56
61
 
62
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
63
+ pass # only one signature supported for QueryTemplateFunction
64
+
57
65
  def set_conn(self, conn: Optional[sql.engine.Connection]) -> None:
58
66
  self.conn = conn
59
67
 
60
- def exec(self, *args: Any, **kwargs: Any) -> Any:
68
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
69
+ assert not self.is_polymorphic
61
70
  bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
62
71
  # apply defaults, otherwise we might have Parameters left over
63
72
  bound_args.update(
@@ -75,9 +84,42 @@ class QueryTemplateFunction(Function):
75
84
  return self.self_name
76
85
 
77
86
  def _as_dict(self) -> dict:
78
- return {'name': self.name, 'signature': self.signature.as_dict(), 'df': self.template_df.as_dict()}
87
+ return {'name': self.name, 'signature': self.signatures[0].as_dict(), 'df': self.template_df.as_dict()}
79
88
 
80
89
  @classmethod
81
90
  def _from_dict(cls, d: dict) -> Function:
82
91
  from pixeltable.dataframe import DataFrame
83
92
  return cls(DataFrame.from_dict(d['df']), Signature.from_dict(d['signature']), name=d['name'])
93
+
94
+
95
+ @overload
96
+ def query(self, py_fn: Callable) -> QueryTemplateFunction: ...
97
+
98
+ @overload
99
+ def query(
100
+ self, *, param_types: Optional[list[ts.ColumnType]] = None
101
+ ) -> Callable[[Callable], QueryTemplateFunction]: ...
102
+
103
+ def query(*args: Any, **kwargs: Any) -> Any:
104
+ def make_query_template(
105
+ py_fn: Callable, param_types: Optional[list[ts.ColumnType]]
106
+ ) -> QueryTemplateFunction:
107
+ if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
108
+ # this is a named function in a module
109
+ function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
110
+ else:
111
+ function_path = None
112
+ query_name = py_fn.__name__
113
+ query_fn = QueryTemplateFunction.create(
114
+ py_fn, param_types=param_types, path=function_path, name=query_name)
115
+ return query_fn
116
+
117
+ # TODO: verify that the inferred return type matches that of the template
118
+ # TODO: verify that the signature doesn't contain batched parameters
119
+
120
+ if len(args) == 1:
121
+ assert len(kwargs) == 0 and callable(args[0])
122
+ return make_query_template(args[0], None)
123
+ else:
124
+ assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
125
+ return lambda py_fn: make_query_template(py_fn, kwargs['param_types'])
@@ -111,6 +111,38 @@ class Signature:
111
111
  parameters = [Parameter.from_dict(param_dict) for param_dict in d['parameters']]
112
112
  return cls(ts.ColumnType.from_dict(d['return_type']), parameters, d['is_batched'])
113
113
 
114
+ def is_consistent_with(self, other: Signature) -> bool:
115
+ """
116
+ Returns True if this signature is consistent with the other signature.
117
+ S is consistent with T if we could safely replace S by T in any call where S is used. Specifically:
118
+ (i) S.return_type is a supertype of T.return_type
119
+ (ii) For each parameter p in S, there is a parameter q in T such that:
120
+ - p and q have the same name and kind
121
+ - q.col_type is a supertype of p.col_type
122
+ (iii) For each *required* parameter q in T, there is a parameter p in S with the same name (in which
123
+ case the kinds and types must also match, by condition (ii)).
124
+ """
125
+ # Check (i)
126
+ if not self.get_return_type().is_supertype_of(other.get_return_type(), ignore_nullable=True):
127
+ return False
128
+
129
+ # Check (ii)
130
+ for param_name, param in self.parameters.items():
131
+ if param_name not in other.parameters:
132
+ return False
133
+ other_param = other.parameters[param_name]
134
+ if (param.kind != other_param.kind or
135
+ (param.col_type is None) != (other_param.col_type is None) or # this can happen if they are varargs
136
+ param.col_type is not None and not other_param.col_type.is_supertype_of(param.col_type, ignore_nullable=True)):
137
+ return False
138
+
139
+ # Check (iii)
140
+ for other_param in other.required_parameters:
141
+ if other_param.name not in self.parameters:
142
+ return False
143
+
144
+ return True
145
+
114
146
  def __eq__(self, other: object) -> bool:
115
147
  if not isinstance(other, Signature):
116
148
  return False
@@ -156,8 +188,12 @@ class Signature:
156
188
 
157
189
  @classmethod
158
190
  def create_parameters(
159
- cls, py_fn: Optional[Callable] = None, py_params: Optional[list[inspect.Parameter]] = None,
160
- param_types: Optional[list[ts.ColumnType]] = None
191
+ cls,
192
+ py_fn: Optional[Callable] = None,
193
+ py_params: Optional[list[inspect.Parameter]] = None,
194
+ param_types: Optional[list[ts.ColumnType]] = None,
195
+ type_substitutions: Optional[dict] = None,
196
+ is_cls_method: bool = False
161
197
  ) -> list[Parameter]:
162
198
  assert (py_fn is None) != (py_params is None)
163
199
  if py_fn is not None:
@@ -165,7 +201,12 @@ class Signature:
165
201
  py_params = list(sig.parameters.values())
166
202
  parameters: list[Parameter] = []
167
203
 
204
+ if type_substitutions is None:
205
+ type_substitutions = {}
206
+
168
207
  for idx, param in enumerate(py_params):
208
+ if is_cls_method and idx == 0:
209
+ continue # skip 'self' or 'cls' parameter
169
210
  if param.name in cls.SPECIAL_PARAM_NAMES:
170
211
  raise excs.Error(f"'{param.name}' is a reserved parameter name")
171
212
  if param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD:
@@ -179,7 +220,12 @@ class Signature:
179
220
  param_type = param_types[idx]
180
221
  is_batched = False
181
222
  else:
182
- param_type, is_batched = cls._infer_type(param.annotation)
223
+ py_type: Optional[type]
224
+ if param.annotation in type_substitutions:
225
+ py_type = type_substitutions[param.annotation]
226
+ else:
227
+ py_type = param.annotation
228
+ param_type, is_batched = cls._infer_type(py_type)
183
229
  if param_type is None:
184
230
  raise excs.Error(f'Cannot infer pixeltable type for parameter {param.name}')
185
231
 
@@ -190,18 +236,29 @@ class Signature:
190
236
 
191
237
  @classmethod
192
238
  def create(
193
- cls, py_fn: Callable,
239
+ cls,
240
+ py_fn: Callable,
194
241
  param_types: Optional[list[ts.ColumnType]] = None,
195
- return_type: Optional[ts.ColumnType] = None
242
+ return_type: Optional[ts.ColumnType] = None,
243
+ type_substitutions: Optional[dict] = None,
244
+ is_cls_method: bool = False
196
245
  ) -> Signature:
197
246
  """Create a signature for the given Callable.
198
247
  Infer the parameter and return types, if none are specified.
199
248
  Raises an exception if the types cannot be inferred.
200
249
  """
201
- parameters = cls.create_parameters(py_fn=py_fn, param_types=param_types)
250
+ if type_substitutions is None:
251
+ type_substitutions = {}
252
+
253
+ parameters = cls.create_parameters(py_fn=py_fn, param_types=param_types, is_cls_method=is_cls_method, type_substitutions=type_substitutions)
202
254
  sig = inspect.signature(py_fn)
203
255
  if return_type is None:
204
- return_type, return_is_batched = cls._infer_type(sig.return_annotation)
256
+ py_type: Optional[type]
257
+ if sig.return_annotation in type_substitutions:
258
+ py_type = type_substitutions[sig.return_annotation]
259
+ else:
260
+ py_type = sig.return_annotation
261
+ return_type, return_is_batched = cls._infer_type(py_type)
205
262
  if return_type is None:
206
263
  raise excs.Error('Cannot infer pixeltable return type')
207
264
  else: