pixeltable 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (94) hide show
  1. pixeltable/__init__.py +5 -3
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -0
  4. pixeltable/catalog/catalog.py +335 -128
  5. pixeltable/catalog/column.py +21 -5
  6. pixeltable/catalog/dir.py +19 -6
  7. pixeltable/catalog/insertable_table.py +34 -37
  8. pixeltable/catalog/named_function.py +0 -4
  9. pixeltable/catalog/schema_object.py +28 -42
  10. pixeltable/catalog/table.py +195 -158
  11. pixeltable/catalog/table_version.py +187 -232
  12. pixeltable/catalog/table_version_handle.py +50 -0
  13. pixeltable/catalog/table_version_path.py +49 -33
  14. pixeltable/catalog/view.py +56 -96
  15. pixeltable/config.py +103 -0
  16. pixeltable/dataframe.py +90 -90
  17. pixeltable/env.py +98 -168
  18. pixeltable/exec/aggregation_node.py +5 -4
  19. pixeltable/exec/cache_prefetch_node.py +1 -1
  20. pixeltable/exec/component_iteration_node.py +13 -9
  21. pixeltable/exec/data_row_batch.py +3 -3
  22. pixeltable/exec/exec_context.py +0 -4
  23. pixeltable/exec/exec_node.py +3 -2
  24. pixeltable/exec/expr_eval/schedulers.py +2 -1
  25. pixeltable/exec/in_memory_data_node.py +9 -4
  26. pixeltable/exec/row_update_node.py +1 -2
  27. pixeltable/exec/sql_node.py +20 -16
  28. pixeltable/exprs/column_ref.py +9 -9
  29. pixeltable/exprs/comparison.py +1 -1
  30. pixeltable/exprs/data_row.py +4 -4
  31. pixeltable/exprs/expr.py +20 -5
  32. pixeltable/exprs/function_call.py +98 -58
  33. pixeltable/exprs/json_mapper.py +25 -8
  34. pixeltable/exprs/json_path.py +6 -5
  35. pixeltable/exprs/object_ref.py +16 -5
  36. pixeltable/exprs/row_builder.py +15 -15
  37. pixeltable/exprs/rowid_ref.py +21 -7
  38. pixeltable/func/__init__.py +1 -1
  39. pixeltable/func/function.py +38 -6
  40. pixeltable/func/query_template_function.py +3 -6
  41. pixeltable/func/tools.py +26 -26
  42. pixeltable/func/udf.py +1 -1
  43. pixeltable/functions/__init__.py +2 -0
  44. pixeltable/functions/anthropic.py +9 -3
  45. pixeltable/functions/fireworks.py +7 -4
  46. pixeltable/functions/globals.py +4 -5
  47. pixeltable/functions/huggingface.py +1 -5
  48. pixeltable/functions/image.py +17 -7
  49. pixeltable/functions/llama_cpp.py +1 -1
  50. pixeltable/functions/mistralai.py +1 -1
  51. pixeltable/functions/ollama.py +4 -4
  52. pixeltable/functions/openai.py +26 -23
  53. pixeltable/functions/string.py +23 -30
  54. pixeltable/functions/timestamp.py +11 -6
  55. pixeltable/functions/together.py +14 -12
  56. pixeltable/functions/util.py +1 -1
  57. pixeltable/functions/video.py +5 -4
  58. pixeltable/functions/vision.py +6 -9
  59. pixeltable/functions/whisper.py +3 -3
  60. pixeltable/globals.py +246 -260
  61. pixeltable/index/__init__.py +2 -0
  62. pixeltable/index/base.py +1 -1
  63. pixeltable/index/btree.py +3 -1
  64. pixeltable/index/embedding_index.py +11 -5
  65. pixeltable/io/external_store.py +11 -12
  66. pixeltable/io/label_studio.py +4 -3
  67. pixeltable/io/parquet.py +57 -56
  68. pixeltable/iterators/__init__.py +4 -2
  69. pixeltable/iterators/audio.py +11 -11
  70. pixeltable/iterators/document.py +10 -10
  71. pixeltable/iterators/string.py +1 -2
  72. pixeltable/iterators/video.py +14 -15
  73. pixeltable/metadata/__init__.py +9 -5
  74. pixeltable/metadata/converters/convert_10.py +0 -1
  75. pixeltable/metadata/converters/convert_15.py +0 -2
  76. pixeltable/metadata/converters/convert_23.py +0 -2
  77. pixeltable/metadata/converters/convert_24.py +3 -3
  78. pixeltable/metadata/converters/convert_25.py +1 -1
  79. pixeltable/metadata/converters/convert_27.py +0 -2
  80. pixeltable/metadata/converters/convert_28.py +0 -2
  81. pixeltable/metadata/converters/convert_29.py +7 -8
  82. pixeltable/metadata/converters/util.py +7 -7
  83. pixeltable/metadata/schema.py +27 -19
  84. pixeltable/plan.py +68 -40
  85. pixeltable/share/packager.py +12 -9
  86. pixeltable/store.py +37 -38
  87. pixeltable/type_system.py +41 -28
  88. pixeltable/utils/filecache.py +2 -1
  89. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/METADATA +1 -1
  90. pixeltable-0.3.7.dist-info/RECORD +174 -0
  91. pixeltable-0.3.5.dist-info/RECORD +0 -172
  92. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/LICENSE +0 -0
  93. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/WHEEL +0 -0
  94. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/entry_points.txt +0 -0
@@ -20,7 +20,8 @@ class InMemoryDataNode(ExecNode):
20
20
  - if an input row doesn't provide a value, sets the slot to the column default
21
21
  """
22
22
 
23
- tbl: catalog.TableVersion
23
+ tbl: catalog.TableVersionHandle
24
+
24
25
  input_rows: list[dict[str, Any]]
25
26
  start_row_id: int
26
27
  output_rows: Optional[DataRowBatch]
@@ -29,12 +30,16 @@ class InMemoryDataNode(ExecNode):
29
30
  output_exprs: list[exprs.ColumnRef]
30
31
 
31
32
  def __init__(
32
- self, tbl: catalog.TableVersion, rows: list[dict[str, Any]], row_builder: exprs.RowBuilder, start_row_id: int
33
+ self,
34
+ tbl: catalog.TableVersionHandle,
35
+ rows: list[dict[str, Any]],
36
+ row_builder: exprs.RowBuilder,
37
+ start_row_id: int,
33
38
  ):
34
39
  # we materialize the input slots
35
40
  output_exprs = list(row_builder.input_exprs)
36
41
  super().__init__(row_builder, output_exprs, [], None)
37
- assert tbl.is_insertable()
42
+ assert tbl.get().is_insertable()
38
43
  self.tbl = tbl
39
44
  self.input_rows = rows
40
45
  self.start_row_id = start_row_id
@@ -62,7 +67,7 @@ class InMemoryDataNode(ExecNode):
62
67
 
63
68
  if col_info.col.col_type.is_image_type() and isinstance(val, bytes):
64
69
  # this is a literal image, ie, a sequence of bytes; we save this as a media file and store the path
65
- path = str(MediaStore.prepare_media_path(self.tbl.id, col_info.col.id, self.tbl.version))
70
+ path = str(MediaStore.prepare_media_path(self.tbl.id, col_info.col.id, self.tbl.get().version))
66
71
  open(path, 'wb').write(val)
67
72
  val = path
68
73
  self.output_rows[row_idx][col_info.slot_idx] = val
@@ -3,7 +3,6 @@ from typing import Any, AsyncIterator
3
3
 
4
4
  import pixeltable.catalog as catalog
5
5
  import pixeltable.exprs as exprs
6
- from pixeltable.utils.media_store import MediaStore
7
6
 
8
7
  from .data_row_batch import DataRowBatch
9
8
  from .exec_node import ExecNode
@@ -40,7 +39,7 @@ class RowUpdateNode(ExecNode):
40
39
  if isinstance(col_ref, exprs.ColumnRef)
41
40
  }
42
41
  self.col_slot_idxs = {col: all_col_slot_idxs[col] for col in col_vals_batch[0].keys()}
43
- self.key_slot_idxs = {col: all_col_slot_idxs[col] for col in tbl.tbl_version.primary_key_columns()}
42
+ self.key_slot_idxs = {col: all_col_slot_idxs[col] for col in tbl.tbl_version.get().primary_key_columns()}
44
43
  self.matched_key_vals: set[tuple] = set()
45
44
 
46
45
  async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
@@ -1,13 +1,14 @@
1
1
  import logging
2
2
  import warnings
3
3
  from decimal import Decimal
4
- from typing import TYPE_CHECKING, AsyncIterator, Iterable, Iterator, NamedTuple, Optional, Sequence
4
+ from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Optional, Sequence
5
5
  from uuid import UUID
6
6
 
7
7
  import sqlalchemy as sql
8
8
 
9
9
  import pixeltable.catalog as catalog
10
10
  import pixeltable.exprs as exprs
11
+ from pixeltable.env import Env
11
12
 
12
13
  from .data_row_batch import DataRowBatch
13
14
  from .exec_node import ExecNode
@@ -122,7 +123,7 @@ class SqlNode(ExecNode):
122
123
  if set_pk:
123
124
  # we also need to retrieve the pk columns
124
125
  assert tbl is not None
125
- self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
126
+ self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
126
127
 
127
128
  # additional state
128
129
  self.result_cursor = None
@@ -142,7 +143,7 @@ class SqlNode(ExecNode):
142
143
  sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
143
144
  if self.set_pk:
144
145
  assert self.tbl is not None
145
- sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
146
+ sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
146
147
  stmt = sql.select(*sql_select_list)
147
148
 
148
149
  where_clause_element = (
@@ -215,29 +216,31 @@ class SqlNode(ExecNode):
215
216
  exact_version_only = set()
216
217
  candidates = tbl.get_tbl_versions()
217
218
  assert len(candidates) > 0
218
- joined_tbls: list[catalog.TableVersion] = [candidates[0]]
219
+ joined_tbls: list[catalog.TableVersionHandle] = [candidates[0]]
219
220
  for tbl in candidates[1:]:
220
221
  if tbl.id in refd_tbl_ids:
221
222
  joined_tbls.append(tbl)
222
223
 
223
224
  first = True
224
- prev_tbl: catalog.TableVersion
225
+ prev_tbl: catalog.TableVersionHandle
225
226
  for tbl in joined_tbls[::-1]:
226
227
  if first:
227
- stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
228
+ stmt = stmt.select_from(tbl.get().store_tbl.sa_tbl)
228
229
  first = False
229
230
  else:
230
231
  # join tbl to prev_tbl on prev_tbl's rowid cols
231
- prev_tbl_rowid_cols = prev_tbl.store_tbl.rowid_columns()
232
- tbl_rowid_cols = tbl.store_tbl.rowid_columns()
232
+ prev_tbl_rowid_cols = prev_tbl.get().store_tbl.rowid_columns()
233
+ tbl_rowid_cols = tbl.get().store_tbl.rowid_columns()
233
234
  rowid_clauses = [
234
235
  c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
235
236
  ]
236
- stmt = stmt.join(tbl.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
237
+ stmt = stmt.join(tbl.get().store_tbl.sa_tbl, sql.and_(*rowid_clauses))
237
238
  if tbl.id in exact_version_only:
238
- stmt = stmt.where(tbl.store_tbl.v_min_col == tbl.version)
239
+ stmt = stmt.where(tbl.get().store_tbl.v_min_col == tbl.get().version)
239
240
  else:
240
- stmt = stmt.where(tbl.store_tbl.v_min_col <= tbl.version).where(tbl.store_tbl.v_max_col > tbl.version)
241
+ stmt = stmt.where(tbl.get().store_tbl.v_min_col <= tbl.get().version).where(
242
+ tbl.get().store_tbl.v_max_col > tbl.get().version
243
+ )
241
244
  prev_tbl = tbl
242
245
  return stmt
243
246
 
@@ -264,10 +267,11 @@ class SqlNode(ExecNode):
264
267
  self.limit = limit
265
268
 
266
269
  def _log_explain(self, stmt: sql.Select) -> None:
270
+ conn = Env.get().conn
267
271
  try:
268
272
  # don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
269
273
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
270
- explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
274
+ explain_result = conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
271
275
  explain_str = '\n'.join([str(row) for row in explain_result])
272
276
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
273
277
  except Exception as e:
@@ -275,7 +279,6 @@ class SqlNode(ExecNode):
275
279
 
276
280
  async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
277
281
  # run the query; do this here rather than in _open(), exceptions are only expected during iteration
278
- assert self.ctx.conn is not None
279
282
  with warnings.catch_warnings(record=True) as w:
280
283
  stmt = self._create_stmt()
281
284
  try:
@@ -286,7 +289,8 @@ class SqlNode(ExecNode):
286
289
  pass
287
290
  self._log_explain(stmt)
288
291
 
289
- result_cursor = self.ctx.conn.execute(stmt)
292
+ conn = Env.get().conn
293
+ result_cursor = conn.execute(stmt)
290
294
  for warning in w:
291
295
  pass
292
296
 
@@ -351,7 +355,7 @@ class SqlScanNode(SqlNode):
351
355
  Supports filtering and ordering.
352
356
  """
353
357
 
354
- exact_version_only: list[catalog.TableVersion]
358
+ exact_version_only: list[catalog.TableVersionHandle]
355
359
 
356
360
  def __init__(
357
361
  self,
@@ -359,7 +363,7 @@ class SqlScanNode(SqlNode):
359
363
  row_builder: exprs.RowBuilder,
360
364
  select_list: Iterable[exprs.Expr],
361
365
  set_pk: bool = False,
362
- exact_version_only: Optional[list[catalog.TableVersion]] = None,
366
+ exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
363
367
  ):
364
368
  """
365
369
  Args:
@@ -52,15 +52,15 @@ class ColumnRef(Expr):
52
52
  assert col.tbl is not None
53
53
  self.col = col
54
54
  self.is_unstored_iter_col = (
55
- col.tbl.is_component_view() and col.tbl.is_iterator_column(col) and not col.is_stored
55
+ col.tbl.get().is_component_view and col.tbl.get().is_iterator_column(col) and not col.is_stored
56
56
  )
57
57
  self.iter_arg_ctx = None
58
58
  # number of rowid columns in the base table
59
- self.base_rowid_len = col.tbl.base.num_rowid_columns() if self.is_unstored_iter_col else 0
59
+ self.base_rowid_len = col.tbl.get().base.get().num_rowid_columns() if self.is_unstored_iter_col else 0
60
60
  self.base_rowid = [None] * self.base_rowid_len
61
61
  self.iterator = None
62
62
  # index of the position column in the view's primary key; don't try to reference tbl.store_tbl here
63
- self.pos_idx = col.tbl.num_rowid_columns() - 1 if self.is_unstored_iter_col else None
63
+ self.pos_idx = col.tbl.get().num_rowid_columns() - 1 if self.is_unstored_iter_col else None
64
64
 
65
65
  self.perform_validation = False
66
66
  if col.col_type.is_media_type():
@@ -138,7 +138,7 @@ class ColumnRef(Expr):
138
138
  return self.col == other.col and self.perform_validation == other.perform_validation
139
139
 
140
140
  def _df(self) -> 'pxt.dataframe.DataFrame':
141
- tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
141
+ tbl = catalog.Catalog.get().get_tbl(self.col.tbl.id)
142
142
  return tbl.select(self)
143
143
 
144
144
  def show(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
@@ -166,9 +166,9 @@ class ColumnRef(Expr):
166
166
  return self._descriptors().to_html()
167
167
 
168
168
  def _descriptors(self) -> DescriptionHelper:
169
- tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
169
+ tbl = catalog.Catalog.get().get_tbl(self.col.tbl.id)
170
170
  helper = DescriptionHelper()
171
- helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
171
+ helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path()!r})')
172
172
  helper.append(tbl._col_descriptor([self.col.name]))
173
173
  idxs = tbl._index_descriptor([self.col.name])
174
174
  if len(idxs) > 0:
@@ -217,7 +217,7 @@ class ColumnRef(Expr):
217
217
  if self.base_rowid != data_row.pk[: self.base_rowid_len]:
218
218
  row_builder.eval(data_row, self.iter_arg_ctx)
219
219
  iterator_args = data_row[self.iter_arg_ctx.target_slot_idxs[0]]
220
- self.iterator = self.col.tbl.iterator_cls(**iterator_args)
220
+ self.iterator = self.col.tbl.get().iterator_cls(**iterator_args)
221
221
  self.base_rowid = data_row.pk[: self.base_rowid_len]
222
222
  self.iterator.set_pos(data_row.pk[self.pos_idx])
223
223
  res = next(self.iterator)
@@ -225,7 +225,7 @@ class ColumnRef(Expr):
225
225
 
226
226
  def _as_dict(self) -> dict:
227
227
  tbl = self.col.tbl
228
- version = tbl.version if tbl.is_snapshot else None
228
+ version = tbl.get().version if tbl.get().is_snapshot else None
229
229
  # we omit self.components, even if this is a validating ColumnRef, because init() will recreate the
230
230
  # non-validating component ColumnRef
231
231
  return {
@@ -238,7 +238,7 @@ class ColumnRef(Expr):
238
238
  @classmethod
239
239
  def get_column(cls, d: dict) -> catalog.Column:
240
240
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
241
- tbl_version = catalog.Catalog.get().tbl_versions[(tbl_id, version)]
241
+ tbl_version = catalog.Catalog.get().get_tbl_version(tbl_id, version)
242
242
  # don't use tbl_version.cols_by_id here, this might be a snapshot reference to a column that was then dropped
243
243
  col = next(col for col in tbl_version.cols if col.id == col_id)
244
244
  return col
@@ -84,7 +84,7 @@ class Comparison(Expr):
84
84
  if self.is_search_arg_comparison:
85
85
  # reference the index value column if there is an index and this is not a snapshot
86
86
  # (indices don't apply to snapshots)
87
- tbl = self._op1.col.tbl
87
+ tbl = self._op1.col.tbl.get()
88
88
  idx_info = [
89
89
  info for info in self._op1.col.get_idx_info().values() if isinstance(info.idx, index.BtreeIndex)
90
90
  ]
@@ -142,13 +142,13 @@ class DataRow:
142
142
  self.file_paths[slot_idx] = None
143
143
  self.file_urls[slot_idx] = None
144
144
 
145
- def __getitem__(self, index: object) -> Any:
145
+ def __getitem__(self, index: int) -> Any:
146
146
  """Returns in-memory value, ie, what is needed for expr evaluation"""
147
147
  assert isinstance(index, int)
148
148
  if not self.has_val[index]:
149
- # for debugging purposes
150
- pass
151
- assert self.has_val[index], index
149
+ # This is a sufficiently cheap and sensitive validation that it makes sense to keep the assertion around
150
+ # even if python is running with -O.
151
+ raise AssertionError(index)
152
152
 
153
153
  if self.file_urls[index] is not None and index in self.img_slot_idxs:
154
154
  # if we need to load this from a file, it should have been materialized locally
pixeltable/exprs/expr.py CHANGED
@@ -14,10 +14,7 @@ import numpy as np
14
14
  import sqlalchemy as sql
15
15
  from typing_extensions import Self, _AnnotatedAlias
16
16
 
17
- import pixeltable.catalog as catalog
18
- import pixeltable.exceptions as excs
19
- import pixeltable.func as func
20
- import pixeltable.type_system as ts
17
+ from pixeltable import catalog, exceptions as excs, func, type_system as ts
21
18
 
22
19
  from .data_row import DataRow
23
20
  from .globals import ArithmeticOperator, ComparisonOperator, LiteralPythonTypes, LogicalOperator
@@ -110,6 +107,24 @@ class Expr(abc.ABC):
110
107
  """
