pixeltable 0.2.24__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (101) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +531 -251
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +8 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +17 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +227 -23
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +65 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/mistralai.py +0 -2
  58. pixeltable/functions/ollama.py +8 -8
  59. pixeltable/functions/openai.py +51 -4
  60. pixeltable/functions/timestamp.py +1 -1
  61. pixeltable/functions/video.py +3 -9
  62. pixeltable/functions/vision.py +1 -1
  63. pixeltable/globals.py +374 -89
  64. pixeltable/index/embedding_index.py +106 -29
  65. pixeltable/io/__init__.py +1 -1
  66. pixeltable/io/label_studio.py +1 -1
  67. pixeltable/io/parquet.py +39 -19
  68. pixeltable/iterators/__init__.py +1 -0
  69. pixeltable/iterators/document.py +12 -0
  70. pixeltable/iterators/image.py +100 -0
  71. pixeltable/iterators/video.py +7 -8
  72. pixeltable/metadata/__init__.py +1 -1
  73. pixeltable/metadata/converters/convert_16.py +2 -1
  74. pixeltable/metadata/converters/convert_17.py +2 -1
  75. pixeltable/metadata/converters/convert_22.py +17 -0
  76. pixeltable/metadata/converters/convert_23.py +35 -0
  77. pixeltable/metadata/converters/convert_24.py +56 -0
  78. pixeltable/metadata/converters/convert_25.py +19 -0
  79. pixeltable/metadata/converters/util.py +4 -2
  80. pixeltable/metadata/notes.py +4 -0
  81. pixeltable/metadata/schema.py +1 -0
  82. pixeltable/plan.py +129 -51
  83. pixeltable/store.py +1 -1
  84. pixeltable/type_system.py +196 -54
  85. pixeltable/utils/arrow.py +8 -3
  86. pixeltable/utils/description_helper.py +89 -0
  87. pixeltable/utils/documents.py +14 -0
  88. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/METADATA +32 -22
  89. pixeltable-0.3.0.dist-info/RECORD +155 -0
  90. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  91. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  92. pixeltable/tool/create_test_db_dump.py +0 -308
  93. pixeltable/tool/create_test_video.py +0 -81
  94. pixeltable/tool/doc_plugins/griffe.py +0 -50
  95. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  96. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  97. pixeltable/tool/embed_udf.py +0 -9
  98. pixeltable/tool/mypy_plugin.py +0 -55
  99. pixeltable-0.2.24.dist-info/RECORD +0 -153
  100. pixeltable-0.2.24.dist-info/entry_points.txt +0 -3
  101. {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
@@ -3,16 +3,22 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Any, Callable, Optional
6
+ from copy import copy
7
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, cast
7
8
 
8
9
  import sqlalchemy as sql
10
+ from typing_extensions import Self
9
11
 
10
12
  import pixeltable as pxt
13
+ import pixeltable.exceptions as excs
11
14
  import pixeltable.type_system as ts
12
15
 
13
16
  from .globals import resolve_symbol
14
17
  from .signature import Signature
15
18
 
19
+ if TYPE_CHECKING:
20
+ from .expr_template_function import ExprTemplate, ExprTemplateFunction
21
+
16
22
 
17
23
  class Function(abc.ABC):
18
24
  """Base class for Pixeltable's function interface.
@@ -22,30 +28,42 @@ class Function(abc.ABC):
22
28
  via the member self_path.
23
29
  """
24
30
 
25
- signature: Signature
31
+ signatures: list[Signature]
26
32
  self_path: Optional[str]
27
33
  is_method: bool
28
34
  is_property: bool
29
35
  _conditional_return_type: Optional[Callable[..., ts.ColumnType]]
30
36
 
37
+ # We cache the overload resolutions in self._resolutions. This ensures that each resolution is represented
38
+ # globally by a single Python object. We do this dynamically rather than pre-constructing them in order to
39
+ # avoid circular complexity in the `Function` initialization logic.
40
+ __resolved_fns: list[Self]
41
+
31
42
  # Translates a call to this function with the given arguments to its SQLAlchemy equivalent.
32
43
  # Overriden for specific Function instances via the to_sql() decorator. The override must accept the same
33
44
  # parameter names as the original function. Each parameter is going to be of type sql.ColumnElement.
34
45
  _to_sql: Callable[..., Optional[sql.ColumnElement]]
35
46
 
36
-
37
47
  def __init__(
38
- self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False
48
+ self,
49
+ signatures: list[Signature],
50
+ self_path: Optional[str] = None,
51
+ is_method: bool = False,
52
+ is_property: bool = False
39
53
  ):
40
54
  # Check that stored functions cannot be declared using `is_method` or `is_property`:
41
55
  assert not ((is_method or is_property) and self_path is None)
42
- self.signature = signature
56
+ assert isinstance(signatures, list)
57
+ assert len(signatures) > 0
58
+ self.signatures = signatures
43
59
  self.self_path = self_path # fully-qualified path to self
44
60
  self.is_method = is_method
45
61
  self.is_property = is_property
46
62
  self._conditional_return_type = None
47
63
  self._to_sql = self.__default_to_sql
48
64
 
65
+ self.__resolved_fns = []
66
+
49
67
  @property
50
68
  def name(self) -> str:
51
69
  assert self.self_path is not None
@@ -60,47 +78,198 @@ class Function(abc.ABC):
60
78
  return self.self_path[len(ptf_prefix):]
61
79
  return self.self_path
62
80
 
81
+ @property
82
+ def is_polymorphic(self) -> bool:
83
+ return len(self.signatures) > 1
84
+
85
+ @property
86
+ def signature(self) -> Signature:
87
+ assert not self.is_polymorphic
88
+ return self.signatures[0]
89
+
63
90
  @property
64
91
  def arity(self) -> int:
92
+ assert not self.is_polymorphic
65
93
  return len(self.signature.parameters)
66
94
 
95
+ def _docstring(self) -> Optional[str]:
96
+ return None
97
+
67
98
  def help_str(self) -> str:
68
- return self.display_name + str(self.signature)
99
+ docstring = self._docstring()
100
+ display = self.display_name + str(self.signatures[0])
101
+ if docstring is None:
102
+ return display
103
+ return f'{display}\n\n{docstring}'
104
+
105
+ @property
106
+ def _resolved_fns(self) -> list[Self]:
107
+ """
108
+ Return the list of overload resolutions for this `Function`, constructing it first if necessary.
109
+ Each resolution is a new `Function` instance that retains just the single signature at index `signature_idx`,
110
+ and is otherwise identical to this `Function`.
111
+ """
112
+ if len(self.__resolved_fns) == 0:
113
+ # The list of overload resolutions hasn't been constructed yet; do so now.
114
+ if len(self.signatures) == 1:
115
+ # Only one signature: no need to construct separate resolutions
116
+ self.__resolved_fns.append(self)
117
+ else:
118
+ # Multiple signatures: construct a resolution for each signature
119
+ for idx in range(len(self.signatures)):
120
+ resolution = cast(Self, copy(self))
121
+ resolution.signatures = [self.signatures[idx]]
122
+ resolution.__resolved_fns = [resolution] # Resolves to itself
123
+ resolution._update_as_overload_resolution(idx)
124
+ self.__resolved_fns.append(resolution)
125
+
126
+ return self.__resolved_fns
127
+
128
+ @property
129
+ def _has_resolved_fns(self) -> bool:
130
+ """
131
+ Returns true if the resolved_fns for this `Function` have been constructed (i.e., if self._resolved_fns
132
+ has been accessed).
133
+ """
134
+ return len(self.__resolved_fns) > 0
135
+
136
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
137
+ """
138
+ Subclasses must implement this in order to do any additional work when creating a resolution, beyond
139
+ simply updating `self.signatures`.
140
+ """
141
+ raise NotImplementedError()
69
142
 
70
143
  def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
71
144
  from pixeltable import exprs
72
- bound_args = self.signature.py_signature.bind(*args, **kwargs)
73
- self.validate_call(bound_args.arguments)
74
- return exprs.FunctionCall(self, bound_args.arguments)
145
+
146
+ resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
147
+ return_type = resolved_fn.call_return_type(args, kwargs)
148
+ return exprs.FunctionCall(resolved_fn, bound_args, return_type)
149
+
150
+ def _bind_to_matching_signature(self, args: Sequence[Any], kwargs: dict[str, Any]) -> tuple[Self, dict[str, Any]]:
151
+ result: int = -1
152
+ bound_args: Optional[dict[str, Any]] = None
153
+ assert len(self.signatures) > 0
154
+ if len(self.signatures) == 1:
155
+ # Only one signature: call _bind_to_signature() and surface any errors directly
156
+ result = 0
157
+ bound_args = self._bind_to_signature(0, args, kwargs)
158
+ else:
159
+ # Multiple signatures: try each signature in declaration order and trap any errors.
160
+ # If none of them succeed, raise a generic error message.
161
+ for i in range(len(self.signatures)):
162
+ try:
163
+ bound_args = self._bind_to_signature(i, args, kwargs)
164
+ except (TypeError, excs.Error):
165
+ continue
166
+ result = i
167
+ break
168
+ if result == -1:
169
+ raise excs.Error(f'Function {self.name!r} has no matching signature for arguments')
170
+ assert result >= 0
171
+ assert bound_args is not None
172
+ return self._resolved_fns[result], bound_args
173
+
174
+ def _bind_to_signature(self, signature_idx: int, args: Sequence[Any], kwargs: dict[str, Any]) -> dict[str, Any]:
175
+ from pixeltable import exprs
176
+
177
+ signature = self.signatures[signature_idx]
178
+ bound_args = signature.py_signature.bind(*args, **kwargs).arguments
179
+ self._resolved_fns[signature_idx].validate_call(bound_args)
180
+ exprs.FunctionCall.normalize_args(self.name, signature, bound_args)
181
+ return bound_args
75
182
 
76
183
  def validate_call(self, bound_args: dict[str, Any]) -> None:
77
184
  """Override this to do custom validation of the arguments"""
78
- pass
185
+ assert not self.is_polymorphic
79
186
 
80
- def call_return_type(self, kwargs: dict[str, Any]) -> ts.ColumnType:
187
+ def call_return_type(self, args: Sequence[Any], kwargs: dict[str, Any]) -> ts.ColumnType:
81
188
  """Return the type of the value returned by calling this function with the given arguments"""
189
+ assert not self.is_polymorphic
82
190
  if self._conditional_return_type is None:
83
191
  return self.signature.return_type
84
- bound_args = self.signature.py_signature.bind(**kwargs)
192
+ bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
85
193
  kw_args: dict[str, Any] = {}
86
194
  sig = inspect.signature(self._conditional_return_type)
87
195
  for param in sig.parameters.values():
88
- if param.name in bound_args.arguments:
89
- kw_args[param.name] = bound_args.arguments[param.name]
196
+ if param.name in bound_args:
197
+ kw_args[param.name] = bound_args[param.name]
90
198
  return self._conditional_return_type(**kw_args)
91
199
 
92
200
  def conditional_return_type(self, fn: Callable[..., ts.ColumnType]) -> Callable[..., ts.ColumnType]:
93
201
  """Instance decorator for specifying a conditional return type for this function"""
94
202
  # verify that call_return_type only has parameters that are also present in the signature
95
- sig = inspect.signature(fn)
96
- for param in sig.parameters.values():
97
- if param.name not in self.signature.parameters:
98
- raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in the signature')
203
+ fn_sig = inspect.signature(fn)
204
+ for param in fn_sig.parameters.values():
205
+ for self_sig in self.signatures:
206
+ if param.name not in self_sig.parameters:
207
+ raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in a signature')
99
208
  self._conditional_return_type = fn
100
209
  return fn
101
210
 
211
+ def using(self, **kwargs: Any) -> 'ExprTemplateFunction':
212
+ from .expr_template_function import ExprTemplateFunction
213
+
214
+ assert len(self.signatures) > 0
215
+ if len(self.signatures) == 1:
216
+ # Only one signature: call _bind_and_create_template() and surface any errors directly
217
+ template = self._bind_and_create_template(kwargs)
218
+ return ExprTemplateFunction([template])
219
+ else:
220
+ # Multiple signatures: iterate over each signature and generate a template for each
221
+ # successful binding. If there are no successful bindings, raise a generic error.
222
+ # (Note that the resulting ExprTemplateFunction may have strictly fewer signatures than
223
+ # this Function, in the event that only some of the signatures are successfully bound.)
224
+ templates: list['ExprTemplate'] = []
225
+ for i in range(len(self.signatures)):
226
+ try:
227
+ template = self._resolved_fns[i]._bind_and_create_template(kwargs)
228
+ templates.append(template)
229
+ except (TypeError, excs.Error):
230
+ continue
231
+ if len(templates) == 0:
232
+ raise excs.Error(f'Function {self.name!r} has no matching signature for arguments')
233
+ return ExprTemplateFunction(templates)
234
+
235
+ def _bind_and_create_template(self, kwargs: dict[str, Any]) -> 'ExprTemplate':
236
+ from pixeltable import exprs
237
+
238
+ from .expr_template_function import ExprTemplate
239
+
240
+ assert not self.is_polymorphic
241
+
242
+ # Resolve each kwarg into a parameter binding
243
+ bindings: dict[str, exprs.Expr] = {}
244
+ for k, v in kwargs.items():
245
+ if k not in self.signature.parameters:
246
+ raise excs.Error(f'Unknown parameter: {k}')
247
+ param = self.signature.parameters[k]
248
+ expr = exprs.Expr.from_object(v)
249
+ if not param.col_type.is_supertype_of(expr.col_type):
250
+ raise excs.Error(f'Expected type `{param.col_type}` for parameter `{k}`; got `{expr.col_type}`')
251
+ bindings[k] = v # Use the original value, not the Expr (The Expr is only for validation)
252
+
253
+ residual_params = [
254
+ p for p in self.signature.parameters.values() if p.name not in bindings
255
+ ]
256
+
257
+ # Bind each remaining parameter to a like-named variable
258
+ for param in residual_params:
259
+ bindings[param.name] = exprs.Variable(param.name, param.col_type)
260
+
261
+ return_type = self.call_return_type([], bindings)
262
+ call = exprs.FunctionCall(self, bindings, return_type)
263
+
264
+ # Construct the (n-k)-ary signature of the new function. We use `call.col_type` for this, rather than
265
+ # `self.signature.return_type`, because the return type of the new function may be specialized via a
266
+ # conditional return type.
267
+ new_signature = Signature(call.col_type, residual_params, self.signature.is_batched)
268
+
269
+ return ExprTemplate(call, new_signature)
270
+
102
271
  @abc.abstractmethod
103
- def exec(self, *args: Any, **kwargs: Any) -> Any:
272
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
104
273
  """Execute the function with the given arguments and return the result."""
105
274
  pass
106
275
 
@@ -122,19 +291,24 @@ class Function(abc.ABC):
122
291
  """Print source code"""
123
292
  print('source not available')
124
293
 
125
- def as_dict(self) -> dict:
294
+ def as_dict(self) -> dict[str, Any]:
126
295
  """
127
296
  Return a serialized reference to the instance that can be passed to json.dumps() and converted back
128
297
  to an instance with from_dict().
129
298
  Subclasses can override _as_dict().
130
299
  """
300
+ # We currently only ever serialize a function that has a specific signature (not a polymorphic form).
301
+ assert not self.is_polymorphic
131
302
  classpath = f'{self.__class__.__module__}.{self.__class__.__qualname__}'
132
303
  return {'_classpath': classpath, **self._as_dict()}
133
304
 
134
305
  def _as_dict(self) -> dict:
135
- """Default serialization: store the path to self (which includes the module path)"""
306
+ """Default serialization: store the path to self (which includes the module path) and signature."""
136
307
  assert self.self_path is not None
137
- return {'path': self.self_path}
308
+ return {
309
+ 'path': self.self_path,
310
+ 'signature': self.signature.as_dict(),
311
+ }
138
312
 
139
313
  @classmethod
140
314
  def from_dict(cls, d: dict) -> Function:
@@ -145,15 +319,45 @@ class Function(abc.ABC):
145
319
  module_path, class_name = d['_classpath'].rsplit('.', 1)
146
320
  class_module = importlib.import_module(module_path)
147
321
  func_class = getattr(class_module, class_name)
322
+ assert isinstance(func_class, type) and issubclass(func_class, Function)
148
323
  return func_class._from_dict(d)
149
324
 
150
325
  @classmethod
151
326
  def _from_dict(cls, d: dict) -> Function:
152
327
  """Default deserialization: load the symbol indicated by the stored symbol_path"""
153
328
  assert 'path' in d and d['path'] is not None
329
+ assert 'signature' in d and d['signature'] is not None
154
330
  instance = resolve_symbol(d['path'])
155
331
  assert isinstance(instance, Function)
156
- return instance
332
+
333
+ # Load the signature from the DB and check that it is still valid (i.e., is still consistent with a signature
334
+ # in the code).
335
+ signature = Signature.from_dict(d['signature'])
336
+ idx = instance.__find_matching_overload(signature)
337
+ if idx is None:
338
+ # No match; generate an informative error message.
339
+ signature_note_str = 'any of its signatures' if instance.is_polymorphic else 'its signature as'
340
+ instance_signature_str = (
341
+ f'{len(instance.signatures)} signatures' if instance.is_polymorphic else str(instance.signature)
342
+ )
343
+ # TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
344
+ # mark any enclosing FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or
345
+ # FunctionCall return type mismatch.
346
+ raise excs.Error(
347
+ f'The signature stored in the database for the UDF `{instance.self_path}` no longer matches '
348
+ f'{signature_note_str} as currently defined in the code.\nThis probably means that the code for '
349
+ f'`{instance.self_path}` has changed in a backward-incompatible way.\n'
350
+ f'Signature in database: {signature}\n'
351
+ f'Signature in code: {instance_signature_str}'
352
+ )
353
+ # We found a match; specialize to the appropriate overload resolution (non-polymorphic form) and return that.
354
+ return instance._resolved_fns[idx]
355
+
356
+ def __find_matching_overload(self, sig: Signature) -> Optional[int]:
357
+ for idx, overload_sig in enumerate(self.signatures):
358
+ if sig.is_consistent_with(overload_sig):
359
+ return idx
360
+ return None
157
361
 
158
362
  def to_store(self) -> tuple[dict, bytes]:
159
363
  """
@@ -13,6 +13,7 @@ import pixeltable.env as env
13
13
  import pixeltable.exceptions as excs
14
14
  import pixeltable.type_system as ts
15
15
  from pixeltable.metadata import schema
16
+
16
17
  from .function import Function
17
18
 
18
19
  _logger = logging.getLogger('pixeltable')
@@ -68,7 +69,7 @@ class FunctionRegistry:
68
69
  raise excs.Error(f'A UDF with that name already exists: {fqn}')
69
70
  self.module_fns[fqn] = fn
70
71
  if fn.is_method or fn.is_property:
71
- base_type = fn.signature.parameters_by_pos[0].col_type.type_enum
72
+ base_type = fn.signatures[0].parameters_by_pos[0].col_type.type_enum
72
73
  if base_type not in self.type_methods:
73
74
  self.type_methods[base_type] = {}
74
75
  if fn.name in self.type_methods[base_type]:
@@ -1,23 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Any, Callable, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
5
5
 
6
6
  import sqlalchemy as sql
7
7
 
8
- import pixeltable as pxt
8
+ import pixeltable.exceptions as excs
9
+ import pixeltable.type_system as ts
9
10
  from pixeltable import exprs
10
11
 
11
12
  from .function import Function
12
13
  from .signature import Signature
13
14
 
15
+ if TYPE_CHECKING:
16
+ from pixeltable import DataFrame
17
+
14
18
 
15
19
  class QueryTemplateFunction(Function):
16
20
  """A parameterized query/DataFrame from which an executable DataFrame is created with a function call."""
17
21
 
18
22
  @classmethod
19
23
  def create(
20
- cls, template_callable: Callable, param_types: Optional[list[pxt.ColumnType]], path: str, name: str
24
+ cls, template_callable: Callable, param_types: Optional[list[ts.ColumnType]], path: str, name: str
21
25
  ) -> QueryTemplateFunction:
22
26
  # we need to construct a template df and a signature
23
27
  py_sig = inspect.signature(template_callable)
@@ -29,14 +33,15 @@ class QueryTemplateFunction(Function):
29
33
  from pixeltable import DataFrame
30
34
  assert isinstance(template_df, DataFrame)
31
35
  # we take params and return json
32
- sig = Signature(return_type=pxt.JsonType(), parameters=params)
36
+ sig = Signature(return_type=ts.JsonType(), parameters=params)
33
37
  return QueryTemplateFunction(template_df, sig, path=path, name=name)
34
38
 
35
39
  def __init__(
36
- self, template_df: Optional['pxt.DataFrame'], sig: Optional[Signature], path: Optional[str] = None,
40
+ self, template_df: Optional['DataFrame'], sig: Signature, path: Optional[str] = None,
37
41
  name: Optional[str] = None,
38
42
  ):
39
- super().__init__(sig, self_path=path)
43
+ assert sig is not None
44
+ super().__init__([sig], self_path=path)
40
45
  self.self_name = name
41
46
  self.template_df = template_df
42
47
 
@@ -48,16 +53,20 @@ class QueryTemplateFunction(Function):
48
53
  # convert defaults to Literals
49
54
  self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
50
55
  param_types = self.template_df.parameters()
51
- for param in [p for p in self.signature.parameters.values() if p.has_default()]:
56
+ for param in [p for p in sig.parameters.values() if p.has_default()]:
52
57
  assert param.name in param_types
53
58
  param_type = param_types[param.name]
54
59
  literal_default = exprs.Literal(param.default, col_type=param_type)
55
60
  self.defaults[param.name] = literal_default
56
61
 
62
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
63
+ pass # only one signature supported for QueryTemplateFunction
64
+
57
65
  def set_conn(self, conn: Optional[sql.engine.Connection]) -> None:
58
66
  self.conn = conn
59
67
 
60
- def exec(self, *args: Any, **kwargs: Any) -> Any:
68
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
69
+ assert not self.is_polymorphic
61
70
  bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
62
71
  # apply defaults, otherwise we might have Parameters left over
63
72
  bound_args.update(
@@ -75,9 +84,42 @@ class QueryTemplateFunction(Function):
75
84
  return self.self_name
76
85
 
77
86
  def _as_dict(self) -> dict:
78
- return {'name': self.name, 'signature': self.signature.as_dict(), 'df': self.template_df.as_dict()}
87
+ return {'name': self.name, 'signature': self.signatures[0].as_dict(), 'df': self.template_df.as_dict()}
79
88
 
80
89
  @classmethod
81
90
  def _from_dict(cls, d: dict) -> Function:
82
91
  from pixeltable.dataframe import DataFrame
83
92
  return cls(DataFrame.from_dict(d['df']), Signature.from_dict(d['signature']), name=d['name'])
93
+
94
+
95
+ @overload
96
+ def query(self, py_fn: Callable) -> QueryTemplateFunction: ...
97
+
98
+ @overload
99
+ def query(
100
+ self, *, param_types: Optional[list[ts.ColumnType]] = None
101
+ ) -> Callable[[Callable], QueryTemplateFunction]: ...
102
+
103
+ def query(*args: Any, **kwargs: Any) -> Any:
104
+ def make_query_template(
105
+ py_fn: Callable, param_types: Optional[list[ts.ColumnType]]
106
+ ) -> QueryTemplateFunction:
107
+ if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
108
+ # this is a named function in a module
109
+ function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
110
+ else:
111
+ function_path = None
112
+ query_name = py_fn.__name__
113
+ query_fn = QueryTemplateFunction.create(
114
+ py_fn, param_types=param_types, path=function_path, name=query_name)
115
+ return query_fn
116
+
117
+ # TODO: verify that the inferred return type matches that of the template
118
+ # TODO: verify that the signature doesn't contain batched parameters
119
+
120
+ if len(args) == 1:
121
+ assert len(kwargs) == 0 and callable(args[0])
122
+ return make_query_template(args[0], None)
123
+ else:
124
+ assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
125
+ return lambda py_fn: make_query_template(py_fn, kwargs['param_types'])
@@ -91,6 +91,7 @@ class Signature:
91
91
  self.parameters_by_pos = parameters.copy()
92
92
  self.constant_parameters = [p for p in parameters if not p.is_batched]
93
93
  self.batched_parameters = [p for p in parameters if p.is_batched]
94
+ self.required_parameters = [p for p in parameters if not p.has_default()]
94
95
  self.py_signature = inspect.Signature([p.to_py_param() for p in self.parameters_by_pos])
95
96
 
96
97
  def get_return_type(self) -> ts.ColumnType:
@@ -110,6 +111,38 @@ class Signature:
110
111
  parameters = [Parameter.from_dict(param_dict) for param_dict in d['parameters']]
111
112
  return cls(ts.ColumnType.from_dict(d['return_type']), parameters, d['is_batched'])
112
113
 
114
+ def is_consistent_with(self, other: Signature) -> bool:
115
+ """
116
+ Returns True if this signature is consistent with the other signature.
117
+ S is consistent with T if we could safely replace S by T in any call where S is used. Specifically:
118
+ (i) S.return_type is a supertype of T.return_type
119
+ (ii) For each parameter p in S, there is a parameter q in T such that:
120
+ - p and q have the same name and kind
121
+ - q.col_type is a supertype of p.col_type
122
+ (iii) For each *required* parameter q in T, there is a parameter p in S with the same name (in which
123
+ case the kinds and types must also match, by condition (ii)).
124
+ """
125
+ # Check (i)
126
+ if not self.get_return_type().is_supertype_of(other.get_return_type(), ignore_nullable=True):
127
+ return False
128
+
129
+ # Check (ii)
130
+ for param_name, param in self.parameters.items():
131
+ if param_name not in other.parameters:
132
+ return False
133
+ other_param = other.parameters[param_name]
134
+ if (param.kind != other_param.kind or
135
+ (param.col_type is None) != (other_param.col_type is None) or # this can happen if they are varargs
136
+ param.col_type is not None and not other_param.col_type.is_supertype_of(param.col_type, ignore_nullable=True)):
137
+ return False
138
+
139
+ # Check (iii)
140
+ for other_param in other.required_parameters:
141
+ if other_param.name not in self.parameters:
142
+ return False
143
+
144
+ return True
145
+
113
146
  def __eq__(self, other: object) -> bool:
114
147
  if not isinstance(other, Signature):
115
148
  return False
@@ -155,8 +188,12 @@ class Signature:
155
188
 
156
189
  @classmethod
157
190
  def create_parameters(
158
- cls, py_fn: Optional[Callable] = None, py_params: Optional[list[inspect.Parameter]] = None,
159
- param_types: Optional[list[ts.ColumnType]] = None
191
+ cls,
192
+ py_fn: Optional[Callable] = None,
193
+ py_params: Optional[list[inspect.Parameter]] = None,
194
+ param_types: Optional[list[ts.ColumnType]] = None,
195
+ type_substitutions: Optional[dict] = None,
196
+ is_cls_method: bool = False
160
197
  ) -> list[Parameter]:
161
198
  assert (py_fn is None) != (py_params is None)
162
199
  if py_fn is not None:
@@ -164,7 +201,12 @@ class Signature:
164
201
  py_params = list(sig.parameters.values())
165
202
  parameters: list[Parameter] = []
166
203
 
204
+ if type_substitutions is None:
205
+ type_substitutions = {}
206
+
167
207
  for idx, param in enumerate(py_params):
208
+ if is_cls_method and idx == 0:
209
+ continue # skip 'self' or 'cls' parameter
168
210
  if param.name in cls.SPECIAL_PARAM_NAMES:
169
211
  raise excs.Error(f"'{param.name}' is a reserved parameter name")
170
212
  if param.kind == inspect.Parameter.VAR_POSITIONAL or param.kind == inspect.Parameter.VAR_KEYWORD:
@@ -178,7 +220,12 @@ class Signature:
178
220
  param_type = param_types[idx]
179
221
  is_batched = False
180
222
  else:
181
- param_type, is_batched = cls._infer_type(param.annotation)
223
+ py_type: Optional[type]
224
+ if param.annotation in type_substitutions:
225
+ py_type = type_substitutions[param.annotation]
226
+ else:
227
+ py_type = param.annotation
228
+ param_type, is_batched = cls._infer_type(py_type)
182
229
  if param_type is None:
183
230
  raise excs.Error(f'Cannot infer pixeltable type for parameter {param.name}')
184
231
 
@@ -189,18 +236,29 @@ class Signature:
189
236
 
190
237
  @classmethod
191
238
  def create(
192
- cls, py_fn: Callable,
239
+ cls,
240
+ py_fn: Callable,
193
241
  param_types: Optional[list[ts.ColumnType]] = None,
194
- return_type: Optional[ts.ColumnType] = None
242
+ return_type: Optional[ts.ColumnType] = None,
243
+ type_substitutions: Optional[dict] = None,
244
+ is_cls_method: bool = False
195
245
  ) -> Signature:
196
246
  """Create a signature for the given Callable.
197
247
  Infer the parameter and return types, if none are specified.
198
248
  Raises an exception if the types cannot be inferred.
199
249
  """
200
- parameters = cls.create_parameters(py_fn=py_fn, param_types=param_types)
250
+ if type_substitutions is None:
251
+ type_substitutions = {}
252
+
253
+ parameters = cls.create_parameters(py_fn=py_fn, param_types=param_types, is_cls_method=is_cls_method, type_substitutions=type_substitutions)
201
254
  sig = inspect.signature(py_fn)
202
255
  if return_type is None:
203
- return_type, return_is_batched = cls._infer_type(sig.return_annotation)
256
+ py_type: Optional[type]
257
+ if sig.return_annotation in type_substitutions:
258
+ py_type = type_substitutions[sig.return_annotation]
259
+ else:
260
+ py_type = sig.return_annotation
261
+ return_type, return_is_batched = cls._infer_type(py_type)
204
262
  if return_type is None:
205
263
  raise excs.Error('Cannot infer pixeltable return type')
206
264
  else: