pixeltable 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/table.py +247 -83
  3. pixeltable/catalog/view.py +5 -2
  4. pixeltable/dataframe.py +240 -92
  5. pixeltable/exec/__init__.py +1 -1
  6. pixeltable/exec/exec_node.py +6 -7
  7. pixeltable/exec/sql_node.py +91 -44
  8. pixeltable/exprs/__init__.py +1 -0
  9. pixeltable/exprs/arithmetic_expr.py +1 -1
  10. pixeltable/exprs/array_slice.py +1 -1
  11. pixeltable/exprs/column_property_ref.py +1 -1
  12. pixeltable/exprs/column_ref.py +29 -2
  13. pixeltable/exprs/comparison.py +1 -1
  14. pixeltable/exprs/compound_predicate.py +1 -1
  15. pixeltable/exprs/expr.py +11 -5
  16. pixeltable/exprs/expr_set.py +8 -0
  17. pixeltable/exprs/function_call.py +14 -11
  18. pixeltable/exprs/in_predicate.py +1 -1
  19. pixeltable/exprs/inline_expr.py +3 -3
  20. pixeltable/exprs/is_null.py +1 -1
  21. pixeltable/exprs/json_mapper.py +1 -1
  22. pixeltable/exprs/json_path.py +1 -1
  23. pixeltable/exprs/method_ref.py +1 -1
  24. pixeltable/exprs/rowid_ref.py +1 -1
  25. pixeltable/exprs/similarity_expr.py +4 -1
  26. pixeltable/exprs/sql_element_cache.py +4 -0
  27. pixeltable/exprs/type_cast.py +2 -2
  28. pixeltable/exprs/variable.py +3 -0
  29. pixeltable/func/expr_template_function.py +3 -0
  30. pixeltable/func/function.py +37 -1
  31. pixeltable/func/signature.py +1 -0
  32. pixeltable/functions/mistralai.py +0 -2
  33. pixeltable/functions/ollama.py +4 -4
  34. pixeltable/globals.py +32 -18
  35. pixeltable/index/embedding_index.py +6 -1
  36. pixeltable/io/__init__.py +1 -1
  37. pixeltable/io/parquet.py +39 -19
  38. pixeltable/iterators/__init__.py +1 -0
  39. pixeltable/iterators/image.py +100 -0
  40. pixeltable/iterators/video.py +7 -8
  41. pixeltable/metadata/__init__.py +1 -1
  42. pixeltable/metadata/converters/convert_22.py +17 -0
  43. pixeltable/metadata/notes.py +1 -0
  44. pixeltable/plan.py +129 -51
  45. pixeltable/store.py +1 -1
  46. pixeltable/tool/create_test_db_dump.py +4 -1
  47. pixeltable/type_system.py +1 -1
  48. pixeltable/utils/arrow.py +8 -3
  49. pixeltable/utils/description_helper.py +89 -0
  50. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/METADATA +28 -12
  51. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/RECORD +54 -51
  52. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/WHEEL +1 -1
  53. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/LICENSE +0 -0
  54. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,19 @@
1
1
  import logging
2
2
  import warnings
3
3
  from decimal import Decimal
4
- from typing import Iterable, Iterator, NamedTuple, Optional
4
+ from typing import Iterable, Iterator, NamedTuple, Optional, TYPE_CHECKING, Sequence
5
5
  from uuid import UUID
6
6
 
7
7
  import sqlalchemy as sql
8
8
 
9
9
  import pixeltable.catalog as catalog
10
10
  import pixeltable.exprs as exprs
11
-
12
11
  from .data_row_batch import DataRowBatch
13
12
  from .exec_node import ExecNode
14
13
 
14
+ if TYPE_CHECKING:
15
+ import pixeltable.plan
16
+
15
17
  _logger = logging.getLogger('pixeltable')
16
18
 
17
19
 
@@ -67,12 +69,17 @@ class SqlNode(ExecNode):
67
69
  select_list: exprs.ExprSet
68
70
  set_pk: bool
69
71
  num_pk_cols: int
70
- filter: Optional[exprs.Expr]
71
- filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
72
+ py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
73
+ py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
72
74
  cte: Optional[sql.CTE]
73
75
  sql_elements: exprs.SqlElementCache
74
- limit: Optional[int]
76
+
77
+ # where_clause/-_element: allow subclass to set one or the other (but not both)
78
+ where_clause: Optional[exprs.Expr]
79
+ where_clause_element: Optional[sql.ColumnElement]
80
+
75
81
  order_by_clause: OrderByClause
