pixeltable 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (62) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +96 -19
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/dataframe.py +201 -3
  12. pixeltable/env.py +9 -3
  13. pixeltable/exec/expr_eval_node.py +1 -1
  14. pixeltable/exec/sql_node.py +2 -2
  15. pixeltable/exprs/function_call.py +134 -29
  16. pixeltable/exprs/inline_expr.py +22 -2
  17. pixeltable/exprs/row_builder.py +1 -1
  18. pixeltable/exprs/similarity_expr.py +9 -2
  19. pixeltable/func/__init__.py +1 -0
  20. pixeltable/func/aggregate_function.py +151 -68
  21. pixeltable/func/callable_function.py +50 -16
  22. pixeltable/func/expr_template_function.py +62 -24
  23. pixeltable/func/function.py +191 -23
  24. pixeltable/func/function_registry.py +2 -1
  25. pixeltable/func/query_template_function.py +11 -6
  26. pixeltable/func/signature.py +64 -7
  27. pixeltable/func/tools.py +116 -0
  28. pixeltable/func/udf.py +57 -35
  29. pixeltable/functions/__init__.py +2 -2
  30. pixeltable/functions/anthropic.py +36 -2
  31. pixeltable/functions/globals.py +54 -34
  32. pixeltable/functions/json.py +3 -8
  33. pixeltable/functions/math.py +67 -0
  34. pixeltable/functions/ollama.py +4 -4
  35. pixeltable/functions/openai.py +31 -2
  36. pixeltable/functions/timestamp.py +1 -1
  37. pixeltable/functions/video.py +2 -8
  38. pixeltable/functions/vision.py +1 -1
  39. pixeltable/globals.py +347 -79
  40. pixeltable/index/embedding_index.py +44 -24
  41. pixeltable/metadata/__init__.py +1 -1
  42. pixeltable/metadata/converters/convert_16.py +2 -1
  43. pixeltable/metadata/converters/convert_17.py +2 -1
  44. pixeltable/metadata/converters/convert_23.py +35 -0
  45. pixeltable/metadata/converters/convert_24.py +47 -0
  46. pixeltable/metadata/converters/util.py +4 -2
  47. pixeltable/metadata/notes.py +2 -0
  48. pixeltable/metadata/schema.py +1 -0
  49. pixeltable/type_system.py +192 -48
  50. {pixeltable-0.2.28.dist-info → pixeltable-0.2.30.dist-info}/METADATA +4 -2
  51. {pixeltable-0.2.28.dist-info → pixeltable-0.2.30.dist-info}/RECORD +54 -57
  52. pixeltable-0.2.30.dist-info/entry_points.txt +3 -0
  53. pixeltable/tool/create_test_db_dump.py +0 -311
  54. pixeltable/tool/create_test_video.py +0 -81
  55. pixeltable/tool/doc_plugins/griffe.py +0 -50
  56. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  57. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  58. pixeltable/tool/embed_udf.py +0 -9
  59. pixeltable/tool/mypy_plugin.py +0 -55
  60. pixeltable-0.2.28.dist-info/entry_points.txt +0 -3
  61. {pixeltable-0.2.28.dist-info → pixeltable-0.2.30.dist-info}/LICENSE +0 -0
  62. {pixeltable-0.2.28.dist-info → pixeltable-0.2.30.dist-info}/WHEEL +0 -0
pixeltable/dataframe.py CHANGED
@@ -300,6 +300,20 @@ class DataFrame:
300
300
  return self.limit(n).collect()
301
301
 
302
302
  def head(self, n: int = 10) -> DataFrameResultSet:
303
+ """Return the first n rows of the DataFrame, in insertion order of the underlying Table.
304
+
305
+ head() is not supported for joins.
306
+
307
+ Args:
308
+ n: Number of rows to select. Default is 10.
309
+
310
+ Returns:
311
+ A DataFrameResultSet with the first n rows of the DataFrame.
312
+
313
+ Raises:
314
+ Error: If the DataFrame is the result of a join or
315
+ if the DataFrame has an order_by clause.
316
+ """
303
317
  if self.order_by_clause is not None:
304
318
  raise excs.Error(f'head() cannot be used with order_by()')
305
319
  if self._has_joins():
@@ -309,6 +323,20 @@ class DataFrame:
309
323
  return self.order_by(*order_by_clause, asc=True).limit(n).collect()
310
324
 
311
325
  def tail(self, n: int = 10) -> DataFrameResultSet:
326
+ """Return the last n rows of the DataFrame, in insertion order of the underlying Table.
327
+
328
+ tail() is not supported for joins.
329
+
330
+ Args:
331
+ n: Number of rows to select. Default is 10.
332
+
333
+ Returns:
334
+ A DataFrameResultSet with the last n rows of the DataFrame.
335
+
336
+ Raises:
337
+ Error: If the DataFrame is the result of a join or
338
+ if the DataFrame has an order_by clause.
339
+ """
312
340
  if self.order_by_clause is not None:
313
341
  raise excs.Error(f'tail() cannot be used with order_by()')
314
342
  if self._has_joins():
@@ -394,6 +422,11 @@ class DataFrame:
394
422
  return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
395
423
 
396
424
  def count(self) -> int:
425
+ """Return the number of rows in the DataFrame.
426
+
427
+ Returns:
428
+ The number of rows in the DataFrame.
429
+ """
397
430
  from pixeltable.plan import Planner
398
431
 
399
432
  stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
@@ -463,6 +496,36 @@ class DataFrame:
463
496
  return self._descriptors().to_html()
464
497
 
465
498
  def select(self, *items: Any, **named_items: Any) -> DataFrame:
499
+ """ Select columns or expressions from the DataFrame.
500
+
501
+ Args:
502
+ items: expressions to be selected
503
+ named_items: named expressions to be selected
504
+
505
+ Returns:
506
+ A new DataFrame with the specified select list.
507
+
508
+ Raises:
509
+ Error: If the select list is already specified,
510
+ or if any of the specified expressions are invalid,
511
+ or refer to tables not in the DataFrame.
512
+
513
+ Examples:
514
+ Given the DataFrame person from a table t with all its columns and rows:
515
+
516
+ >>> person = t.select()
517
+
518
+ Select the columns 'name' and 'age' (referenced in table t) from the DataFrame person:
519
+
520
+ >>> df = person.select(t.name, t.age)
521
+
522
+ Select the columns 'name' (referenced in table t) from the DataFrame person,
523
+ and a named column 'is_adult' from the expression `age >= 18` where 'age' is
524
+ another column in table t:
525
+
526
+ >>> df = person.select(t.name, is_adult=(t.age >= 18))
527
+
528
+ """
466
529
  if self.select_list is not None:
467
530
  raise excs.Error(f'Select list already specified')
468
531
  for name, _ in named_items.items():
@@ -512,6 +575,29 @@ class DataFrame:
512
575
  )
513
576
 
514
577
  def where(self, pred: exprs.Expr) -> DataFrame:
578
+ """Filter rows based on a predicate.
579
+
580
+ Args:
581
+ pred: the predicate to filter rows
582
+
583
+ Returns:
584
+ A new DataFrame with the specified predicates replacing the where-clause.
585
+
586
+ Raises:
587
+ Error: If the predicate is not a Pixeltable expression,
588
+ or if it does not return a boolean value,
589
+ or refers to tables not in the DataFrame.
590
+
591
+ Examples:
592
+ Given the DataFrame person from a table t with all its columns and rows:
593
+
594
+ >>> person = t.select()
595
+
596
+ Filter the above DataFrame person to only include rows where the column 'age'
597
+ (referenced in table t) is greater than 30:
598
+
599
+ >>> df = person.where(t.age > 30)
600
+ """
515
601
  if not isinstance(pred, exprs.Expr):
516
602
  raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
517
603
  if not pred.col_type.is_bool_type():
@@ -662,11 +748,45 @@ class DataFrame:
662
748
  )
663
749
 
664
750
  def group_by(self, *grouping_items: Any) -> DataFrame:
665
- """
666
- Add a group-by clause to this DataFrame.
751
+ """ Add a group-by clause to this DataFrame.
752
+
667
753
  Variants:
668
754
  - group_by(<base table>): group a component view by their respective base table rows
669
755
  - group_by(<expr>, ...): group by the given expressions
756
+
757
+ Note, that grouping will be applied to the rows and take effect when
758
+ used with an aggregation function like sum(), count() etc.
759
+
760
+ Args:
761
+ grouping_items: expressions to group by
762
+
763
+ Returns:
764
+ A new DataFrame with the specified group-by clause.
765
+
766
+ Raises:
767
+ Error: If the group-by clause is already specified,
768
+ or if the specified expression is invalid,
769
+ or refer to tables not in the DataFrame,
770
+ or if the DataFrame is a result of a join.
771
+
772
+ Examples:
773
+ Given the DataFrame book from a table t with all its columns and rows:
774
+
775
+ >>> book = t.select()
776
+
777
+ Group the above DataFrame book by the 'genre' column (referenced in table t):
778
+
779
+ >>> df = book.group_by(t.genre)
780
+
781
+ Use the above DataFrame df grouped by genre to count the number of
782
+ books for each 'genre':
783
+
784
+ >>> df = book.group_by(t.genre).select(t.genre, count=count(t.genre)).show()
785
+
786
+ Use the above DataFrame df grouped by genre to the total price of
787
+ books for each 'genre':
788
+
789
+ >>> df = book.group_by(t.genre).select(t.genre, total=sum(t.price)).show()
670
790
  """
671
791
  if self.group_by_clause is not None:
672
792
  raise excs.Error(f'Group-by already specified')
@@ -699,6 +819,35 @@ class DataFrame:
699
819
  )
700
820
 
701
821
  def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
822
+ """ Add an order-by clause to this DataFrame.
823
+
824
+ Args:
825
+ expr_list: expressions to order by
826
+ asc: whether to order in ascending order (True) or descending order (False).
827
+ Default is True.
828
+
829
+ Returns:
830
+ A new DataFrame with the specified order-by clause.
831
+
832
+ Raises:
833
+ Error: If the order-by clause is already specified,
834
+ or if the specified expression is invalid,
835
+ or refer to tables not in the DataFrame.
836
+
837
+ Examples:
838
+ Given the DataFrame book from a table t with all its columns and rows:
839
+
840
+ >>> book = t.select()
841
+
842
+ Order the above DataFrame book by two columns (price, pages) in descending order:
843
+
844
+ >>> df = book.order_by(t.price, t.pages, asc=False)
845
+
846
+ Order the above DataFrame book by price in descending order, but order the pages
847
+ in ascending order:
848
+
849
+ >>> df = book.order_by(t.price, asc=False).order_by(t.pages)
850
+ """
702
851
  for e in expr_list:
703
852
  if not isinstance(e, exprs.Expr):
704
853
  raise excs.Error(f'Invalid expression in order_by(): {e}')
@@ -715,6 +864,14 @@ class DataFrame:
715
864
  )
716
865
 
717
866
  def limit(self, n: int) -> DataFrame:
867
+ """ Limit the number of rows in the DataFrame.
868
+
869
+ Args:
870
+ n: Number of rows to select.
871
+
872
+ Returns:
873
+ A new DataFrame with the specified limited rows.
874
+ """
718
875
  # TODO: allow n to be a Variable that can be substituted in bind()
719
876
  assert n is not None and isinstance(n, int)
720
877
  return DataFrame(
@@ -728,17 +885,58 @@ class DataFrame:
728
885
  )
729
886
 
730
887
  def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
888
+ """ Update rows in the underlying table of the DataFrame.
889
+
890
+ Update rows in the table with the specified value_spec.
891
+
892
+ Args:
893
+ value_spec: a dict of column names to update and the new value to update it to.
894
+ cascade: if True, also update all computed columns that transitively depend
895
+ on the updated columns, including within views. Default is True.
896
+
897
+ Returns:
898
+ UpdateStatus: the status of the update operation.
899
+
900
+ Example:
901
+ Given the DataFrame person from a table t with all its columns and rows:
902
+
903
+ >>> person = t.select()
904
+
905
+ Via the above DataFrame person, update the column 'city' to 'Oakland' and 'state' to 'CA' in the table t:
906
+
907
+ >>> df = person.update({'city': 'Oakland', 'state': 'CA'})
908
+
909
+ Via the above DataFrame person, update the column 'age' to 30 for any rows where 'year' is 2014 in the table t:
910
+
911
+ >>> df = person.where(t.year == 2014).update({'age': 30})
912
+ """
731
913
  self._validate_mutable('update')
732
914
  return self._first_tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
733
915
 
734
916
  def delete(self) -> UpdateStatus:
917
+ """ Delete rows form the underlying table of the DataFrame.
918
+
919
+ The delete operation is only allowed for DataFrames on base tables.
920
+
921
+ Returns:
922
+ UpdateStatus: the status of the delete operation.
923
+
924
+ Example:
925
+ Given the DataFrame person from a table t with all its columns and rows:
926
+
927
+ >>> person = t.select()
928
+
929
+ Via the above DataFrame person, delete all rows from the table t where the column 'age' is less than 18:
930
+
931
+ >>> df = person.where(t.age < 18).delete()
932
+ """
735
933
  self._validate_mutable('delete')
736
934
  if not self._first_tbl.is_insertable():
737
935
  raise excs.Error(f'Cannot delete from view')
738
936
  return self._first_tbl.tbl_version.delete(where=self.where_clause)
739
937
 
740
938
  def _validate_mutable(self, op_name: str) -> None:
741
- """Tests whether this `DataFrame` can be mutated (such as by an update operation)."""
939
+ """Tests whether this DataFrame can be mutated (such as by an update operation)."""
742
940
  if self.group_by_clause is not None or self.grouping_tbl is not None:
743
941
  raise excs.Error(f'Cannot use `{op_name}` after `group_by`')
744
942
  if self.order_by_clause is not None:
pixeltable/env.py CHANGED
@@ -8,6 +8,7 @@ import importlib.util
8
8
  import inspect
9
9
  import logging
10
10
  import os
11
+ import platform
11
12
  import shutil
12
13
  import subprocess
13
14
  import sys
@@ -311,8 +312,12 @@ class Env:
311
312
  self._db_name = os.environ.get('PIXELTABLE_DB', 'pixeltable')
312
313
  self._pgdata_dir = Path(os.environ.get('PIXELTABLE_PGDATA', str(self._home / 'pgdata')))
313
314
 
314
- # in pixeltable_pgserver.get_server(): cleanup_mode=None will leave db on for debugging purposes
315
- self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
315
+ # cleanup_mode=None will leave the postgres process running after Python exits
316
+ # cleanup_mode='stop' will terminate the postgres process when Python exits
317
+ # On Windows, we need cleanup_mode='stop' because child processes are killed automatically when the parent
318
+ # process (such as Terminal or VSCode) exits, potentially leaving it in an unusable state.
319
+ cleanup_mode = 'stop' if platform.system() == 'Windows' else None
320
+ self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=cleanup_mode)
316
321
  self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
317
322
 
318
323
  tz_name = self.config.get_string_value('time_zone')
@@ -357,7 +362,7 @@ class Env:
357
362
  self.db_url,
358
363
  echo=echo,
359
364
  future=True,
360
- isolation_level='AUTOCOMMIT',
365
+ isolation_level='REPEATABLE READ',
361
366
  connect_args=connect_args,
362
367
  )
363
368
  self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
@@ -506,6 +511,7 @@ class Env:
506
511
  self.__register_package('openai')
507
512
  self.__register_package('openpyxl')
508
513
  self.__register_package('pyarrow')
514
+ self.__register_package('pydantic')
509
515
  self.__register_package('replicate')
510
516
  self.__register_package('sentencepiece')
511
517
  self.__register_package('sentence_transformers', library_name='sentence-transformers')
@@ -208,7 +208,7 @@ class ExprEvalNode(ExecNode):
208
208
  }
209
209
  start_ts = time.perf_counter()
210
210
  assert isinstance(fn_call.fn, CallableFunction)
211
- result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
211
+ result_batch = fn_call.fn.exec_batch(call_args, call_kwargs)
212
212
  self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
213
213
  self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
214
214
 
@@ -262,7 +262,7 @@ class SqlNode(ExecNode):
262
262
  explain_str = '\n'.join([str(row) for row in explain_result])
263
263
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
264
264
  except Exception as e:
265
- _logger.warning(f'EXPLAIN failed')
265
+ _logger.warning(f'EXPLAIN failed with error: {e}')
266
266
 
267
267
  def __iter__(self) -> Iterator[DataRowBatch]:
268
268
  # run the query; do this here rather than in _open(), exceptions are only expected during iteration
@@ -468,4 +468,4 @@ class SqlJoinNode(SqlNode):
468
468
  stmt = stmt.join(
469
469
  self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
470
470
  full=join_clause == plan.JoinType.FULL_OUTER)
471
- return stmt
471
+ return stmt
@@ -15,6 +15,7 @@ import pixeltable.type_system as ts
15
15
  from .data_row import DataRow
16
16
  from .expr import Expr
17
17
  from .inline_expr import InlineDict, InlineList
18
+ from .literal import Literal
18
19
  from .row_builder import RowBuilder
19
20
  from .rowid_ref import RowidRef
20
21
  from .sql_element_cache import SqlElementCache
@@ -34,6 +35,7 @@ class FunctionCall(Expr):
34
35
 
35
36
  arg_types: list[ts.ColumnType]
36
37
  kwarg_types: dict[str, ts.ColumnType]
38
+ return_type: ts.ColumnType
37
39
  group_by_start_idx: int
38
40
  group_by_stop_idx: int
39
41
  fn_expr_idx: int
@@ -43,17 +45,25 @@ class FunctionCall(Expr):
43
45
  current_partition_vals: Optional[list[Any]]
44
46
 
45
47
  def __init__(
46
- self, fn: func.Function, bound_args: dict[str, Any], order_by_clause: Optional[list[Any]] = None,
47
- group_by_clause: Optional[list[Any]] = None, is_method_call: bool = False):
48
+ self,
49
+ fn: func.Function,
50
+ bound_args: dict[str, Any],
51
+ return_type: ts.ColumnType,
52
+ order_by_clause: Optional[list[Any]] = None,
53
+ group_by_clause: Optional[list[Any]] = None,
54
+ is_method_call: bool = False
55
+ ):
48
56
  if order_by_clause is None:
49
57
  order_by_clause = []
50
58
  if group_by_clause is None:
51
59
  group_by_clause = []
52
- signature = fn.signature
53
- return_type = fn.call_return_type(bound_args)
60
+
61
+ assert not fn.is_polymorphic
62
+
54
63
  self.fn = fn
55
64
  self.is_method_call = is_method_call
56
- self.normalize_args(fn.name, signature, bound_args)
65
+
66
+ signature = fn.signature
57
67
 
58
68
  # If `return_type` is non-nullable, but the function call has a nullable input to any of its non-nullable
59
69
  # parameters, then we need to make it nullable. This is because Pixeltable defaults a function output to
@@ -67,6 +77,8 @@ class FunctionCall(Expr):
67
77
  return_type = return_type.copy(nullable=True)
68
78
  break
69
79
 
80
+ self.return_type = return_type
81
+
70
82
  super().__init__(return_type)
71
83
 
72
84
  self.agg_init_args = {}
@@ -74,9 +86,9 @@ class FunctionCall(Expr):
74
86
  # we separate out the init args for the aggregator
75
87
  assert isinstance(fn, func.AggregateFunction)
76
88
  self.agg_init_args = {
77
- arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names
89
+ arg_name: arg for arg_name, arg in bound_args.items() if arg_name in fn.init_param_names[0]
78
90
  }
79
- bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names}
91
+ bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names[0]}
80
92
 
81
93
  # construct components, args, kwargs
82
94
  self.args = []
@@ -88,7 +100,7 @@ class FunctionCall(Expr):
88
100
 
89
101
  # the prefix of parameters that are bound can be passed by position
90
102
  processed_args: set[str] = set()
91
- for py_param in fn.signature.py_signature.parameters.values():
103
+ for py_param in signature.py_signature.parameters.values():
92
104
  if py_param.name not in bound_args or py_param.kind == inspect.Parameter.KEYWORD_ONLY:
93
105
  break
94
106
  arg = bound_args[py_param.name]
@@ -110,7 +122,7 @@ class FunctionCall(Expr):
110
122
  self.components.append(arg.copy())
111
123
  else:
112
124
  self.kwargs[param_name] = (None, arg)
113
- if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
125
+ if signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
114
126
  self.kwarg_types[param_name] = signature.parameters[param_name].col_type
115
127
 
116
128
  # window function state:
@@ -129,7 +141,7 @@ class FunctionCall(Expr):
129
141
 
130
142
  if isinstance(self.fn, func.ExprTemplateFunction):
131
143
  # we instantiate the template to create an Expr that can be evaluated and record that as a component
132
- fn_expr = self.fn.instantiate(**bound_args)
144
+ fn_expr = self.fn.instantiate([], bound_args)
133
145
  self.components.append(fn_expr)
134
146
  self.fn_expr_idx = len(self.components) - 1
135
147
  else:
@@ -187,11 +199,6 @@ class FunctionCall(Expr):
187
199
  pass
188
200
 
189
201
  if not isinstance(arg, Expr):
190
- # make sure that non-Expr args are json-serializable and are literals of the correct type
191
- try:
192
- _ = json.dumps(arg)
193
- except TypeError:
194
- raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg} (of type {type(arg)})')
195
202
  if arg is not None:
196
203
  try:
197
204
  param_type = param.col_type
@@ -360,7 +367,7 @@ class FunctionCall(Expr):
360
367
  """