111
108
  return None
112
109
 
110
+ @property
111
+ def validation_error(self) -> Optional[str]:
112
+ """
113
+ Subclasses can override this to indicate that validation has failed after a catalog load.
114
+
115
+ If an Expr (or any of its transitive components) is invalid, then it cannot be evaluated, but its metadata
116
+ will still be preserved in the catalog (so that the user can take appropriate corrective action).
117
+ """
118
+ for c in self.components:
119
+ error = c.validation_error
120
+ if error is not None:
121
+ return error
122
+ return None
123
+
124
+ @property
125
+ def is_valid(self) -> bool:
126
+ return self.validation_error is None
127
+
113
128
  def equals(self, other: Expr) -> bool:
114
129
  """
115
130
  Subclass-specific comparison. Implemented as a function because __eq__() is needed to construct Comparisons.
@@ -245,7 +260,7 @@ class Expr(abc.ABC):
245
260
 
246
261
  def retarget(self, tbl: catalog.TableVersionPath) -> Self:
247
262
  """Retarget ColumnRefs in this expr to the specific TableVersions in tbl."""
248
- tbl_versions = {tbl_version.id: tbl_version for tbl_version in tbl.get_tbl_versions()}
263
+ tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
249
264
  return self._retarget(tbl_versions)
250
265
 
251
266
  def _retarget(self, tbl_versions: dict[UUID, catalog.TableVersion]) -> Self:
@@ -1,7 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import logging
4
5
  import sys
6
+ import warnings
7
+ from textwrap import dedent
5
8
  from typing import Any, Optional, Sequence, Union
6
9
 
7
10
  import sqlalchemy as sql
@@ -18,6 +21,8 @@ from .row_builder import RowBuilder
18
21
  from .rowid_ref import RowidRef
19
22
  from .sql_element_cache import SqlElementCache
20
23
 
24
+ _logger = logging.getLogger('pixeltable')
25
+
21
26
 
22
27
  class FunctionCall(Expr):
23
28
  fn: func.Function
@@ -45,6 +50,8 @@ class FunctionCall(Expr):
45
50
  aggregator: Optional[Any]
46
51
  current_partition_vals: Optional[list[Any]]
47
52
 
53
+ _validation_error: Optional[str]
54
+
48
55
  def __init__(
49
56
  self,
50
57
  fn: func.Function,
@@ -54,6 +61,7 @@ class FunctionCall(Expr):
54
61
  order_by_clause: Optional[list[Any]] = None,
55
62
  group_by_clause: Optional[list[Any]] = None,
56
63
  is_method_call: bool = False,
64
+ validation_error: Optional[str] = None,
57
65
  ):
58
66
  assert not fn.is_polymorphic
59
67
  assert all(isinstance(arg, Expr) for arg in args)
@@ -76,26 +84,6 @@ class FunctionCall(Expr):
76
84
  self.components.extend(arg.copy() for arg in kwargs.values())
77
85
  self.kwarg_idxs = {name: i + len(args) for i, name in enumerate(kwargs.keys())}
78
86
 
79
- # Now generate bound_idxs for the args and kwargs indices.
80
- # This is guaranteed to work, because at this point the call has already been validated.
81
- # These will be used later to dereference specific parameter values.
82
- bindings = fn.signature.py_signature.bind(*self.arg_idxs, **self.kwarg_idxs)
83
- self.bound_idxs = bindings.arguments
84
-
85
- # Separately generate bound_args for purposes of determining the resource pool.
86
- bindings = fn.signature.py_signature.bind(*args, **kwargs)
87
- bound_args = bindings.arguments
88
- self.resource_pool = fn.call_resource_pool(bound_args)
89
-
90
- self.agg_init_args = {}
91
- if self.is_agg_fn_call:
92
- # We separate out the init args for the aggregator. Unpack Literals in init args.
93
- assert isinstance(fn, func.AggregateFunction)
94
- for arg_name, arg in bound_args.items():
95
- if arg_name in fn.init_param_names[0]:
96
- assert isinstance(arg, Literal) # This was checked during validate_call
97
- self.agg_init_args[arg_name] = arg.val
98
-
99
87
  # window function state:
100
88
  # self.components[self.group_by_start_idx:self.group_by_stop_idx] contains group_by exprs
101
89
  self.group_by_start_idx, self.group_by_stop_idx = 0, 0
@@ -125,10 +113,35 @@ class FunctionCall(Expr):
125
113
  raise excs.Error(
126
114
  f'order_by argument needs to be a Pixeltable expression, but instead is a {type(order_by_clause[0])}'
127
115
  )
128
- # don't add components after this, everthing after order_by_start_idx is part of the order_by clause
129
116
  self.order_by_start_idx = len(self.components)
130
117
  self.components.extend(order_by_clause)
131
118
 
119
+ self._validation_error = validation_error
120
+
121
+ if validation_error is not None:
122
+ self.resource_pool = None
123
+ return
124
+
125
+ # Now generate bound_idxs for the args and kwargs indices.
126
+ # This is guaranteed to work, because at this point the call has already been validated.
127
+ # These will be used later to dereference specific parameter values.
128
+ bindings = fn.signature.py_signature.bind(*self.arg_idxs, **self.kwarg_idxs)
129
+ self.bound_idxs = bindings.arguments
130
+
131
+ # Separately generate bound_args for purposes of determining the resource pool.
132
+ bindings = fn.signature.py_signature.bind(*args, **kwargs)
133
+ bound_args = bindings.arguments
134
+ self.resource_pool = fn.call_resource_pool(bound_args)
135
+
136
+ self.agg_init_args = {}
137
+ if self.is_agg_fn_call:
138
+ # We separate out the init args for the aggregator. Unpack Literals in init args.
139
+ assert isinstance(fn, func.AggregateFunction)
140
+ for arg_name, arg in bound_args.items():
141
+ if arg_name in fn.init_param_names[0]:
142
+ assert isinstance(arg, Literal) # This was checked during validate_call
143
+ self.agg_init_args[arg_name] = arg.val
144
+
132
145
  # execution state for aggregate functions
133
146
  self.aggregator = None
134
147
  self.current_partition_vals = None
@@ -137,7 +150,7 @@ class FunctionCall(Expr):
137
150
 
138
151
  def _create_rowid_refs(self, tbl: catalog.Table) -> list[Expr]:
139
152
  target = tbl._tbl_version_path.tbl_version
140
- return [RowidRef(target, i) for i in range(target.num_rowid_columns())]
153
+ return [RowidRef(target, i) for i in range(target.get().num_rowid_columns())]
141
154
 
142
155
  def default_column_name(self) -> Optional[str]:
143
156
  return self.fn.name
@@ -165,12 +178,16 @@ class FunctionCall(Expr):
165
178
  ('group_by_start_idx', self.group_by_start_idx),
166
179
  ('group_by_stop_idx', self.group_by_stop_idx),
167
180
  ('fn_expr_idx', self.fn_expr_idx),
168
- ('order_by_idx', self.order_by_start_idx),
181
+ ('order_by_start_idx', self.order_by_start_idx),
169
182
  ]
170
183
 
171
184
  def __repr__(self) -> str:
172
185
  return self.display_str()
173
186
 
187
+ @property
188
+ def validation_error(self) -> Optional[str]:
189
+ return self._validation_error or super().validation_error
190
+
174
191
  def display_str(self, inline: bool = True) -> str:
175
192
  if self.is_method_call:
176
193
  return f'{self.components[0]}.{self.fn.name}({self._print_args(1, inline)})'
@@ -232,6 +249,8 @@ class FunctionCall(Expr):
232
249
  return self.order_by
233
250
 
234
251
  def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
252
+ assert self.is_valid
253
+
235
254
  # we currently can't translate aggregate functions with grouping and/or ordering to SQL
236
255
  if self.has_group_by() or len(self.order_by) > 0:
237
256
  return None
@@ -304,6 +323,7 @@ class FunctionCall(Expr):
304
323
  Returns a list of dicts mapping each param name to its value when this FunctionCall is evaluated against
305
324
  data_rows
306
325
  """