82
+ limit: Optional[int]
76
83
 
77
84
  def __init__(
78
85
  self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
@@ -89,6 +96,7 @@ class SqlNode(ExecNode):
89
96
  # create Select stmt
90
97
  self.sql_elements = sql_elements
91
98
  self.tbl = tbl
99
+ assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
92
100
  self.select_list = exprs.ExprSet(select_list)
93
101
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
94
102
  for iter_arg in row_builder.unstored_iter_args.values():
@@ -112,10 +120,12 @@ class SqlNode(ExecNode):
112
120
  # additional state
113
121
  self.result_cursor = None
114
122
  # the filter is provided by the subclass
115
- self.filter = None
116
- self.filter_eval_ctx = None
123
+ self.py_filter = None
124
+ self.py_filter_eval_ctx = None
117
125
  self.cte = None
118
126
  self.limit = None
127
+ self.where_clause = None
128
+ self.where_clause_element = None
119
129
  self.order_by_clause = []
120
130
 
121
131
  def _create_stmt(self) -> sql.Select:
@@ -124,9 +134,16 @@ class SqlNode(ExecNode):
124
134
  assert self.sql_elements.contains_all(self.select_list)
125
135
  sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
126
136
  if self.set_pk:
137
+ assert self.tbl is not None
127
138
  sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
128
139
  stmt = sql.select(*sql_select_list)
129
140
 
141
+ where_clause_element = (
142
+ self.sql_elements.get(self.where_clause) if self.where_clause is not None else self.where_clause_element
143
+ )
144
+ if where_clause_element is not None:
145
+ stmt = stmt.where(where_clause_element)
146
+
130
147
  order_by_clause: list[sql.ColumnElement] = []
131
148
  for e, asc in self.order_by_clause:
132
149
  if isinstance(e, exprs.SimilarityExpr):
@@ -135,7 +152,7 @@ class SqlNode(ExecNode):
135
152
  order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
136
153
  stmt = stmt.order_by(*order_by_clause)
137
154
 
138
- if self.filter is None and self.limit is not None:
155
+ if self.py_filter is None and self.limit is not None:
139
156
  # if we don't have a Python filter, we can apply the limit to stmt
140
157
  stmt = stmt.limit(self.limit)
141
158
 
@@ -151,7 +168,7 @@ class SqlNode(ExecNode):
151
168
  Returns:
152
169
  (CTE, dict from Expr to output column)
153
170
  """
154
- if self.filter is not None:
171
+ if self.py_filter is not None:
155
172
  # the filter needs to run in Python
156
173
  return None
157
174
  self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
@@ -215,8 +232,17 @@ class SqlNode(ExecNode):
215
232
  prev_tbl = tbl
216
233
  return stmt
217
234
 
218
- def add_order_by(self, ordering: OrderByClause) -> None:
219
- """Add Order By clause to stmt"""
235
+ def set_where(self, where_clause: exprs.Expr) -> None:
236
+ assert self.where_clause_element is None
237
+ self.where_clause = where_clause
238
+
239
+ def set_py_filter(self, py_filter: exprs.Expr) -> None:
240
+ assert self.py_filter is None
241
+ self.py_filter = py_filter
242
+ self.py_filter_eval_ctx = self.row_builder.create_eval_ctx([py_filter], exclude=self.select_list)
243
+
244
+ def set_order_by(self, ordering: OrderByClause) -> None:
245
+ """Add Order By clause"""
220
246
  if self.tbl is not None:
221
247
  # change rowid refs against a base table to rowid refs against the target table, so that we minimize
222
248
  # the number of tables that need to be joined to the target table
@@ -280,10 +306,10 @@ class SqlNode(ExecNode):
280
306
  else:
281
307
  output_row[slot_idx] = sql_row[i]
282
308
 
283
- if self.filter is not None:
309
+ if self.py_filter is not None:
284
310
  # evaluate filter
285
- self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
286
- if self.filter is not None and not output_row[self.filter.slot_idx]:
311
+ self.row_builder.eval(output_row, self.py_filter_eval_ctx, profile=self.ctx.profile)
312
+ if self.py_filter is not None and not output_row[self.py_filter.slot_idx]:
287
313
  # we re-use this row for the next sql row since it didn't pass the filter
288
314
  output_row = output_batch.pop_row()
289
315
  output_row.clear()
@@ -315,21 +341,16 @@ class SqlScanNode(SqlNode):
315
341
 
316
342
  Supports filtering and ordering.
317
343
  """
318
- where_clause: Optional[exprs.Expr]
319
344
  exact_version_only: list[catalog.TableVersion]
320
345
 
321
346
  def __init__(
322
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
323
- select_list: Iterable[exprs.Expr],
324
- where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
325
- set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
347
+ self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
348
+ select_list: Iterable[exprs.Expr],
349
+ set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
326
350
  ):
327
351
  """
328
352
  Args:
329
353
  select_list: output of the query
330
- sql_where_clause: SQL Where clause
331
- filter: additional Where-clause predicate that can't be evaluated via SQL
332
- limit: max number of rows to return: 0 = no limit
333
354
  set_pk: if True, sets the primary for each DataRow
334
355
  exact_version_only: tables for which we only want to see rows created at the current version
335
356
  """
@@ -338,12 +359,7 @@ class SqlScanNode(SqlNode):
338
359
  # create Select stmt
339
360
  if exact_version_only is None:
340
361
  exact_version_only = []
341
- target = tbl.tbl_version # the stored table we're scanning
342
- self.filter = filter
343
- self.filter_eval_ctx = \
344
- row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
345
362
 
346
- self.where_clause = where_clause
347
363
  self.exact_version_only = exact_version_only
348
364
 
349
365
  def _create_stmt(self) -> sql.Select:
@@ -352,12 +368,6 @@ class SqlScanNode(SqlNode):
352
368
  refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
353
369
  stmt = self.create_from_clause(
354
370
  self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
355
-
356
- if self.where_clause is not None:
357
- sql_where_clause = self.sql_elements.get(self.where_clause)
358
- assert sql_where_clause is not None
359
- stmt = stmt.where(sql_where_clause)
360
-
361
371
  return stmt
362
372
 
363
373
 
@@ -366,11 +376,9 @@ class SqlLookupNode(SqlNode):
366
376
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
367
377
  """
368
378
 
369
- where_clause: sql.ColumnElement
370
-
371
379
  def __init__(
372
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
373
- select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
380
+ self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
381
+ select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
374
382
  ):
375
383
  """
376
384
  Args:
@@ -381,15 +389,15 @@ class SqlLookupNode(SqlNode):
381
389
  sql_elements = exprs.SqlElementCache()
382
390
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
383
391
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
384
- self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
392
+ self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
385
393
 
386
394
  def _create_stmt(self) -> sql.Select:
387
395
  stmt = super()._create_stmt()
388
396
  refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
389
397
  stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
390
- stmt = stmt.where(self.where_clause)
391
398
  return stmt
392
399
 
400
+
393
401
  class SqlAggregationNode(SqlNode):
394
402
  """
395
403
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
@@ -398,11 +406,11 @@ class SqlAggregationNode(SqlNode):
398
406
  group_by_items: Optional[list[exprs.Expr]]
399
407
 
400
408
  def __init__(
401
- self, row_builder: exprs.RowBuilder,
402
- input: SqlNode,
403
- select_list: Iterable[exprs.Expr],
404
- group_by_items: Optional[list[exprs.Expr]] = None,
405
- limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
409
+ self, row_builder: exprs.RowBuilder,
410
+ input: SqlNode,
411
+ select_list: Iterable[exprs.Expr],
412
+ group_by_items: Optional[list[exprs.Expr]] = None,
413
+ limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
406
414
  ):
407
415
  """
408
416
  Args:
@@ -422,3 +430,42 @@ class SqlAggregationNode(SqlNode):
422
430
  assert all(e is not None for e in sql_group_by_items)
423
431
  stmt = stmt.group_by(*sql_group_by_items)
424
432
  return stmt
433
+
434
+
435
+ class SqlJoinNode(SqlNode):
436
+ """
437
+ Materializes data from the store via a Select ... From ... that contains joins
438
+ """
439
+ input_ctes: list[sql.CTE]
440
+ join_clauses: list['pixeltable.plan.JoinClause']
441
+
442
+ def __init__(
443
+ self, row_builder: exprs.RowBuilder,
444
+ inputs: Sequence[SqlNode], join_clauses: list['pixeltable.plan.JoinClause'], select_list: Iterable[exprs.Expr]
445
+ ):
446
+ assert len(inputs) > 1
447
+ assert len(inputs) == len(join_clauses) + 1
448
+ self.input_ctes = []
449
+ self.join_clauses = join_clauses
450
+ sql_elements = exprs.SqlElementCache()
451
+ for input_node in inputs:
452
+ input_cte, input_col_map = input_node.to_cte()
453
+ self.input_ctes.append(input_cte)
454
+ sql_elements.extend(input_col_map)
455
+ super().__init__(None, row_builder, select_list, sql_elements)
456
+
457
+ def _create_stmt(self) -> sql.Select:
458
+ from pixeltable import plan
459
+ stmt = super()._create_stmt()
460
+ stmt = stmt.select_from(self.input_ctes[0])
461
+ for i in range(len(self.join_clauses)):
462
+ join_clause = self.join_clauses[i]
463
+ on_clause = (
464
+ self.sql_elements.get(join_clause.join_predicate) if join_clause.join_type != plan.JoinType.CROSS
465
+ else sql.sql.expression.literal(True)
466
+ )
467
+ is_outer = join_clause.join_type == plan.JoinType.LEFT or join_clause.join_type == plan.JoinType.FULL_OUTER
468
+ stmt = stmt.join(
469
+ self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
470
+ full=join_clause == plan.JoinType.FULL_OUTER)
471
+ return stmt
@@ -23,3 +23,4 @@ from .similarity_expr import SimilarityExpr
23
23
  from .sql_element_cache import SqlElementCache
24
24
  from .type_cast import TypeCast
25
25
  from .variable import Variable
26
+ from .globals import ComparisonOperator, LogicalOperator, ArithmeticOperator
@@ -35,7 +35,7 @@ class ArithmeticExpr(Expr):
35
35
 
36
36
  self.id = self._create_id()
37
37
 
38
- def __str__(self) -> str:
38
+ def __repr__(self) -> str:
39
39
  # add parentheses around operands that are ArithmeticExprs to express precedence
40
40
  op1_str = f'({self._op1})' if isinstance(self._op1, ArithmeticExpr) else str(self._op1)
41
41
  op2_str = f'({self._op2})' if isinstance(self._op2, ArithmeticExpr) else str(self._op2)
@@ -23,7 +23,7 @@ class ArraySlice(Expr):
23
23
  self.index = index
24
24
  self.id = self._create_id()
25
25
 
26
- def __str__(self) -> str:
26
+ def __repr__(self) -> str:
27
27
  index_strs: list[str] = []
28
28
  for el in self.index:
29
29
  if isinstance(el, int):
@@ -46,7 +46,7 @@ class ColumnPropertyRef(Expr):
46
46
  assert isinstance(col_ref, ColumnRef)
47
47
  return col_ref
48
48
 
49
- def __str__(self) -> str:
49
+ def __repr__(self) -> str:
50
50
  return f'{self._col_ref}.{self.prop.name.lower()}'
51
51
 
52
52
  def is_error_prop(self) -> bool:
@@ -5,10 +5,12 @@ from uuid import UUID
5
5
 
6
6
  import sqlalchemy as sql
7
7
 
8
+ import pixeltable as pxt
8
9
  import pixeltable.catalog as catalog
9
10
  import pixeltable.exceptions as excs
10
11
  import pixeltable.iterators as iters
11
12
 
13
+ from ..utils.description_helper import DescriptionHelper
12
14
  from .data_row import DataRow
13
15
  from .expr import Expr
14
16
  from .row_builder import RowBuilder
@@ -126,6 +128,22 @@ class ColumnRef(Expr):
126
128
  def _equals(self, other: ColumnRef) -> bool:
127
129
  return self.col == other.col and self.perform_validation == other.perform_validation
128
130
 
131
+ def _df(self) -> 'pxt.dataframe.DataFrame':
132
+ tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
133
+ return tbl.select(self)
134
+
135
+ def show(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
136
+ return self._df().show(*args, **kwargs)
137
+
138
+ def head(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
139
+ return self._df().head(*args, **kwargs)
140
+
141
+ def tail(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
142
+ return self._df().tail(*args, **kwargs)
143
+
144
+ def count(self) -> int:
145
+ return self._df().count()
146
+
129
147
  def __str__(self) -> str:
130
148
  if self.col.name is None:
131
149
  return f'<unnamed column {self.col.id}>'
@@ -133,11 +151,20 @@ class ColumnRef(Expr):
133
151
  return self.col.name
134
152
 
135
153
  def __repr__(self) -> str:
136
- return f'ColumnRef({self.col!r})'
154
+ return self._descriptors().to_string()
137
155
 
138
156
  def _repr_html_(self) -> str:
157
+ return self._descriptors().to_html()
158
+
159
+ def _descriptors(self) -> DescriptionHelper:
139
160
  tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
140
- return tbl._description_html(cols=[self.col])._repr_html_() # type: ignore[attr-defined]
161
+ helper = DescriptionHelper()
162
+ helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
163
+ helper.append(tbl._col_descriptor([self.col.name]))
164
+ idxs = tbl._index_descriptor([self.col.name])
165
+ if len(idxs) > 0:
166
+ helper.append(idxs)
167
+ return helper
141
168
 
142
169
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
143
170
  return None if self.perform_validation else self.col.sa_col
@@ -49,7 +49,7 @@ class Comparison(Expr):
49
49
 
50
50
  self.id = self._create_id()
51
51
 
52
- def __str__(self) -> str:
52
+ def __repr__(self) -> str:
53
53
  return f'{self._op1} {self.operator} {self._op2}'
54
54
 
55
55
  def _equals(self, other: Comparison) -> bool:
@@ -30,7 +30,7 @@ class CompoundPredicate(Expr):
30
30
 
31
31
  self.id = self._create_id()
32
32
 
33
- def __str__(self) -> str:
33
+ def __repr__(self) -> str:
34
34
  if self.operator == LogicalOperator.NOT:
35
35
  return f'~({self.components[0]})'
36
36
  return f' {self.operator} '.join([f'({e})' for e in self.components])
pixeltable/exprs/expr.py CHANGED
@@ -216,12 +216,12 @@ class Expr(abc.ABC):
216
216
  return result
217
217
  result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
218
218
 
219
- def is_bound_by(self, tbl: catalog.TableVersionPath) -> bool:
220
- """Returns True if this expr can be evaluated in the context of tbl."""
219
+ def is_bound_by(self, tbls: list[catalog.TableVersionPath]) -> bool:
220
+ """Returns True if this expr can be evaluated in the context of tbls."""
221
221
  from .column_ref import ColumnRef
222
222
  col_refs = self.subexprs(ColumnRef)
223
223
  for col_ref in col_refs:
224
- if not tbl.has_column(col_ref.col):
224
+ if not any(tbl.has_column(col_ref.col) for tbl in tbls):
225
225
  return False
226
226
  return True
227
227
 
@@ -235,7 +235,7 @@ class Expr(abc.ABC):
235
235
  self.components[i] = self.components[i]._retarget(tbl_versions)
236
236
  return self
237
237
 
238
- def __str__(self) -> str:
238
+ def __repr__(self) -> str:
239
239
  return f'<Expression of type {type(self)}>'
240
240
 
241
241
  def display_str(self, inline: bool = True) -> str:
@@ -450,7 +450,13 @@ class Expr(abc.ABC):
450
450
 
451
451
  def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'exprs.TypeCast':
452
452
  from pixeltable.exprs import TypeCast
453
- return TypeCast(self, ts.ColumnType.normalize_type(new_type))
453
+ # Interpret the type argument the same way we would if given in a schema
454
+ col_type = ts.ColumnType.normalize_type(new_type, nullable_default=True, allow_builtin_types=False)
455
+ if not self.col_type.nullable:
456
+ # This expression is non-nullable; we can prove that the output is non-nullable, regardless of
457
+ # whether new_type is given as nullable.
458
+ col_type = col_type.copy(nullable=False)
459
+ return TypeCast(self, col_type)
454
460
 
455
461
  def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'exprs.FunctionCall':
456
462
  if col_type is not None:
@@ -60,6 +60,14 @@ class ExprSet(Generic[T]):
60
60
  def __le__(self, other: ExprSet[T]) -> bool:
61
61
  return other.issuperset(self)
62
62
 
63
+ def union(self, *others: Iterable[T]) -> ExprSet[T]:
64
+ result = ExprSet(self.exprs.values())
65
+ result.update(*others)
66
+ return result
67
+
68
+ def __or__(self, other: ExprSet[T]) -> ExprSet[T]:
69
+ return self.union(other)
70
+
63
71
  def difference(self, *others: Iterable[T]) -> ExprSet[T]:
64
72
  id_diff = set(self.exprs.keys()).difference(e.id for other_set in others for e in other_set)
65
73
  return ExprSet(self.exprs[id] for id in id_diff)
@@ -85,7 +85,9 @@ class FunctionCall(Expr):
85
85
  # we record the types of non-variable parameters for runtime type checks
86
86
  self.arg_types = []
87
87
  self.kwarg_types = {}
88
+
88
89
  # the prefix of parameters that are bound can be passed by position
90
+ processed_args: set[str] = set()
89
91
  for py_param in fn.signature.py_signature.parameters.values():
90
92
  if py_param.name not in bound_args or py_param.kind == inspect.Parameter.KEYWORD_ONLY:
91
93
  break
@@ -97,18 +99,19 @@ class FunctionCall(Expr):
97
99
  self.args.append((None, arg))
98
100
  if py_param.kind != inspect.Parameter.VAR_POSITIONAL and py_param.kind != inspect.Parameter.VAR_KEYWORD:
99
101
  self.arg_types.append(signature.parameters[py_param.name].col_type)
102
+ processed_args.add(py_param.name)
100
103
 
101
104
  # the remaining args are passed as keywords
102
- kw_param_names = set(bound_args.keys()) - set(list(fn.signature.py_signature.parameters.keys())[:len(self.args)])
103
- for param_name in kw_param_names:
104
- arg = bound_args[param_name]
105
- if isinstance(arg, Expr):
106
- self.kwargs[param_name] = (len(self.components), None)
107
- self.components.append(arg.copy())
108
- else:
109
- self.kwargs[param_name] = (None, arg)
110
- if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
111
- self.kwarg_types[param_name] = signature.parameters[param_name].col_type
105
+ for param_name in bound_args.keys():
106
+ if param_name not in processed_args:
107
+ arg = bound_args[param_name]
108
+ if isinstance(arg, Expr):
109
+ self.kwargs[param_name] = (len(self.components), None)
110
+ self.components.append(arg.copy())
111
+ else:
112
+ self.kwargs[param_name] = (None, arg)
113
+ if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
114
+ self.kwarg_types[param_name] = signature.parameters[param_name].col_type
112
115
 
113
116
  # window function state:
114
117
  # self.components[self.group_by_start_idx:self.group_by_stop_idx] contains group_by exprs
@@ -255,7 +258,7 @@ class FunctionCall(Expr):
255
258
  ('order_by_start_idx', self.order_by_start_idx)
256
259
  ]
257
260
 
258
- def __str__(self) -> str:
261
+ def __repr__(self) -> str:
259
262
  return self.display_str()
260
263
 
261
264
  def display_str(self, inline: bool = True) -> str:
@@ -61,7 +61,7 @@ class InPredicate(Expr):
61
61
  pass
62
62
  return result
63
63
 
64
- def __str__(self) -> str:
64
+ def __repr__(self) -> str:
65
65
  if self.value_list is not None:
66
66
  return f'{self.components[0]}.isin({self.value_list})'
67
67
  return f'{self.components[0]}.isin({self.components[1]})'
@@ -56,7 +56,7 @@ class InlineArray(Expr):
56
56
  self.components.extend(exprs)
57
57
  self.id = self._create_id()
58
58
 
59
- def __str__(self) -> str:
59
+ def __repr__(self) -> str:
60
60
  elem_strs = [str(expr) for expr in self.components]
61
61
  return f'[{", ".join(elem_strs)}]'
62
62
 
@@ -105,7 +105,7 @@ class InlineList(Expr):
105
105
  self.components.extend(exprs)
106
106
  self.id = self._create_id()
107
107
 
108
- def __str__(self) -> str:
108
+ def __repr__(self) -> str:
109
109
  elem_strs = [str(expr) for expr in self.components]
110
110
  return f'[{", ".join(elem_strs)}]'
111
111
 
@@ -153,7 +153,7 @@ class InlineDict(Expr):
153
153
  self.components.extend(exprs)
154
154
  self.id = self._create_id()
155
155
 
156
- def __str__(self) -> str:
156
+ def __repr__(self) -> str:
157
157
  item_strs = list(f"'{key}': {str(expr)}" for key, expr in zip(self.keys, self.components))
158
158
  return '{' + ', '.join(item_strs) + '}'
159
159
 
@@ -18,7 +18,7 @@ class IsNull(Expr):
18
18
  self.components = [e]
19
19
  self.id = self._create_id()
20
20
 
21
- def __str__(self) -> str:
21
+ def __repr__(self) -> str:
22
22
  return f'{str(self.components[0])} == None'
23
23
 
24
24
  def _equals(self, other: IsNull) -> bool:
@@ -69,7 +69,7 @@ class JsonMapper(Expr):
69
69
  return False
70
70
  return self._src_expr.equals(other._src_expr) and self._target_expr.equals(other._target_expr)
71
71
 
72
- def __str__(self) -> str:
72
+ def __repr__(self) -> str:
73
73
  return f'{str(self._src_expr)} >> {str(self._target_expr)}'
74
74
 
75
75
  @property
@@ -42,7 +42,7 @@ class JsonPath(Expr):
42
42
  # this is not a problem, because _create_id() shouldn't be called after init()
43
43
  self.id = self._create_id()
44
44
 
45
- def __str__(self) -> str:
45
+ def __repr__(self) -> str:
46
46
  # else "R": the anchor is RELATIVE_PATH_ROOT
47
47
  return (f'{str(self._anchor) if self._anchor is not None else "R"}'
48
48
  f'{"." if isinstance(self.path_elements[0], str) else ""}{self._json_path()}')
@@ -60,5 +60,5 @@ class MethodRef(Expr):
60
60
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
61
61
  assert False, 'MethodRef cannot be evaluated directly'
62
62
 
63
- def __str__(self) -> str:
63
+ def __repr__(self) -> str:
64
64
  return f'{self.base_expr}.{self.method_name}'
@@ -55,7 +55,7 @@ class RowidRef(Expr):
55
55
  return super()._id_attrs() +\
56
56
  [('normalized_base_id', self.normalized_base_id), ('idx', self.rowid_component_idx)]
57
57
 
58
- def __str__(self) -> str:
58
+ def __repr__(self) -> str:
59
59
  # check if this is the pos column of a component view
60
60
  tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
61
61
  if tbl.is_component_view() and self.rowid_component_idx == tbl.store_tbl.pos_col_idx: # type: ignore[attr-defined]
@@ -55,9 +55,12 @@ class SimilarityExpr(Expr):
55
55
  f'Embedding index {self.idx_info.name!r} on column {self.idx_info.col.name!r} was created without the '
56
56
  f"'image_embed' parameter and does not support image queries")
57
57
 
58
- def __str__(self) -> str:
58
+ def __repr__(self) -> str:
59
59
  return f'{self.components[0]}.similarity({self.components[1]})'
60
60
 
61
+ def default_column_name(self) -> str:
62
+ return 'similarity'
63
+
61
64
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
62
65
  if not isinstance(self.components[1], Literal):
63
66
  raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
@@ -17,6 +17,10 @@ class SqlElementCache:
17
17
  for e, el in elements.items():
18
18
  self.cache[e.id] = el
19
19
 
20
+ def extend(self, elements: ExprDict[sql.ColumnElement]):
21
+ for e, el in elements.items():
22
+ self.cache[e.id] = el
23
+
20
24
  def get(self, e: Expr) -> Optional[sql.ColumnElement]:
21
25
  """Returns the sql.ColumnElement for the given Expr, or None if Expr.to_sql() returns None."""
22
26
  try:
@@ -51,5 +51,5 @@ class TypeCast(Expr):
51
51
  assert len(components) == 1
52
52
  return cls(components[0], ts.ColumnType.from_dict(d['new_type']))
53
53
 
54
- def __str__(self) -> str:
55
- return f'{self._underlying}.astype({self.col_type})'
54
+ def __repr__(self) -> str:
55
+ return f'{self._underlying}.astype({self.col_type._to_str(as_schema=True)})'
@@ -33,6 +33,9 @@ class Variable(Expr):
33
33
  def __str__(self) -> str:
34
34
  return self.name
35
35
 
36
+ def __repr__(self) -> str:
37
+ return f"Variable('{self.name}')"
38
+
36
39
  def sql_expr(self, _: SqlElementCache) -> NoReturn:
37
40
  raise NotImplementedError()
38
41
 
@@ -78,6 +78,9 @@ class ExprTemplateFunction(Function):
78
78
  def name(self) -> str:
79
79
  return self.self_name
80
80
 
81
+ def __str__(self) -> str:
82
+ return str(self.expr)
83
+
81
84
  def _as_dict(self) -> dict:
82
85
  if self.self_path is not None:
83
86
  return super()._as_dict()