361
368
  assert self.is_agg_fn_call
362
369
  assert isinstance(self.fn, func.AggregateFunction)
363
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
370
+ self.aggregator = self.fn.agg_class(**self.agg_init_args)
364
371
 
365
372
  def update(self, data_row: DataRow) -> None:
366
373
  """
@@ -432,27 +439,32 @@ class FunctionCall(Expr):
432
439
  data_row[self.slot_idx] = self.fn.py_fn(*args, **kwargs)
433
440
  elif self.is_window_fn_call:
434
441
  assert isinstance(self.fn, func.AggregateFunction)
442
+ agg_cls = self.fn.agg_class
435
443
  if self.has_group_by():
436
444
  if self.current_partition_vals is None:
437
445
  self.current_partition_vals = [None] * len(self.group_by)
438
446
  partition_vals = [data_row[e.slot_idx] for e in self.group_by]
439
447
  if partition_vals != self.current_partition_vals:
440
448
  # new partition
441
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
449
+ self.aggregator = agg_cls(**self.agg_init_args)
442
450
  self.current_partition_vals = partition_vals
443
451
  elif self.aggregator is None:
444
- self.aggregator = self.fn.agg_cls(**self.agg_init_args)
452
+ self.aggregator = agg_cls(**self.agg_init_args)
445
453
  self.aggregator.update(*args)
446
454
  data_row[self.slot_idx] = self.aggregator.value()
447
455
  else:
448
- data_row[self.slot_idx] = self.fn.exec(*args, **kwargs)
456
+ data_row[self.slot_idx] = self.fn.exec(args, kwargs)
449
457
 
450
458
  def _as_dict(self) -> dict:
451
459
  result = {
452
- 'fn': self.fn.as_dict(), 'args': self.args, 'kwargs': self.kwargs,
453
- 'group_by_start_idx': self.group_by_start_idx, 'group_by_stop_idx': self.group_by_stop_idx,
460
+ 'fn': self.fn.as_dict(),
461
+ 'args': self.args,
462
+ 'kwargs': self.kwargs,
463
+ 'return_type': self.return_type.as_dict(),
464
+ 'group_by_start_idx': self.group_by_start_idx,
465
+ 'group_by_stop_idx': self.group_by_stop_idx,
454
466
  'order_by_start_idx': self.order_by_start_idx,
455
- **super()._as_dict()
467
+ **super()._as_dict(),
456
468
  }
457
469
  return result
458
470
 
@@ -461,15 +473,108 @@ class FunctionCall(Expr):
461
473
  assert 'fn' in d
462
474
  assert 'args' in d
463
475
  assert 'kwargs' in d
464
- # reassemble bound args
476
+
465
477
  fn = func.Function.from_dict(d['fn'])
466
- param_names = list(fn.signature.parameters.keys())
467
- bound_args = {param_names[i]: arg if idx is None else components[idx] for i, (idx, arg) in enumerate(d['args'])}
468
- bound_args.update(
469
- {param_name: val if idx is None else components[idx] for param_name, (idx, val) in d['kwargs'].items()})
478
+ assert not fn.is_polymorphic
479
+ return_type = ts.ColumnType.from_dict(d['return_type']) if 'return_type' in d else None
470
480
  group_by_exprs = components[d['group_by_start_idx']:d['group_by_stop_idx']]
471
481
  order_by_exprs = components[d['order_by_start_idx']:]
482
+
483
+ args = [
484
+ expr if idx is None else components[idx]
485
+ for idx, expr in d['args']
486
+ ]
487
+ kwargs = {
488
+ param_name: (expr if idx is None else components[idx])
489
+ for param_name, (idx, expr) in d['kwargs'].items()
490
+ }
491
+
492
+ # `Function.from_dict()` does signature matching, so it is safe to assume that `args` and `kwargs` are
493
+ # consistent with its signature.
494
+
495
+ # Reassemble bound_args. Note that args and kwargs represent "already bound arguments": they are not bindable
496
+ # in the Python sense, because variable args (such as *args and **kwargs) have already been condensed.
497
+ param_names = list(fn.signature.parameters.keys())
498
+ bound_args = {param_names[i]: arg for i, arg in enumerate(args)}
499
+ bound_args.update(kwargs.items())
500
+
501
+ # TODO: In order to properly invoke call_return_type, we need to ensure that any InlineLists or InlineDicts
502
+ # in bound_args are unpacked into Python lists/dicts. There is an open task to ensure this is true in general;
503
+ # for now, as a hack, we do the unpacking here for the specific case of an InlineList of Literals (the only
504
+ # case where this is necessary to support existing conditional_return_type implementations). Once the general
505
+ # pattern is implemented, we can remove this hack.
506
+ unpacked_bound_args = {
507
+ param_name: cls.__unpack_bound_arg(arg) for param_name, arg in bound_args.items()
508
+ }
509
+
510
+ # Evaluate the call_return_type as defined in the current codebase.
511
+ call_return_type = fn.call_return_type([], unpacked_bound_args)
512
+
513
+ if return_type is None:
514
+ # Schema versions prior to 25 did not store the return_type in metadata, and there is no obvious way to
515
+ # infer it during DB migration, so we might encounter a stored return_type of None. In that case, we use
516
+ # the call_return_type that we just inferred (which matches the deserialization behavior prior to
517
+ # version 25).
518
+ return_type = call_return_type
519
+ else:
520
+ # There is a return_type stored in metadata (schema version >= 25).
521
+ # Check that the stored return_type of the UDF call matches the column type of the FunctionCall, and
522
+ # fail-fast if it doesn't (otherwise we risk getting downstream database errors).
523
+ # TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
524
+ # mark this FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or Function
525
+ # signature mismatch.
526
+ if not return_type.is_supertype_of(call_return_type, ignore_nullable=True):
527
+ raise excs.Error(
528
+ f'The return type stored in the database for a UDF call to `{fn.self_path}` no longer matches the '
529
+ f'return type of the UDF as currently defined in the code.\nThis probably means that the code for '
530
+ f'`{fn.self_path}` has changed in a backward-incompatible way.\n'
531
+ f'Return type in database: `{return_type}`\n'
532
+ f'Return type as currently defined: `{call_return_type}`'
533
+ )
534
+
472
535
  fn_call = cls(
473
- func.Function.from_dict(d['fn']), bound_args, group_by_clause=group_by_exprs,
474
- order_by_clause=order_by_exprs)
536
+ fn,
537
+ bound_args,
538
+ return_type,
539
+ group_by_clause=group_by_exprs,
540
+ order_by_clause=order_by_exprs
541
+ )
475
542
  return fn_call
543
+
544
+ @classmethod
545
+ def __find_matching_signature(cls, fn: func.Function, args: list[Any], kwargs: dict[str, Any]) -> Optional[int]:
546
+ for idx, sig in enumerate(fn.signatures):
547
+ if cls.__signature_matches(sig, args, kwargs):
548
+ return idx
549
+ return None
550
+
551
+ @classmethod
552
+ def __signature_matches(cls, sig: func.Signature, args: list[Any], kwargs: dict[str, Any]) -> bool:
553
+ unbound_parameters = set(sig.parameters.keys())
554
+ for i, arg in enumerate(args):
555
+ if i >= len(sig.parameters_by_pos):
556
+ return False
557
+ param = sig.parameters_by_pos[i]
558
+ arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
559
+ if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
560
+ return False
561
+ unbound_parameters.remove(param.name)
562
+ for param_name, arg in kwargs.items():
563
+ if param_name not in unbound_parameters:
564
+ return False
565
+ param = sig.parameters[param_name]
566
+ arg_type = arg.col_type if isinstance(arg, Expr) else ts.ColumnType.infer_literal_type(arg)
567
+ if param.col_type is not None and not param.col_type.is_supertype_of(arg_type, ignore_nullable=True):
568
+ return False
569
+ unbound_parameters.remove(param_name)
570
+ for param_name in unbound_parameters:
571
+ param = sig.parameters[param_name]
572
+ if not param.has_default:
573
+ return False
574
+ return True
575
+
576
+ @classmethod
577
+ def __unpack_bound_arg(cls, arg: Any) -> Any:
578
+ if isinstance(arg, InlineList) and all(isinstance(el, Literal) for el in arg.components):
579
+ return [el.val for el in arg.components]
580
+ return arg
@@ -101,7 +101,13 @@ class InlineList(Expr):
101
101
  else:
102
102
  exprs.append(Literal(el))
103
103
 
104
- super().__init__(ts.JsonType())
104
+ json_schema = {
105
+ 'type': 'array',
106
+ 'prefixItems': [expr.col_type.to_json_schema() for expr in exprs],
107
+ 'items': False # No additional items (fixed length)
108
+ }
109
+
110
+ super().__init__(ts.JsonType(json_schema))
105
111
  self.components.extend(exprs)
106
112
  self.id = self._create_id()
107
113
 
@@ -149,7 +155,21 @@ class InlineDict(Expr):
149
155
  else:
150
156
  exprs.append(Literal(val))
151
157
 
152
- super().__init__(ts.JsonType())
158
+ json_schema: Optional[dict[str, Any]]
159
+ try:
160
+ json_schema = {
161
+ 'type': 'object',
162
+ 'properties': {
163
+ key: expr.col_type.to_json_schema()
164
+ for key, expr in zip(self.keys, exprs)
165
+ },
166
+ }
167
+ except excs.Error:
168
+ # InlineDicts are used to store iterator arguments, which are not required to be valid JSON types,
169
+ # so we can't always construct a valid schema.
170
+ json_schema = None
171
+
172
+ super().__init__(ts.JsonType(json_schema))
153
173
  self.components.extend(exprs)
154
174
  self.id = self._create_id()
155
175
 
@@ -368,7 +368,7 @@ class RowBuilder:
368
368
  if not ignore_errors:
369
369
  input_vals = [data_row[d.slot_idx] for d in expr.dependencies()]
370
370
  raise excs.ExprEvalError(
371
- expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0)
371
+ expr, f'expression {expr}', data_row.get_exc(expr.slot_idx), exc_tb, input_vals, 0) from exc
372
372
 
373
373
  def create_table_row(self, data_row: DataRow, exc_col_ids: set[int]) -> tuple[dict[str, Any], int]:
374
374
  """Create a table row from the slots that have an output column assigned