326
+ assert self.is_valid
307
327
  assert all(name in self.fn.signature.parameters for name in param_names), f'{param_names}, {self.fn.signature}'
308
328
  result: list[dict[str, Any]] = []
309
329
  for row in data_rows:
@@ -327,6 +347,8 @@ class FunctionCall(Expr):
327
347
  return result
328
348
 
329
349
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
350
+ assert self.is_valid
351
+
330
352
  if isinstance(self.fn, func.ExprTemplateFunction):
331
353
  # we need to evaluate the template
332
354
  # TODO: can we get rid of this extra copy?
@@ -396,51 +418,68 @@ class FunctionCall(Expr):
396
418
  group_by_exprs = components[group_by_start_idx:group_by_stop_idx]
397
419
  order_by_exprs = components[order_by_start_idx:]
398
420
 
421
+ validation_error: Optional[str] = None
422
+
423
+ if isinstance(fn, func.InvalidFunction):
424
+ validation_error = (
425
+ dedent(
426
+ f"""
427
+ The UDF '{fn.self_path}' cannot be located, because
428
+ {{errormsg}}
429
+ """
430
+ )
431
+ .strip()
432
+ .format(errormsg=fn.errormsg)
433
+ )
434
+ return cls(fn, args, kwargs, return_type, is_method_call=is_method_call, validation_error=validation_error)
435
+
399
436
  # Now re-bind args and kwargs using the version of `fn` that is currently represented in code. This ensures
