pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import inspect
5
- from typing import TYPE_CHECKING, Any, Callable, Optional
5
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence, overload
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
@@ -12,15 +12,15 @@ from .globals import validate_symbol_path
12
12
  from .signature import Parameter, Signature
13
13
 
14
14
  if TYPE_CHECKING:
15
- import pixeltable
15
+ from pixeltable import exprs
16
16
 
17
17
 
18
18
  class Aggregator(abc.ABC):
19
- def update(self, *args: Any, **kwargs: Any) -> None:
20
- pass
19
+ @abc.abstractmethod
20
+ def update(self, *args: Any, **kwargs: Any) -> None: ...
21
21
 
22
- def value(self) -> Any:
23
- pass
22
+ @abc.abstractmethod
23
+ def value(self) -> Any: ...
24
24
 
25
25
 
26
26
  class AggregateFunction(Function):
@@ -31,66 +31,149 @@ class AggregateFunction(Function):
31
31
  allows_std_agg: if True, the aggregate function can be used as a standard aggregate function w/o a window
32
32
  allows_window: if True, the aggregate function can be used with a window
33
33
  """
34
- ORDER_BY_PARAM = 'order_by'
35
- GROUP_BY_PARAM = 'group_by'
36
- RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
34
+
35
+ ORDER_BY_PARAM: ClassVar[str] = 'order_by'
36
+ GROUP_BY_PARAM: ClassVar[str] = 'group_by'
37
+ RESERVED_PARAMS: ClassVar[set[str]] = {ORDER_BY_PARAM, GROUP_BY_PARAM}
38
+
39
+ agg_classes: list[type[Aggregator]] # classes for each signature, in signature order
40
+ init_param_names: list[list[str]] # names of the __init__ parameters for each signature
37
41
 
38
42
  def __init__(
39
- self, aggregator_class: type[Aggregator], self_path: str,
40
- init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
41
- requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
42
- self.agg_cls = aggregator_class
43
+ self,
44
+ agg_class: type[Aggregator],
45
+ type_substitutions: Sequence[dict] | None,
46
+ self_path: str,
47
+ requires_order_by: bool,
48
+ allows_std_agg: bool,
49
+ allows_window: bool,
50
+ ) -> None:
51
+ if type_substitutions is None:
52
+ type_substitutions = [None] # single signature with no substitutions
53
+ self.agg_classes = [agg_class]
54
+ else:
55
+ self.agg_classes = [agg_class] * len(type_substitutions)
56
+ self.init_param_names = []
43
57
  self.requires_order_by = requires_order_by
44
58
  self.allows_std_agg = allows_std_agg
45
59
  self.allows_window = allows_window
46
- self.__doc__ = aggregator_class.__doc__
60
+ self.__doc__ = agg_class.__doc__
61
+
62
+ signatures: list[Signature] = []
63
+
64
+ # If no type_substitutions were provided, construct a single signature for the class.
65
+ # Otherwise, construct one signature for each type substitution instance.
66
+ for subst in type_substitutions:
67
+ signature, init_param_names = self.__cls_to_signature(agg_class, subst)
68
+ signatures.append(signature)
69
+ self.init_param_names.append(init_param_names)
70
+
71
+ super().__init__(signatures, self_path=self_path)
72
+
73
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
74
+ self.agg_classes = [self.agg_classes[signature_idx]]
75
+ self.init_param_names = [self.init_param_names[signature_idx]]
76
+
77
+ def __cls_to_signature(
78
+ self, cls: type[Aggregator], type_substitutions: dict | None = None
79
+ ) -> tuple[Signature, list[str]]:
80
+ """Inspects the Aggregator class to infer the corresponding function signature. Returns the
81
+ inferred signature along with the list of init_param_names (for downstream error handling).
82
+ """
83
+ from pixeltable import exprs
84
+
85
+ # infer type parameters; set return_type=InvalidType() because it has no meaning here
86
+ init_sig = Signature.create(
87
+ py_fn=cls.__init__, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions
88
+ )
89
+ update_sig = Signature.create(
90
+ py_fn=cls.update, return_type=ts.InvalidType(), is_cls_method=True, type_substitutions=type_substitutions
91
+ )
92
+ value_sig = Signature.create(py_fn=cls.value, is_cls_method=True, type_substitutions=type_substitutions)
93
+
94
+ init_types = [p.col_type for p in init_sig.parameters.values()]
95
+ update_types = [p.col_type for p in update_sig.parameters.values()]
96
+ value_type = value_sig.return_type
97
+ assert value_type is not None
98
+
99
+ if len(update_types) == 0:
100
+ raise excs.Error('update() must have at least one parameter')
47
101
 
48
102
  # our signature is the signature of 'update', but without self,
49
103
  # plus the parameters of 'init' as keyword-only parameters
50
- py_update_params = list(inspect.signature(self.agg_cls.update).parameters.values())[1:] # leave out self
104
+ py_update_params = list(inspect.signature(cls.update).parameters.values())[1:] # leave out self
51
105
  assert len(py_update_params) == len(update_types)
52
106
  update_params = [
53
- Parameter(p.name, col_type=update_types[i], kind=p.kind, default=p.default)
107
+ Parameter(
108
+ p.name,
109
+ col_type=update_types[i],
110
+ kind=p.kind,
111
+ default=exprs.Expr.from_object(p.default), # type: ignore[arg-type]
112
+ )
54
113
  for i, p in enumerate(py_update_params)
55
114
  ]
56
115
  # starting at 1: leave out self
57
- py_init_params = list(inspect.signature(self.agg_cls.__init__).parameters.values())[1:]
116
+ py_init_params = list(inspect.signature(cls.__init__).parameters.values())[1:]
58
117
  assert len(py_init_params) == len(init_types)
59
118
  init_params = [
60
- Parameter(p.name, col_type=init_types[i], kind=inspect.Parameter.KEYWORD_ONLY, default=p.default)
119
+ Parameter(
120
+ p.name,
121
+ col_type=init_types[i],
122
+ kind=inspect.Parameter.KEYWORD_ONLY,
123
+ default=exprs.Expr.from_object(p.default), # type: ignore[arg-type]
124
+ )
61
125
  for i, p in enumerate(py_init_params)
62
126
  ]
63
- duplicate_params = set(p.name for p in init_params) & set(p.name for p in update_params)
127
+ duplicate_params = {p.name for p in init_params} & {p.name for p in update_params}
64
128
  if len(duplicate_params) > 0:
65
129
  raise excs.Error(
66
- f'__init__() and update() cannot have parameters with the same name: '
67
- f'{", ".join(duplicate_params)}'
130
+ f'__init__() and update() cannot have parameters with the same name: {", ".join(duplicate_params)}'
68
131
  )
69
132
  params = update_params + init_params # init_params are keyword-only and come last
133
+ init_param_names = [p.name for p in init_params]
134
+
135
+ return Signature(value_type, params), init_param_names
70
136
 
71
- signature = Signature(value_type, params)
72
- super().__init__(signature, self_path=self_path)
73
- self.init_param_names = [p.name for p in init_params]
137
+ @property
138
+ def agg_class(self) -> type[Aggregator]:
139
+ assert not self.is_polymorphic
140
+ return self.agg_classes[0]
74
141
 
75
- # make sure the signature doesn't contain reserved parameter names;
76
- # do this after super().__init__(), otherwise self.name is invalid
77
- for param in signature.parameters:
78
- if param.lower() in self.RESERVED_PARAMS:
79
- raise excs.Error(f'{self.name}(): parameter name {param} is reserved')
142
+ @property
143
+ def is_async(self) -> bool:
144
+ return False
80
145
 
81
- def exec(self, *args: Any, **kwargs: Any) -> Any:
146
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
82
147
  raise NotImplementedError
83
148
 
149
+ def overload(self, cls: type[Aggregator]) -> AggregateFunction:
150
+ if not isinstance(cls, type) or not issubclass(cls, Aggregator):
151
+ raise excs.Error(f'Invalid argument to @overload decorator: {cls}')
152
+ if self._has_resolved_fns:
153
+ raise excs.Error('New `overload` not allowed after the UDF has already been called')
154
+ if self._conditional_return_type is not None:
155
+ raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
156
+ sig, init_param_names = self.__cls_to_signature(cls)
157
+ self.signatures.append(sig)
158
+ self.agg_classes.append(cls)
159
+ self.init_param_names.append(init_param_names)
160
+ return self
161
+
162
+ def comment(self) -> str | None:
163
+ return inspect.getdoc(self.agg_classes[0])
164
+
84
165
  def help_str(self) -> str:
85
166
  res = super().help_str()
86
- res += '\n\n' + inspect.getdoc(self.agg_cls.update)
167
+ # We need to reference agg_classes[0] rather than agg_class here, because we want this to work even if the
168
+ # aggregator is polymorphic (in which case we use the docstring of the originally decorated UDA).
169
+ res += '\n\n' + inspect.getdoc(self.agg_classes[0].update)
87
170
  return res
88
171
 
89
- def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
172
+ def __call__(self, *args: Any, **kwargs: Any) -> 'exprs.FunctionCall':
90
173
  from pixeltable import exprs
91
174
 
92
175
  # perform semantic analysis of special parameters 'order_by' and 'group_by'
93
- order_by_clause: Optional[Any] = None
176
+ order_by_clause: Any | None = None
94
177
  if self.ORDER_BY_PARAM in kwargs:
95
178
  if self.requires_order_by:
96
179
  raise excs.Error(
@@ -99,7 +182,8 @@ class AggregateFunction(Function):
99
182
  )
100
183
  if not self.allows_window:
101
184
  raise excs.Error(
102
- f'{self.display_name}(): order_by invalid with an aggregate function that does not allow windows')
185
+ f'{self.display_name}(): order_by invalid with an aggregate function that does not allow windows'
186
+ )
103
187
  order_by_clause = kwargs.pop(self.ORDER_BY_PARAM)
104
188
  elif self.requires_order_by:
105
189
  # the first argument is the order-by expr
@@ -114,42 +198,61 @@ class AggregateFunction(Function):
114
198
  # don't pass the first parameter on, the Function doesn't get to see it
115
199
  args = args[1:]
116
200
 
117
- group_by_clause: Optional[Any] = None
201
+ group_by_clause: Any | None = None
118
202
  if self.GROUP_BY_PARAM in kwargs:
119
203
  if not self.allows_window:
120
204
  raise excs.Error(
121
- f'{self.display_name}(): group_by invalid with an aggregate function that does not allow windows')
205
+ f'{self.display_name}(): group_by invalid with an aggregate function that does not allow windows'
206
+ )
122
207
  group_by_clause = kwargs.pop(self.GROUP_BY_PARAM)
123
208
 
124
- bound_args = self.signature.py_signature.bind(*args, **kwargs)
125
- self.validate_call(bound_args.arguments)
209
+ args = [exprs.Expr.from_object(arg) for arg in args]
210
+ kwargs = {k: exprs.Expr.from_object(v) for k, v in kwargs.items()}
211
+
212
+ resolved_fn, bound_args = self._bind_to_matching_signature(args, kwargs)
213
+ return_type = resolved_fn.call_return_type(bound_args)
214
+
126
215
  return exprs.FunctionCall(
127
- self, bound_args.arguments,
216
+ resolved_fn,
217
+ args,
218
+ kwargs,
219
+ return_type,
128
220
  order_by_clause=[order_by_clause] if order_by_clause is not None else [],
129
- group_by_clause=[group_by_clause] if group_by_clause is not None else [])
221
+ group_by_clause=[group_by_clause] if group_by_clause is not None else [],
222
+ )
223
+
224
+ def validate_call(self, bound_args: dict[str, 'exprs.Expr']) -> None:
225
+ from pixeltable import exprs
226
+
227
+ super().validate_call(bound_args)
130
228
 
131
- def validate_call(self, bound_args: dict[str, Any]) -> None:
132
229
  # check that init parameters are not Exprs
133
230
  # TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
134
- import pixeltable.exprs as exprs
135
- for param_name in self.init_param_names:
136
- if param_name in bound_args and isinstance(bound_args[param_name], exprs.Expr):
137
- raise excs.Error(
138
- f'{self.display_name}(): init() parameter {param_name} needs to be a constant, not a Pixeltable '
139
- f'expression'
140
- )
231
+ for param_name in self.init_param_names[0]:
232
+ if param_name in bound_args and not isinstance(bound_args[param_name], exprs.Literal):
233
+ raise excs.Error(f'{self.display_name}(): init() parameter {param_name!r} must be a constant value')
141
234
 
142
235
  def __repr__(self) -> str:
143
236
  return f'<Pixeltable Aggregator {self.name}>'
144
237
 
145
238
 
239
+ # Decorator invoked without parentheses: @pxt.uda
240
+ @overload
241
+ def uda(decorated_fn: Callable) -> AggregateFunction: ...
242
+
243
+
244
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
245
+ @overload
146
246
  def uda(
147
- *,
148
- value_type: ts.ColumnType,
149
- update_types: list[ts.ColumnType],
150
- init_types: Optional[list[ts.ColumnType]] = None,
151
- requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
152
- ) -> Callable[[type[Aggregator]], AggregateFunction]:
247
+ *,
248
+ requires_order_by: bool = False,
249
+ allows_std_agg: bool = True,
250
+ allows_window: bool = False,
251
+ type_substitutions: Sequence[dict] | None = None,
252
+ ) -> Callable[[type[Aggregator]], AggregateFunction]: ...
253
+
254
+
255
+ def uda(*args, **kwargs): # type: ignore[no-untyped-def]
153
256
  """Decorator for user-defined aggregate functions.
