pixeltable 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/table.py +118 -44
- pixeltable/catalog/view.py +2 -2
- pixeltable/dataframe.py +240 -92
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/exec_node.py +6 -7
- pixeltable/exec/sql_node.py +91 -44
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +1 -1
- pixeltable/exprs/array_slice.py +1 -1
- pixeltable/exprs/column_property_ref.py +1 -1
- pixeltable/exprs/column_ref.py +29 -2
- pixeltable/exprs/comparison.py +1 -1
- pixeltable/exprs/compound_predicate.py +1 -1
- pixeltable/exprs/expr.py +11 -5
- pixeltable/exprs/expr_set.py +8 -0
- pixeltable/exprs/function_call.py +14 -11
- pixeltable/exprs/in_predicate.py +1 -1
- pixeltable/exprs/inline_expr.py +3 -3
- pixeltable/exprs/is_null.py +1 -1
- pixeltable/exprs/json_mapper.py +1 -1
- pixeltable/exprs/json_path.py +1 -1
- pixeltable/exprs/method_ref.py +1 -1
- pixeltable/exprs/rowid_ref.py +1 -1
- pixeltable/exprs/similarity_expr.py +1 -1
- pixeltable/exprs/sql_element_cache.py +4 -0
- pixeltable/exprs/type_cast.py +2 -2
- pixeltable/exprs/variable.py +3 -0
- pixeltable/func/expr_template_function.py +3 -0
- pixeltable/functions/ollama.py +4 -4
- pixeltable/globals.py +4 -1
- pixeltable/io/__init__.py +1 -1
- pixeltable/io/parquet.py +39 -19
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_22.py +17 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/plan.py +128 -50
- pixeltable/store.py +1 -1
- pixeltable/type_system.py +1 -1
- pixeltable/utils/arrow.py +8 -3
- pixeltable/utils/description_helper.py +89 -0
- {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/METADATA +26 -10
- {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/RECORD +46 -44
- {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/WHEEL +1 -1
- {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/entry_points.txt +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
3
|
from decimal import Decimal
|
|
4
|
-
from typing import Iterable, Iterator, NamedTuple, Optional
|
|
4
|
+
from typing import Iterable, Iterator, NamedTuple, Optional, TYPE_CHECKING, Sequence
|
|
5
5
|
from uuid import UUID
|
|
6
6
|
|
|
7
7
|
import sqlalchemy as sql
|
|
8
8
|
|
|
9
9
|
import pixeltable.catalog as catalog
|
|
10
10
|
import pixeltable.exprs as exprs
|
|
11
|
-
|
|
12
11
|
from .data_row_batch import DataRowBatch
|
|
13
12
|
from .exec_node import ExecNode
|
|
14
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pixeltable.plan
|
|
16
|
+
|
|
15
17
|
_logger = logging.getLogger('pixeltable')
|
|
16
18
|
|
|
17
19
|
|
|
@@ -67,12 +69,17 @@ class SqlNode(ExecNode):
|
|
|
67
69
|
select_list: exprs.ExprSet
|
|
68
70
|
set_pk: bool
|
|
69
71
|
num_pk_cols: int
|
|
70
|
-
|
|
71
|
-
|
|
72
|
+
py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
|
|
73
|
+
py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
|
|
72
74
|
cte: Optional[sql.CTE]
|
|
73
75
|
sql_elements: exprs.SqlElementCache
|
|
74
|
-
|
|
76
|
+
|
|
77
|
+
# where_clause/-_element: allow subclass to set one or the other (but not both)
|
|
78
|
+
where_clause: Optional[exprs.Expr]
|
|
79
|
+
where_clause_element: Optional[sql.ColumnElement]
|
|
80
|
+
|
|
75
81
|
order_by_clause: OrderByClause
|
|
82
|
+
limit: Optional[int]
|
|
76
83
|
|
|
77
84
|
def __init__(
|
|
78
85
|
self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
|
|
@@ -89,6 +96,7 @@ class SqlNode(ExecNode):
|
|
|
89
96
|
# create Select stmt
|
|
90
97
|
self.sql_elements = sql_elements
|
|
91
98
|
self.tbl = tbl
|
|
99
|
+
assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
|
|
92
100
|
self.select_list = exprs.ExprSet(select_list)
|
|
93
101
|
# unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
|
|
94
102
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
@@ -112,10 +120,12 @@ class SqlNode(ExecNode):
|
|
|
112
120
|
# additional state
|
|
113
121
|
self.result_cursor = None
|
|
114
122
|
# the filter is provided by the subclass
|
|
115
|
-
self.
|
|
116
|
-
self.
|
|
123
|
+
self.py_filter = None
|
|
124
|
+
self.py_filter_eval_ctx = None
|
|
117
125
|
self.cte = None
|
|
118
126
|
self.limit = None
|
|
127
|
+
self.where_clause = None
|
|
128
|
+
self.where_clause_element = None
|
|
119
129
|
self.order_by_clause = []
|
|
120
130
|
|
|
121
131
|
def _create_stmt(self) -> sql.Select:
|
|
@@ -124,9 +134,16 @@ class SqlNode(ExecNode):
|
|
|
124
134
|
assert self.sql_elements.contains_all(self.select_list)
|
|
125
135
|
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
126
136
|
if self.set_pk:
|
|
137
|
+
assert self.tbl is not None
|
|
127
138
|
sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
|
|
128
139
|
stmt = sql.select(*sql_select_list)
|
|
129
140
|
|
|
141
|
+
where_clause_element = (
|
|
142
|
+
self.sql_elements.get(self.where_clause) if self.where_clause is not None else self.where_clause_element
|
|
143
|
+
)
|
|
144
|
+
if where_clause_element is not None:
|
|
145
|
+
stmt = stmt.where(where_clause_element)
|
|
146
|
+
|
|
130
147
|
order_by_clause: list[sql.ColumnElement] = []
|
|
131
148
|
for e, asc in self.order_by_clause:
|
|
132
149
|
if isinstance(e, exprs.SimilarityExpr):
|
|
@@ -135,7 +152,7 @@ class SqlNode(ExecNode):
|
|
|
135
152
|
order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
|
|
136
153
|
stmt = stmt.order_by(*order_by_clause)
|
|
137
154
|
|
|
138
|
-
if self.
|
|
155
|
+
if self.py_filter is None and self.limit is not None:
|
|
139
156
|
# if we don't have a Python filter, we can apply the limit to stmt
|
|
140
157
|
stmt = stmt.limit(self.limit)
|
|
141
158
|
|
|
@@ -151,7 +168,7 @@ class SqlNode(ExecNode):
|
|
|
151
168
|
Returns:
|
|
152
169
|
(CTE, dict from Expr to output column)
|
|
153
170
|
"""
|
|
154
|
-
if self.
|
|
171
|
+
if self.py_filter is not None:
|
|
155
172
|
# the filter needs to run in Python
|
|
156
173
|
return None
|
|
157
174
|
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
@@ -215,8 +232,17 @@ class SqlNode(ExecNode):
|
|
|
215
232
|
prev_tbl = tbl
|
|
216
233
|
return stmt
|
|
217
234
|
|
|
218
|
-
def
|
|
219
|
-
|
|
235
|
+
def set_where(self, where_clause: exprs.Expr) -> None:
|
|
236
|
+
assert self.where_clause_element is None
|
|
237
|
+
self.where_clause = where_clause
|
|
238
|
+
|
|
239
|
+
def set_py_filter(self, py_filter: exprs.Expr) -> None:
|
|
240
|
+
assert self.py_filter is None
|
|
241
|
+
self.py_filter = py_filter
|
|
242
|
+
self.py_filter_eval_ctx = self.row_builder.create_eval_ctx([py_filter], exclude=self.select_list)
|
|
243
|
+
|
|
244
|
+
def set_order_by(self, ordering: OrderByClause) -> None:
|
|
245
|
+
"""Add Order By clause"""
|
|
220
246
|
if self.tbl is not None:
|
|
221
247
|
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
222
248
|
# the number of tables that need to be joined to the target table
|
|
@@ -280,10 +306,10 @@ class SqlNode(ExecNode):
|
|
|
280
306
|
else:
|
|
281
307
|
output_row[slot_idx] = sql_row[i]
|
|
282
308
|
|
|
283
|
-
if self.
|
|
309
|
+
if self.py_filter is not None:
|
|
284
310
|
# evaluate filter
|
|
285
|
-
self.row_builder.eval(output_row, self.
|
|
286
|
-
if self.
|
|
311
|
+
self.row_builder.eval(output_row, self.py_filter_eval_ctx, profile=self.ctx.profile)
|
|
312
|
+
if self.py_filter is not None and not output_row[self.py_filter.slot_idx]:
|
|
287
313
|
# we re-use this row for the next sql row since it didn't pass the filter
|
|
288
314
|
output_row = output_batch.pop_row()
|
|
289
315
|
output_row.clear()
|
|
@@ -315,21 +341,16 @@ class SqlScanNode(SqlNode):
|
|
|
315
341
|
|
|
316
342
|
Supports filtering and ordering.
|
|
317
343
|
"""
|
|
318
|
-
where_clause: Optional[exprs.Expr]
|
|
319
344
|
exact_version_only: list[catalog.TableVersion]
|
|
320
345
|
|
|
321
346
|
def __init__(
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
347
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
348
|
+
select_list: Iterable[exprs.Expr],
|
|
349
|
+
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
326
350
|
):
|
|
327
351
|
"""
|
|
328
352
|
Args:
|
|
329
353
|
select_list: output of the query
|
|
330
|
-
sql_where_clause: SQL Where clause
|
|
331
|
-
filter: additional Where-clause predicate that can't be evaluated via SQL
|
|
332
|
-
limit: max number of rows to return: 0 = no limit
|
|
333
354
|
set_pk: if True, sets the primary for each DataRow
|
|
334
355
|
exact_version_only: tables for which we only want to see rows created at the current version
|
|
335
356
|
"""
|
|
@@ -338,12 +359,7 @@ class SqlScanNode(SqlNode):
|
|
|
338
359
|
# create Select stmt
|
|
339
360
|
if exact_version_only is None:
|
|
340
361
|
exact_version_only = []
|
|
341
|
-
target = tbl.tbl_version # the stored table we're scanning
|
|
342
|
-
self.filter = filter
|
|
343
|
-
self.filter_eval_ctx = \
|
|
344
|
-
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
345
362
|
|
|
346
|
-
self.where_clause = where_clause
|
|
347
363
|
self.exact_version_only = exact_version_only
|
|
348
364
|
|
|
349
365
|
def _create_stmt(self) -> sql.Select:
|
|
@@ -352,12 +368,6 @@ class SqlScanNode(SqlNode):
|
|
|
352
368
|
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
353
369
|
stmt = self.create_from_clause(
|
|
354
370
|
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
|
|
355
|
-
|
|
356
|
-
if self.where_clause is not None:
|
|
357
|
-
sql_where_clause = self.sql_elements.get(self.where_clause)
|
|
358
|
-
assert sql_where_clause is not None
|
|
359
|
-
stmt = stmt.where(sql_where_clause)
|
|
360
|
-
|
|
361
371
|
return stmt
|
|
362
372
|
|
|
363
373
|
|
|
@@ -366,11 +376,9 @@ class SqlLookupNode(SqlNode):
|
|
|
366
376
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
367
377
|
"""
|
|
368
378
|
|
|
369
|
-
where_clause: sql.ColumnElement
|
|
370
|
-
|
|
371
379
|
def __init__(
|
|
372
|
-
|
|
373
|
-
|
|
380
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
381
|
+
select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
|
|
374
382
|
):
|
|
375
383
|
"""
|
|
376
384
|
Args:
|
|
@@ -381,15 +389,15 @@ class SqlLookupNode(SqlNode):
|
|
|
381
389
|
sql_elements = exprs.SqlElementCache()
|
|
382
390
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
383
391
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
384
|
-
self.
|
|
392
|
+
self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
385
393
|
|
|
386
394
|
def _create_stmt(self) -> sql.Select:
|
|
387
395
|
stmt = super()._create_stmt()
|
|
388
396
|
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
|
|
389
397
|
stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
|
|
390
|
-
stmt = stmt.where(self.where_clause)
|
|
391
398
|
return stmt
|
|
392
399
|
|
|
400
|
+
|
|
393
401
|
class SqlAggregationNode(SqlNode):
|
|
394
402
|
"""
|
|
395
403
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
@@ -398,11 +406,11 @@ class SqlAggregationNode(SqlNode):
|
|
|
398
406
|
group_by_items: Optional[list[exprs.Expr]]
|
|
399
407
|
|
|
400
408
|
def __init__(
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
409
|
+
self, row_builder: exprs.RowBuilder,
|
|
410
|
+
input: SqlNode,
|
|
411
|
+
select_list: Iterable[exprs.Expr],
|
|
412
|
+
group_by_items: Optional[list[exprs.Expr]] = None,
|
|
413
|
+
limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
406
414
|
):
|
|
407
415
|
"""
|
|
408
416
|
Args:
|
|
@@ -422,3 +430,42 @@ class SqlAggregationNode(SqlNode):
|
|
|
422
430
|
assert all(e is not None for e in sql_group_by_items)
|
|
423
431
|
stmt = stmt.group_by(*sql_group_by_items)
|
|
424
432
|
return stmt
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class SqlJoinNode(SqlNode):
|
|
436
|
+
"""
|
|
437
|
+
Materializes data from the store via a Select ... From ... that contains joins
|
|
438
|
+
"""
|
|
439
|
+
input_ctes: list[sql.CTE]
|
|
440
|
+
join_clauses: list['pixeltable.plan.JoinClause']
|
|
441
|
+
|
|
442
|
+
def __init__(
|
|
443
|
+
self, row_builder: exprs.RowBuilder,
|
|
444
|
+
inputs: Sequence[SqlNode], join_clauses: list['pixeltable.plan.JoinClause'], select_list: Iterable[exprs.Expr]
|
|
445
|
+
):
|
|
446
|
+
assert len(inputs) > 1
|
|
447
|
+
assert len(inputs) == len(join_clauses) + 1
|
|
448
|
+
self.input_ctes = []
|
|
449
|
+
self.join_clauses = join_clauses
|
|
450
|
+
sql_elements = exprs.SqlElementCache()
|
|
451
|
+
for input_node in inputs:
|
|
452
|
+
input_cte, input_col_map = input_node.to_cte()
|
|
453
|
+
self.input_ctes.append(input_cte)
|
|
454
|
+
sql_elements.extend(input_col_map)
|
|
455
|
+
super().__init__(None, row_builder, select_list, sql_elements)
|
|
456
|
+
|
|
457
|
+
def _create_stmt(self) -> sql.Select:
|
|
458
|
+
from pixeltable import plan
|
|
459
|
+
stmt = super()._create_stmt()
|
|
460
|
+
stmt = stmt.select_from(self.input_ctes[0])
|
|
461
|
+
for i in range(len(self.join_clauses)):
|
|
462
|
+
join_clause = self.join_clauses[i]
|
|
463
|
+
on_clause = (
|
|
464
|
+
self.sql_elements.get(join_clause.join_predicate) if join_clause.join_type != plan.JoinType.CROSS
|
|
465
|
+
else sql.sql.expression.literal(True)
|
|
466
|
+
)
|
|
467
|
+
is_outer = join_clause.join_type == plan.JoinType.LEFT or join_clause.join_type == plan.JoinType.FULL_OUTER
|
|
468
|
+
stmt = stmt.join(
|
|
469
|
+
self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
|
|
470
|
+
full=join_clause == plan.JoinType.FULL_OUTER)
|
|
471
|
+
return stmt
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -35,7 +35,7 @@ class ArithmeticExpr(Expr):
|
|
|
35
35
|
|
|
36
36
|
self.id = self._create_id()
|
|
37
37
|
|
|
38
|
-
def
|
|
38
|
+
def __repr__(self) -> str:
|
|
39
39
|
# add parentheses around operands that are ArithmeticExprs to express precedence
|
|
40
40
|
op1_str = f'({self._op1})' if isinstance(self._op1, ArithmeticExpr) else str(self._op1)
|
|
41
41
|
op2_str = f'({self._op2})' if isinstance(self._op2, ArithmeticExpr) else str(self._op2)
|
pixeltable/exprs/array_slice.py
CHANGED
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -5,10 +5,12 @@ from uuid import UUID
|
|
|
5
5
|
|
|
6
6
|
import sqlalchemy as sql
|
|
7
7
|
|
|
8
|
+
import pixeltable as pxt
|
|
8
9
|
import pixeltable.catalog as catalog
|
|
9
10
|
import pixeltable.exceptions as excs
|
|
10
11
|
import pixeltable.iterators as iters
|
|
11
12
|
|
|
13
|
+
from ..utils.description_helper import DescriptionHelper
|
|
12
14
|
from .data_row import DataRow
|
|
13
15
|
from .expr import Expr
|
|
14
16
|
from .row_builder import RowBuilder
|
|
@@ -126,6 +128,22 @@ class ColumnRef(Expr):
|
|
|
126
128
|
def _equals(self, other: ColumnRef) -> bool:
|
|
127
129
|
return self.col == other.col and self.perform_validation == other.perform_validation
|
|
128
130
|
|
|
131
|
+
def _df(self) -> 'pxt.dataframe.DataFrame':
|
|
132
|
+
tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
|
|
133
|
+
return tbl.select(self)
|
|
134
|
+
|
|
135
|
+
def show(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
136
|
+
return self._df().show(*args, **kwargs)
|
|
137
|
+
|
|
138
|
+
def head(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
139
|
+
return self._df().head(*args, **kwargs)
|
|
140
|
+
|
|
141
|
+
def tail(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
142
|
+
return self._df().tail(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
def count(self) -> int:
|
|
145
|
+
return self._df().count()
|
|
146
|
+
|
|
129
147
|
def __str__(self) -> str:
|
|
130
148
|
if self.col.name is None:
|
|
131
149
|
return f'<unnamed column {self.col.id}>'
|
|
@@ -133,11 +151,20 @@ class ColumnRef(Expr):
|
|
|
133
151
|
return self.col.name
|
|
134
152
|
|
|
135
153
|
def __repr__(self) -> str:
|
|
136
|
-
return
|
|
154
|
+
return self._descriptors().to_string()
|
|
137
155
|
|
|
138
156
|
def _repr_html_(self) -> str:
|
|
157
|
+
return self._descriptors().to_html()
|
|
158
|
+
|
|
159
|
+
def _descriptors(self) -> DescriptionHelper:
|
|
139
160
|
tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
|
|
140
|
-
|
|
161
|
+
helper = DescriptionHelper()
|
|
162
|
+
helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
|
|
163
|
+
helper.append(tbl._col_descriptor([self.col.name]))
|
|
164
|
+
idxs = tbl._index_descriptor([self.col.name])
|
|
165
|
+
if len(idxs) > 0:
|
|
166
|
+
helper.append(idxs)
|
|
167
|
+
return helper
|
|
141
168
|
|
|
142
169
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
143
170
|
return None if self.perform_validation else self.col.sa_col
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -30,7 +30,7 @@ class CompoundPredicate(Expr):
|
|
|
30
30
|
|
|
31
31
|
self.id = self._create_id()
|
|
32
32
|
|
|
33
|
-
def
|
|
33
|
+
def __repr__(self) -> str:
|
|
34
34
|
if self.operator == LogicalOperator.NOT:
|
|
35
35
|
return f'~({self.components[0]})'
|
|
36
36
|
return f' {self.operator} '.join([f'({e})' for e in self.components])
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -216,12 +216,12 @@ class Expr(abc.ABC):
|
|
|
216
216
|
return result
|
|
217
217
|
result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
|
|
218
218
|
|
|
219
|
-
def is_bound_by(self,
|
|
220
|
-
"""Returns True if this expr can be evaluated in the context of
|
|
219
|
+
def is_bound_by(self, tbls: list[catalog.TableVersionPath]) -> bool:
|
|
220
|
+
"""Returns True if this expr can be evaluated in the context of tbls."""
|
|
221
221
|
from .column_ref import ColumnRef
|
|
222
222
|
col_refs = self.subexprs(ColumnRef)
|
|
223
223
|
for col_ref in col_refs:
|
|
224
|
-
if not tbl.has_column(col_ref.col):
|
|
224
|
+
if not any(tbl.has_column(col_ref.col) for tbl in tbls):
|
|
225
225
|
return False
|
|
226
226
|
return True
|
|
227
227
|
|
|
@@ -235,7 +235,7 @@ class Expr(abc.ABC):
|
|
|
235
235
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
|
236
236
|
return self
|
|
237
237
|
|
|
238
|
-
def
|
|
238
|
+
def __repr__(self) -> str:
|
|
239
239
|
return f'<Expression of type {type(self)}>'
|
|
240
240
|
|
|
241
241
|
def display_str(self, inline: bool = True) -> str:
|
|
@@ -450,7 +450,13 @@ class Expr(abc.ABC):
|
|
|
450
450
|
|
|
451
451
|
def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'exprs.TypeCast':
|
|
452
452
|
from pixeltable.exprs import TypeCast
|
|
453
|
-
|
|
453
|
+
# Interpret the type argument the same way we would if given in a schema
|
|
454
|
+
col_type = ts.ColumnType.normalize_type(new_type, nullable_default=True, allow_builtin_types=False)
|
|
455
|
+
if not self.col_type.nullable:
|
|
456
|
+
# This expression is non-nullable; we can prove that the output is non-nullable, regardless of
|
|
457
|
+
# whether new_type is given as nullable.
|
|
458
|
+
col_type = col_type.copy(nullable=False)
|
|
459
|
+
return TypeCast(self, col_type)
|
|
454
460
|
|
|
455
461
|
def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'exprs.FunctionCall':
|
|
456
462
|
if col_type is not None:
|
pixeltable/exprs/expr_set.py
CHANGED
|
@@ -60,6 +60,14 @@ class ExprSet(Generic[T]):
|
|
|
60
60
|
def __le__(self, other: ExprSet[T]) -> bool:
|
|
61
61
|
return other.issuperset(self)
|
|
62
62
|
|
|
63
|
+
def union(self, *others: Iterable[T]) -> ExprSet[T]:
|
|
64
|
+
result = ExprSet(self.exprs.values())
|
|
65
|
+
result.update(*others)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
def __or__(self, other: ExprSet[T]) -> ExprSet[T]:
|
|
69
|
+
return self.union(other)
|
|
70
|
+
|
|
63
71
|
def difference(self, *others: Iterable[T]) -> ExprSet[T]:
|
|
64
72
|
id_diff = set(self.exprs.keys()).difference(e.id for other_set in others for e in other_set)
|
|
65
73
|
return ExprSet(self.exprs[id] for id in id_diff)
|
|
@@ -85,7 +85,9 @@ class FunctionCall(Expr):
|
|
|
85
85
|
# we record the types of non-variable parameters for runtime type checks
|
|
86
86
|
self.arg_types = []
|
|
87
87
|
self.kwarg_types = {}
|
|
88
|
+
|
|
88
89
|
# the prefix of parameters that are bound can be passed by position
|
|
90
|
+
processed_args: set[str] = set()
|
|
89
91
|
for py_param in fn.signature.py_signature.parameters.values():
|
|
90
92
|
if py_param.name not in bound_args or py_param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
91
93
|
break
|
|
@@ -97,18 +99,19 @@ class FunctionCall(Expr):
|
|
|
97
99
|
self.args.append((None, arg))
|
|
98
100
|
if py_param.kind != inspect.Parameter.VAR_POSITIONAL and py_param.kind != inspect.Parameter.VAR_KEYWORD:
|
|
99
101
|
self.arg_types.append(signature.parameters[py_param.name].col_type)
|
|
102
|
+
processed_args.add(py_param.name)
|
|
100
103
|
|
|
101
104
|
# the remaining args are passed as keywords
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
105
|
+
for param_name in bound_args.keys():
|
|
106
|
+
if param_name not in processed_args:
|
|
107
|
+
arg = bound_args[param_name]
|
|
108
|
+
if isinstance(arg, Expr):
|
|
109
|
+
self.kwargs[param_name] = (len(self.components), None)
|
|
110
|
+
self.components.append(arg.copy())
|
|
111
|
+
else:
|
|
112
|
+
self.kwargs[param_name] = (None, arg)
|
|
113
|
+
if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
|
|
114
|
+
self.kwarg_types[param_name] = signature.parameters[param_name].col_type
|
|
112
115
|
|
|
113
116
|
# window function state:
|
|
114
117
|
# self.components[self.group_by_start_idx:self.group_by_stop_idx] contains group_by exprs
|
|
@@ -255,7 +258,7 @@ class FunctionCall(Expr):
|
|
|
255
258
|
('order_by_start_idx', self.order_by_start_idx)
|
|
256
259
|
]
|
|
257
260
|
|
|
258
|
-
def
|
|
261
|
+
def __repr__(self) -> str:
|
|
259
262
|
return self.display_str()
|
|
260
263
|
|
|
261
264
|
def display_str(self, inline: bool = True) -> str:
|
pixeltable/exprs/in_predicate.py
CHANGED
|
@@ -61,7 +61,7 @@ class InPredicate(Expr):
|
|
|
61
61
|
pass
|
|
62
62
|
return result
|
|
63
63
|
|
|
64
|
-
def
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
65
|
if self.value_list is not None:
|
|
66
66
|
return f'{self.components[0]}.isin({self.value_list})'
|
|
67
67
|
return f'{self.components[0]}.isin({self.components[1]})'
|
pixeltable/exprs/inline_expr.py
CHANGED
|
@@ -56,7 +56,7 @@ class InlineArray(Expr):
|
|
|
56
56
|
self.components.extend(exprs)
|
|
57
57
|
self.id = self._create_id()
|
|
58
58
|
|
|
59
|
-
def
|
|
59
|
+
def __repr__(self) -> str:
|
|
60
60
|
elem_strs = [str(expr) for expr in self.components]
|
|
61
61
|
return f'[{", ".join(elem_strs)}]'
|
|
62
62
|
|
|
@@ -105,7 +105,7 @@ class InlineList(Expr):
|
|
|
105
105
|
self.components.extend(exprs)
|
|
106
106
|
self.id = self._create_id()
|
|
107
107
|
|
|
108
|
-
def
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
109
|
elem_strs = [str(expr) for expr in self.components]
|
|
110
110
|
return f'[{", ".join(elem_strs)}]'
|
|
111
111
|
|
|
@@ -153,7 +153,7 @@ class InlineDict(Expr):
|
|
|
153
153
|
self.components.extend(exprs)
|
|
154
154
|
self.id = self._create_id()
|
|
155
155
|
|
|
156
|
-
def
|
|
156
|
+
def __repr__(self) -> str:
|
|
157
157
|
item_strs = list(f"'{key}': {str(expr)}" for key, expr in zip(self.keys, self.components))
|
|
158
158
|
return '{' + ', '.join(item_strs) + '}'
|
|
159
159
|
|
pixeltable/exprs/is_null.py
CHANGED
pixeltable/exprs/json_mapper.py
CHANGED
|
@@ -69,7 +69,7 @@ class JsonMapper(Expr):
|
|
|
69
69
|
return False
|
|
70
70
|
return self._src_expr.equals(other._src_expr) and self._target_expr.equals(other._target_expr)
|
|
71
71
|
|
|
72
|
-
def
|
|
72
|
+
def __repr__(self) -> str:
|
|
73
73
|
return f'{str(self._src_expr)} >> {str(self._target_expr)}'
|
|
74
74
|
|
|
75
75
|
@property
|
pixeltable/exprs/json_path.py
CHANGED
|
@@ -42,7 +42,7 @@ class JsonPath(Expr):
|
|
|
42
42
|
# this is not a problem, because _create_id() shouldn't be called after init()
|
|
43
43
|
self.id = self._create_id()
|
|
44
44
|
|
|
45
|
-
def
|
|
45
|
+
def __repr__(self) -> str:
|
|
46
46
|
# else "R": the anchor is RELATIVE_PATH_ROOT
|
|
47
47
|
return (f'{str(self._anchor) if self._anchor is not None else "R"}'
|
|
48
48
|
f'{"." if isinstance(self.path_elements[0], str) else ""}{self._json_path()}')
|
pixeltable/exprs/method_ref.py
CHANGED
pixeltable/exprs/rowid_ref.py
CHANGED
|
@@ -55,7 +55,7 @@ class RowidRef(Expr):
|
|
|
55
55
|
return super()._id_attrs() +\
|
|
56
56
|
[('normalized_base_id', self.normalized_base_id), ('idx', self.rowid_component_idx)]
|
|
57
57
|
|
|
58
|
-
def
|
|
58
|
+
def __repr__(self) -> str:
|
|
59
59
|
# check if this is the pos column of a component view
|
|
60
60
|
tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
|
|
61
61
|
if tbl.is_component_view() and self.rowid_component_idx == tbl.store_tbl.pos_col_idx: # type: ignore[attr-defined]
|
|
@@ -55,7 +55,7 @@ class SimilarityExpr(Expr):
|
|
|
55
55
|
f'Embedding index {self.idx_info.name!r} on column {self.idx_info.col.name!r} was created without the '
|
|
56
56
|
f"'image_embed' parameter and does not support image queries")
|
|
57
57
|
|
|
58
|
-
def
|
|
58
|
+
def __repr__(self) -> str:
|
|
59
59
|
return f'{self.components[0]}.similarity({self.components[1]})'
|
|
60
60
|
|
|
61
61
|
def default_column_name(self) -> str:
|
|
@@ -17,6 +17,10 @@ class SqlElementCache:
|
|
|
17
17
|
for e, el in elements.items():
|
|
18
18
|
self.cache[e.id] = el
|
|
19
19
|
|
|
20
|
+
def extend(self, elements: ExprDict[sql.ColumnElement]):
|
|
21
|
+
for e, el in elements.items():
|
|
22
|
+
self.cache[e.id] = el
|
|
23
|
+
|
|
20
24
|
def get(self, e: Expr) -> Optional[sql.ColumnElement]:
|
|
21
25
|
"""Returns the sql.ColumnElement for the given Expr, or None if Expr.to_sql() returns None."""
|
|
22
26
|
try:
|
pixeltable/exprs/type_cast.py
CHANGED
|
@@ -51,5 +51,5 @@ class TypeCast(Expr):
|
|
|
51
51
|
assert len(components) == 1
|
|
52
52
|
return cls(components[0], ts.ColumnType.from_dict(d['new_type']))
|
|
53
53
|
|
|
54
|
-
def
|
|
55
|
-
return f'{self._underlying}.astype({self.col_type})'
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
return f'{self._underlying}.astype({self.col_type._to_str(as_schema=True)})'
|
pixeltable/exprs/variable.py
CHANGED
pixeltable/functions/ollama.py
CHANGED
|
@@ -68,7 +68,7 @@ def generate(
|
|
|
68
68
|
raw=raw,
|
|
69
69
|
format=format,
|
|
70
70
|
options=options,
|
|
71
|
-
) # type: ignore[call-overload]
|
|
71
|
+
).dict() # type: ignore[call-overload]
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
@pxt.udf
|
|
@@ -103,7 +103,7 @@ def chat(
|
|
|
103
103
|
tools=tools,
|
|
104
104
|
format=format,
|
|
105
105
|
options=options,
|
|
106
|
-
) # type: ignore[call-overload]
|
|
106
|
+
).dict() # type: ignore[call-overload]
|
|
107
107
|
|
|
108
108
|
|
|
109
109
|
@pxt.udf(batch_size=16)
|
|
@@ -135,8 +135,8 @@ def embed(
|
|
|
135
135
|
model=model,
|
|
136
136
|
input=input,
|
|
137
137
|
truncate=truncate,
|
|
138
|
-
options=options,
|
|
139
|
-
)
|
|
138
|
+
options=options,
|
|
139
|
+
).dict()
|
|
140
140
|
return [np.array(data, dtype=np.float64) for data in results['embeddings']]
|
|
141
141
|
|
|
142
142
|
|