400
437
  # that we get a valid binding even if the signatures of `fn` have changed since the FunctionCall was
401
438
  # serialized.
402
439
 
403
- resolved_fn: func.Function
404
- bound_args: dict[str, Expr]
440
+ resolved_fn: func.Function = fn
405
441
 
406
442
  try:
443
+ # Bind args and kwargs to the function signature in the current codebase.
407
444
  resolved_fn, bound_args = fn._bind_to_matching_signature(args, kwargs)
408
445
  except (TypeError, excs.Error):
409
- # TODO: Handle this more gracefully (instead of failing the DB load, allow the DB load to succeed, but
410
- # mark any enclosing FunctionCall as unusable). It's the same issue as dealing with a renamed UDF or
411
- # FunctionCall return type mismatch.
412
446
  signature_note_str = 'any of its signatures' if fn.is_polymorphic else 'its signature'
413
- instance_signature_str = f'{len(fn.signatures)} signatures' if fn.is_polymorphic else str(fn.signature)
414
- raise excs.Error(
415
- f'The signature stored in the database for the UDF `{fn.self_path}` no longer matches '
416
- f'{signature_note_str} as currently defined in the code.\nThis probably means that the code for '
417
- f'`{fn.self_path}` has changed in a backward-incompatible way.\n'
418
- f'Signature in database: {fn}\n'
419
- f'Signature as currently defined in code: {instance_signature_str}'
420
- )
421
-
422
- # Evaluate the call_return_type as defined in the current codebase.
423
- call_return_type = resolved_fn.call_return_type(bound_args)
424
-
425
- if return_type is None:
426
- # Schema versions prior to 25 did not store the return_type in metadata, and there is no obvious way to
427
- # infer it during DB migration, so we might encounter a stored return_type of None. In that case, we use
428
- # the call_return_type that we just inferred (which matches the deserialization behavior prior to
429
- # version 25).
430
- return_type = call_return_type
447
+ args_str = [str(arg.col_type) for arg in args]
448
+ args_str.extend(f'{name}: {arg.col_type}' for name, arg in kwargs.items())
449
+ call_signature_str = f'({", ".join(args_str)}) -> {return_type}'
450
+ fn_signature_str = f'{len(fn.signatures)} signatures' if fn.is_polymorphic else str(fn.signature)
451
+ validation_error = dedent(
452
+ f"""
453
+ The signature stored in the database for a UDF call to {fn.self_path!r} no longer
454
+ matches {signature_note_str} as currently defined in the code. This probably means that the
455
+ code for {fn.self_path!r} has changed in a backward-incompatible way.
456
+ Signature of UDF call in the database: {call_signature_str}
457
+ Signature of UDF as currently defined in code: {fn_signature_str}
458
+ """
459
+ ).strip()
431
460
  else:
432
- # There is a return_type stored in metadata (schema version >= 25).
433
- # Check that the stored return_type of the UDF call matches the column type of the FunctionCall, and
434
- # fail-fast if it doesn't (otherwise we risk getting downstream database errors).
435
- # TODO: Handle this more gracefully (as noted above).
436
- if not return_type.is_supertype_of(call_return_type, ignore_nullable=True):
437
- raise excs.Error(
438
- f'The return type stored in the database for a UDF call to `{fn.self_path}` no longer matches the '
439
- f'return type of the UDF as currently defined in the code.\nThis probably means that the code for '
440
- f'`{fn.self_path}` has changed in a backward-incompatible way.\n'
441
- f'Return type in database: `{return_type}`\n'
442
- f'Return type as currently defined in code: `{call_return_type}`'
443
- )
461
+ # Evaluate the call_return_type as defined in the current codebase.
462
+ call_return_type = resolved_fn.call_return_type(bound_args)
463
+ if return_type is None:
464
+ # Schema versions prior to 25 did not store the return_type in metadata, and there is no obvious way to
465
+ # infer it during DB migration, so we might encounter a stored return_type of None. In that case, we use
466
+ # the call_return_type that we just inferred (which matches the deserialization behavior prior to
467
+ # version 25).
468
+ return_type = call_return_type
469
+ else:
470
+ # There is a return_type stored in metadata (schema version >= 25).
471
+ # Check that the stored return_type of the UDF call matches the column type of the FunctionCall, and
472
+ # fail-fast if it doesn't (otherwise we risk getting downstream database errors).
473
+ if not return_type.is_supertype_of(call_return_type, ignore_nullable=True):
474
+ validation_error = dedent(
475
+ f"""
476
+ The return type stored in the database for a UDF call to {fn.self_path!r} no longer
477
+ matches its return type as currently defined in the code. This probably means that the
478
+ code for {fn.self_path!r} has changed in a backward-incompatible way.
479
+ Return type of UDF call in the database: {return_type}
480
+ Return type of UDF as currently defined in code: {call_return_type}
481
+ """
482
+ ).strip()
444
483
 