154
257
 
155
258
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -161,46 +264,48 @@ def uda(
161
264
  to the module where the class is defined.
162
265
 
163
266
  Parameters:
164
- - init_types: list of types for the __init__() parameters; must match the number of parameters
165
- - update_types: list of types for the update() parameters; must match the number of parameters
166
- - value_type: return type of the aggregator
167
267
  - requires_order_by: if True, the first parameter to the function is the order-by expression
168
268
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
169
269
  - allows_window: if True, the function can be used with a window
170
270
  """
171
- if init_types is None:
172
- init_types = []
173
-
174
- def decorator(cls: type[Aggregator]) -> AggregateFunction:
175
- # validate type parameters
176
- num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
177
- if num_init_params > 0:
178
- if len(init_types) != num_init_params:
179
- raise excs.Error(
180
- f'init_types must be a list of {num_init_params} types, one for each parameter of __init__()')
181
- num_update_params = len(inspect.signature(cls.update).parameters) - 1
182
- if num_update_params == 0:
183
- raise excs.Error('update() must have at least one parameter')
184
- if len(update_types) != num_update_params:
185
- raise excs.Error(
186
- f'update_types must be a list of {num_update_params} types, one for each parameter of update()')
187
- assert value_type is not None
188
-
189
- # the AggregateFunction instance resides in the same module as cls
190
- class_path = f'{cls.__module__}.{cls.__qualname__}'
191
- # nonlocal name
192
- # name = name or cls.__name__
193
- # instance_path_elements = class_path.split('.')[:-1] + [name]
194
- # instance_path = '.'.join(instance_path_elements)
195
-
196
- # create the corresponding AggregateFunction instance
197
- instance = AggregateFunction(
198
- cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
199
- # do the path validation at the very end, in order to be able to write tests for the other failure cases
200
- validate_symbol_path(class_path)
201
- #module = importlib.import_module(cls.__module__)
202
- #setattr(module, name, instance)
203
-
204
- return instance
271
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
272
+ # Decorator invoked without parentheses: @pxt.uda
273
+ # Simply call make_aggregator with defaults.
274
+ return make_aggregator(cls=args[0])
275
+
276
+ else:
277
+ # Decorator schema invoked with parentheses: @pxt.uda(**kwargs)
278
+ # Create a decorator for the specified schema.
279
+ requires_order_by = kwargs.pop('requires_order_by', False)
280
+ allows_std_agg = kwargs.pop('allows_std_agg', True)
281
+ allows_window = kwargs.pop('allows_window', False)
282
+ type_substitutions = kwargs.pop('type_substitutions', None)
283
+ if len(kwargs) > 0:
284
+ raise excs.Error(f'Invalid @uda decorator kwargs: {", ".join(kwargs.keys())}')
285
+ if len(args) > 0:
286
+ raise excs.Error('Unexpected @uda decorator arguments.')
287
+
288
+ def decorator(cls: type[Aggregator]) -> AggregateFunction:
289
+ return make_aggregator(
290
+ cls,
291
+ requires_order_by=requires_order_by,
292
+ allows_std_agg=allows_std_agg,
293
+ allows_window=allows_window,
294
+ type_substitutions=type_substitutions,
295
+ )
205
296
 
206
- return decorator
297
+ return decorator
298
+
299
+
300
+ def make_aggregator(
301
+ cls: type[Aggregator],
302
+ requires_order_by: bool = False,
303
+ allows_std_agg: bool = True,
304
+ allows_window: bool = False,
305
+ type_substitutions: Sequence[dict] | None = None,
306
+ ) -> AggregateFunction:
307
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
308
+ instance = AggregateFunction(cls, type_substitutions, class_path, requires_order_by, allows_std_agg, allows_window)
309
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
310
+ validate_symbol_path(class_path)
311
+ return instance
@@ -1,14 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Any, Callable, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Sequence
5
5
  from uuid import UUID
6
6
 
7
7
  import cloudpickle # type: ignore[import-untyped]
8
8
 
9
+ import pixeltable.exceptions as excs
10
+ from pixeltable.utils.coroutine import run_coroutine_synchronously
11
+
9
12
  from .function import Function
10
13
  from .signature import Signature
11
14
 
15
+ if TYPE_CHECKING:
16
+ from pixeltable import exprs
17
+
12
18
 
13
19
  class CallableFunction(Function):
14
20
  """Pixeltable Function backed by a Python Callable.
@@ -18,55 +24,121 @@ class CallableFunction(Function):
18
24
  - functions that are defined in modules are serialized via the default mechanism
19
25
  """
20
26
 
27
+ py_fns: list[Callable]
28
+ self_name: str | None
29
+ batch_size: int | None
30
+
21
31
  def __init__(
22
32
  self,
23
- signature: Signature,
24
- py_fn: Callable,
25
- self_path: Optional[str] = None,
26
- self_name: Optional[str] = None,
27
- batch_size: Optional[int] = None,
33
+ signatures: list[Signature],
34
+ py_fns: list[Callable],
35
+ self_path: str | None = None,
36
+ self_name: str | None = None,
37
+ batch_size: int | None = None,
28
38
  is_method: bool = False,
29
- is_property: bool = False
39
+ is_property: bool = False,
30
40
  ):
31
- assert py_fn is not None
32
- self.py_fn = py_fn
41
+ assert len(signatures) > 0
42
+ assert len(signatures) == len(py_fns)
43
+ if self_path is None and len(signatures) > 1:
44
+ raise excs.Error('Multiple signatures are only allowed for module UDFs (not locally defined UDFs)')
45
+ self.py_fns = py_fns
33
46
  self.self_name = self_name
34
47
  self.batch_size = batch_size
35
- self.__doc__ = py_fn.__doc__
36
- super().__init__(signature, self_path=self_path, is_method=is_method, is_property=is_property)
48
+ self.__doc__ = self.py_fns[0].__doc__
49
+ super().__init__(signatures, self_path=self_path, is_method=is_method, is_property=is_property)
50
+
51
+ def _update_as_overload_resolution(self, signature_idx: int) -> None:
52
+ assert len(self.py_fns) > signature_idx
53
+ self.py_fns = [self.py_fns[signature_idx]]
37
54
 
38
55
  @property
39
56
  def is_batched(self) -> bool:
40
57
  return self.batch_size is not None
41
58
 
42
- def exec(self, *args: Any, **kwargs: Any) -> Any:
59
+ @property
60
+ def is_async(self) -> bool:
61
+ return inspect.iscoroutinefunction(self.py_fn)
62
+
63
+ def comment(self) -> str | None:
64
+ return inspect.getdoc(self.py_fns[0])
65
+
66
+ @property
67
+ def py_fn(self) -> Callable:
68
+ assert not self.is_polymorphic
69
+ return self.py_fns[0]
70
+
71
+ async def aexec(self, *args: Any, **kwargs: Any) -> Any:
72
+ assert not self.is_polymorphic
73
+ assert self.is_async
74
+ if self.is_batched:
75
+ # Pack the batched parameters into singleton lists
76
+ constant_param_names = [p.name for p in self.signature.constant_parameters]
77
+ batched_args = [[arg] for arg in args]
78
+ constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
79
+ batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
80
+ result = await self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
81
+ assert len(result) == 1
82
+ return result[0]
83
+ else:
84
+ return await self.py_fn(*args, **kwargs)
85
+
86
+ def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
87
+ assert not self.is_polymorphic
43
88
  if self.is_batched:
44
89
  # Pack the batched parameters into singleton lists
45
90
  constant_param_names = [p.name for p in self.signature.constant_parameters]
46
91
  batched_args = [[arg] for arg in args]
47
92
  constant_kwargs = {k: v for k, v in kwargs.items() if k in constant_param_names}
48
93
  batched_kwargs = {k: [v] for k, v in kwargs.items() if k not in constant_param_names}
49
- result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
94
+ result: list[Any]
95
+ if inspect.iscoroutinefunction(self.py_fn):
96
+ # TODO: This is temporary (see note in utils/coroutine.py)
97
+ result = run_coroutine_synchronously(self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs))
98
+ else:
99
+ result = self.py_fn(*batched_args, **constant_kwargs, **batched_kwargs)
50
100
  assert len(result) == 1
51
101
  return result[0]
102
+ elif inspect.iscoroutinefunction(self.py_fn):
103
+ # TODO: This is temporary (see note in utils/coroutine.py)
104
+ return run_coroutine_synchronously(self.py_fn(*args, **kwargs))
52
105
  else:
53
106
  return self.py_fn(*args, **kwargs)
54
107
 
55
- def exec_batch(self, *args: Any, **kwargs: Any) -> list:
108
+ async def aexec_batch(self, *args: Any, **kwargs: Any) -> list:
56
109
  """Execute the function with the given arguments and return the result.
57
110
  The arguments are expected to be batched: if the corresponding parameter has type T,
58
111
  then the argument should have type T if it's a constant parameter, or list[T] if it's
59
112
  a batched parameter.
60
113
  """
61
114
  assert self.is_batched
115
+ assert self.is_async
116
+ assert not self.is_polymorphic
62
117
  # Unpack the constant parameters
118
+ constant_kwargs, batched_kwargs = self.create_batch_kwargs(kwargs)
119
+ return await self.py_fn(*args, **constant_kwargs, **batched_kwargs)
120
+
121
+ def exec_batch(self, args: list[Any], kwargs: dict[str, Any]) -> list:
122
+ """Execute the function with the given arguments and return the result.
123
+ The arguments are expected to be batched: if the corresponding parameter has type T,
124
+ then the argument should have type T if it's a constant parameter, or list[T] if it's
125
+ a batched parameter.
126
+ """
127
+ assert self.is_batched
128
+ assert not self.is_polymorphic
129
+ assert not self.is_async
130
+ # Unpack the constant parameters
131
+ constant_kwargs, batched_kwargs = self.create_batch_kwargs(kwargs)
132
+ return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
133
+
134
+ def create_batch_kwargs(self, kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, list[Any]]]:
135
+ """Converts kwargs containing lists into constant and batched kwargs in the format expected by a batched udf."""
63
136
  constant_param_names = [p.name for p in self.signature.constant_parameters]
64
137
  constant_kwargs = {k: v[0] for k, v in kwargs.items() if k in constant_param_names}
65
138
  batched_kwargs = {k: v for k, v in kwargs.items() if k not in constant_param_names}
66
- return self.py_fn(*args, **constant_kwargs, **batched_kwargs)
139
+ return constant_kwargs, batched_kwargs
67
140
 
68
- # TODO(aaron-siegel): Implement conditional batch sizing
69
- def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
141
+ def get_batch_size(self, *args: Any, **kwargs: Any) -> int | None:
70
142
  return self.batch_size
71
143
 
72
144
  @property
@@ -77,16 +149,26 @@ class CallableFunction(Function):
77
149
  def name(self) -> str:
78
150
  return self.self_name
79
151
 
80
- def help_str(self) -> str:
81
- res = super().help_str()
82
- res += '\n\n' + inspect.getdoc(self.py_fn)
83
- return res
152
+ def overload(self, fn: Callable) -> CallableFunction:
153
+ if self.self_path is None:
154
+ raise excs.Error('`overload` can only be used with module UDFs (not locally defined UDFs)')
155
+ if self.is_method or self.is_property:
156
+ raise excs.Error('`overload` cannot be used with `is_method` or `is_property`')
157
+ if self._has_resolved_fns:
158
+ raise excs.Error('New `overload` not allowed after the UDF has already been called')
159
+ if self._conditional_return_type is not None:
160
+ raise excs.Error('New `overload` not allowed after a conditional return type has been specified')
161
+ sig = Signature.create(fn)
162
+ self.signatures.append(sig)
163
+ self.py_fns.append(fn)
164
+ return self
84
165
 
85
166
  def _as_dict(self) -> dict:
86
167
  if self.self_path is None:
87
168
  # this is not a module function
88
169
  assert not self.is_method and not self.is_property
89
170
  from .function_registry import FunctionRegistry
171
+
90
172
  id = FunctionRegistry.get().create_stored_function(self)
91
173
  return {'id': id.hex}
92
174
  return super()._as_dict()
@@ -95,33 +177,35 @@ class CallableFunction(Function):
95
177
  def _from_dict(cls, d: dict) -> Function:
96
178
  if 'id' in d:
97
179
  from .function_registry import FunctionRegistry
180
+
98
181
  return FunctionRegistry.get().get_stored_function(UUID(hex=d['id']))
99
182
  return super()._from_dict(d)
100
183
 
101
184
  def to_store(self) -> tuple[dict, bytes]:
102
- md = {
103
- 'signature': self.signature.as_dict(),
104
- 'batch_size': self.batch_size,
105
- }
185
+ assert not self.is_polymorphic # multi-signature UDFs not allowed for stored fns
186
+ md = {'signature': self.signature.as_dict(), 'batch_size': self.batch_size}
106
187
  return md, cloudpickle.dumps(self.py_fn)
107
188
 
108
189
  @classmethod
109
- def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
190
+ def from_store(cls, name: str | None, md: dict, binary_obj: bytes) -> Function:
110
191
  py_fn = cloudpickle.loads(binary_obj)
111
192
  assert callable(py_fn)
112
193
  sig = Signature.from_dict(md['signature'])
113
194
  batch_size = md['batch_size']
114
- return CallableFunction(sig, py_fn, self_name=name, batch_size=batch_size)
195
+ return CallableFunction([sig], [py_fn], self_name=name, batch_size=batch_size)
196
+
197
+ def validate_call(self, bound_args: dict[str, 'exprs.Expr']) -> None:
198
+ from pixeltable import exprs
115
199
 
116
- def validate_call(self, bound_args: dict[str, Any]) -> None:
117
- import pixeltable.exprs as exprs
200
+ super().validate_call(bound_args)
118
201
  if self.is_batched:
119
- for param in self.signature.constant_parameters:
120
- if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
121
- raise ValueError(
122
- f'{self.display_name}(): '
123
- f'parameter {param.name} must be a constant value, not a Pixeltable expression'
124
- )
202
+ signature = self.signatures[0]
203
+ for param in signature.constant_parameters:
204
+ # Check that constant parameters map to constant arguments. It's ok for the argument to be a Variable,
205
+ # since in that case the FunctionCall is part of an unresolved template; the check will be done again
206
+ # when the template is fully resolved.
207
+ if param.name in bound_args and not isinstance(bound_args[param.name], (exprs.Literal, exprs.Variable)):
208
+ raise ValueError(f'{self.display_name}(): parameter {param.name} must be a constant value')
125
209
 
126
210
  def __repr__(self) -> str:
127
211
  return f'<Pixeltable UDF {self.name}>'