445
484
  fn_call = cls(
446
485
  resolved_fn,
@@ -450,6 +489,7 @@ class FunctionCall(Expr):
450
489
  group_by_clause=group_by_exprs,
451
490
  order_by_clause=order_by_exprs,
452
491
  is_method_call=is_method_call,
492
+ validation_error=validation_error,
453
493
  )
454
494
 
455
495
  return fn_call
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
3
+ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
@@ -11,6 +11,9 @@ from .expr import _GLOBAL_SCOPE, Expr, ExprScope
11
11
  from .row_builder import RowBuilder
12
12
  from .sql_element_cache import SqlElementCache
13
13
 
14
+ if TYPE_CHECKING:
15
+ from .object_ref import ObjectRef
16
+
14
17
 
15
18
  class JsonMapper(Expr):
16
19
  """
@@ -19,6 +22,10 @@ class JsonMapper(Expr):
19
22
  is populated by JsonMapper.eval(). The JsonMapper effectively creates a new scope for its target expr.
20
23
  """
21
24
 
25
+ target_expr_scope: ExprScope
26
+ parent_mapper: Optional[JsonMapper]
27
+ target_expr_eval_ctx: Optional[RowBuilder.EvalCtx]
28
+
22
29
  def __init__(self, src_expr: Expr, target_expr: Expr):
23
30
  # TODO: type spec should be list[target_expr.col_type]
24
31
  super().__init__(ts.JsonType())
@@ -29,12 +36,18 @@ class JsonMapper(Expr):
29
36
 
30
37
  from .object_ref import ObjectRef
31
38
 
32
- scope_anchor = ObjectRef(self.target_expr_scope, self)
33
- self.components = [src_expr, target_expr, scope_anchor]
34
- self.parent_mapper: Optional[JsonMapper] = None
35
- self.target_expr_eval_ctx: Optional[RowBuilder.EvalCtx] = None
39
+ self.components = [src_expr, target_expr]
40
+ self.parent_mapper = None
41
+ self.target_expr_eval_ctx = None
42
+
43
+ # Intentionally create the id now, before adding the scope anchor; this ensures that JsonMappers will
44
+ # be recognized as equal so long as they have the same src_expr and target_expr.
45
+ # TODO: Might this cause problems after certain substitutions?
36
46
  self.id = self._create_id()
37
47
 
48
+ scope_anchor = ObjectRef(self.target_expr_scope, self)
49
+ self.components.append(scope_anchor)
50
+
38
51
  def bind_rel_paths(self, mapper: Optional[JsonMapper] = None) -> None:
39
52
  self._src_expr.bind_rel_paths(mapper)
40
53
  self._target_expr.bind_rel_paths(self)
@@ -84,8 +97,12 @@ class JsonMapper(Expr):
84
97
  return self.components[1]
85
98
 
86
99
  @property
87
- def scope_anchor(self) -> Expr:
88
- return self.components[2]
100
+ def scope_anchor(self) -> 'ObjectRef':
101
+ from .object_ref import ObjectRef
102
+
103
+ result = self.components[2]
104
+ assert isinstance(result, ObjectRef)
105
+ return result
89
106
 
90
107
  def _equals(self, _: JsonMapper) -> bool:
91
108
  return True
@@ -107,7 +124,7 @@ class JsonMapper(Expr):
107
124
  for i, val in enumerate(src):
108
125
  data_row[self.scope_anchor.slot_idx] = val
109
126
  # stored target_expr
110
- row_builder.eval(data_row, self.target_expr_eval_ctx)
127
+ row_builder.eval(data_row, self.target_expr_eval_ctx, force_eval=self._target_expr.scope())
111
128
  result[i] = data_row[self._target_expr.slot_idx]
112
129
  data_row[self.slot_idx] = result
113
130