pixeltable 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/globals.py +3 -0
- pixeltable/catalog/insertable_table.py +9 -7
- pixeltable/catalog/table.py +220 -143
- pixeltable/catalog/table_version.py +36 -18
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +3 -3
- pixeltable/dataframe.py +9 -24
- pixeltable/env.py +107 -36
- pixeltable/exceptions.py +7 -4
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +22 -15
- pixeltable/exec/component_iteration_node.py +62 -41
- pixeltable/exec/data_row_batch.py +7 -7
- pixeltable/exec/exec_node.py +35 -7
- pixeltable/exec/expr_eval_node.py +2 -1
- pixeltable/exec/in_memory_data_node.py +9 -9
- pixeltable/exec/sql_node.py +265 -136
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/data_row.py +30 -19
- pixeltable/exprs/expr.py +15 -14
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +21 -15
- pixeltable/exprs/function_call.py +21 -8
- pixeltable/exprs/json_path.py +3 -6
- pixeltable/exprs/rowid_ref.py +2 -2
- pixeltable/exprs/sql_element_cache.py +5 -1
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +86 -42
- pixeltable/functions/huggingface.py +12 -14
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/string.py +50 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +26 -12
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +57 -56
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +7 -7
- pixeltable/index/embedding_index.py +8 -10
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/globals.py +3 -1
- pixeltable/io/hf_datasets.py +4 -4
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/document.py +10 -8
- pixeltable/iterators/video.py +10 -1
- pixeltable/metadata/__init__.py +3 -2
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/util.py +9 -8
- pixeltable/metadata/schema.py +32 -21
- pixeltable/plan.py +136 -154
- pixeltable/store.py +51 -36
- pixeltable/tool/create_test_db_dump.py +7 -7
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/mypy_plugin.py +32 -0
- pixeltable/type_system.py +243 -60
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +1 -1
- pixeltable/utils/filecache.py +131 -84
- pixeltable/utils/formatter.py +1 -1
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/METADATA +16 -7
- pixeltable-0.2.21.dist-info/RECORD +148 -0
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.19.dist-info/RECORD +0 -147
- {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,73 +1,175 @@
|
|
|
1
|
-
from typing import List, Optional, Tuple, Iterable, Set
|
|
2
|
-
from uuid import UUID
|
|
3
1
|
import logging
|
|
4
2
|
import warnings
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from typing import Optional, Iterable, Iterator, NamedTuple
|
|
5
|
+
from uuid import UUID
|
|
5
6
|
|
|
6
7
|
import sqlalchemy as sql
|
|
7
8
|
|
|
9
|
+
import pixeltable.catalog as catalog
|
|
10
|
+
import pixeltable.exprs as exprs
|
|
8
11
|
from .data_row_batch import DataRowBatch
|
|
9
12
|
from .exec_node import ExecNode
|
|
10
|
-
import pixeltable.exprs as exprs
|
|
11
|
-
import pixeltable.catalog as catalog
|
|
12
|
-
|
|
13
13
|
|
|
14
14
|
_logger = logging.getLogger('pixeltable')
|
|
15
15
|
|
|
16
|
+
|
|
17
|
+
class OrderByItem(NamedTuple):
|
|
18
|
+
expr: exprs.Expr
|
|
19
|
+
asc: Optional[bool]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
OrderByClause = list[OrderByItem]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
|
|
26
|
+
"""Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
|
|
27
|
+
Two clauses are compatible if for each of their respective items c1[i] and c2[i]
|
|
28
|
+
a) the exprs are identical and
|
|
29
|
+
b) the asc values are identical or at least one is None (None serves as a wildcard)
|
|
30
|
+
"""
|
|
31
|
+
result: OrderByClause = []
|
|
32
|
+
for clause in clauses:
|
|
33
|
+
combined: OrderByClause = []
|
|
34
|
+
for item1, item2 in zip(result, clause):
|
|
35
|
+
if item1.expr.id != item2.expr.id:
|
|
36
|
+
return None
|
|
37
|
+
if item1.asc is not None and item2.asc is not None and item1.asc != item2.asc:
|
|
38
|
+
return None
|
|
39
|
+
asc = item1.asc if item1.asc is not None else item2.asc
|
|
40
|
+
combined.append(OrderByItem(item1.expr, asc))
|
|
41
|
+
|
|
42
|
+
# add remaining ordering of the longer list
|
|
43
|
+
prefix_len = min(len(result), len(clause))
|
|
44
|
+
if len(result) > prefix_len:
|
|
45
|
+
combined.extend(result[prefix_len:])
|
|
46
|
+
elif len(clause) > prefix_len:
|
|
47
|
+
combined.extend(clause[prefix_len:])
|
|
48
|
+
result = combined
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def print_order_by_clause(clause: OrderByClause) -> str:
|
|
53
|
+
return ', '.join([
|
|
54
|
+
f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
|
|
55
|
+
for item in clause
|
|
56
|
+
])
|
|
57
|
+
|
|
58
|
+
|
|
16
59
|
class SqlNode(ExecNode):
|
|
17
|
-
"""
|
|
60
|
+
"""
|
|
61
|
+
Materializes data from the store via a Select stmt.
|
|
62
|
+
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
tbl: Optional[catalog.TableVersionPath]
|
|
66
|
+
select_list: exprs.ExprSet
|
|
67
|
+
set_pk: bool
|
|
68
|
+
num_pk_cols: int
|
|
69
|
+
filter: Optional[exprs.Expr]
|
|
70
|
+
filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
|
|
71
|
+
cte: Optional[sql.CTE]
|
|
72
|
+
sql_elements: exprs.SqlElementCache
|
|
73
|
+
limit: Optional[int]
|
|
74
|
+
order_by_clause: OrderByClause
|
|
18
75
|
|
|
19
76
|
def __init__(
|
|
20
|
-
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
77
|
+
self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
|
|
21
78
|
select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
|
|
22
79
|
):
|
|
23
80
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
This only provides the select list. The subclass is responsible for the From clause and any additional clauses.
|
|
81
|
+
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
82
|
+
SQL-materializable subexpressions.
|
|
27
83
|
|
|
28
84
|
Args:
|
|
29
85
|
select_list: output of the query
|
|
30
86
|
set_pk: if True, sets the primary for each DataRow
|
|
31
87
|
"""
|
|
32
88
|
# create Select stmt
|
|
89
|
+
self.sql_elements = sql_elements
|
|
33
90
|
self.tbl = tbl
|
|
34
|
-
|
|
35
|
-
self.sql_exprs = exprs.ExprSet(select_list)
|
|
91
|
+
self.select_list = exprs.ExprSet(select_list)
|
|
36
92
|
# unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
|
|
37
93
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
38
|
-
sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
|
|
94
|
+
sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
|
|
39
95
|
for e in sql_subexprs:
|
|
40
|
-
self.
|
|
41
|
-
super().__init__(row_builder, self.
|
|
96
|
+
self.select_list.add(e)
|
|
97
|
+
super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
|
|
42
98
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
rowid_ref.set_tbl(tbl)
|
|
99
|
+
if tbl is not None:
|
|
100
|
+
# minimize the number of tables that need to be joined to the target table
|
|
101
|
+
self.retarget_rowid_refs(tbl, self.select_list)
|
|
47
102
|
|
|
48
|
-
|
|
49
|
-
assert len(sql_select_list) == len(self.sql_exprs)
|
|
50
|
-
assert all(e is not None for e in sql_select_list)
|
|
103
|
+
assert self.sql_elements.contains(self.select_list)
|
|
51
104
|
self.set_pk = set_pk
|
|
52
105
|
self.num_pk_cols = 0
|
|
53
106
|
if set_pk:
|
|
54
107
|
# we also need to retrieve the pk columns
|
|
55
|
-
|
|
56
|
-
self.num_pk_cols = len(pk_columns)
|
|
57
|
-
sql_select_list += pk_columns
|
|
58
|
-
|
|
59
|
-
self.stmt = sql.select(*sql_select_list)
|
|
108
|
+
assert tbl is not None
|
|
109
|
+
self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
|
|
60
110
|
|
|
61
111
|
# additional state
|
|
62
|
-
self.result_cursor
|
|
112
|
+
self.result_cursor = None
|
|
63
113
|
# the filter is provided by the subclass
|
|
64
|
-
self.filter
|
|
65
|
-
self.filter_eval_ctx
|
|
114
|
+
self.filter = None
|
|
115
|
+
self.filter_eval_ctx = None
|
|
116
|
+
self.cte = None
|
|
117
|
+
self.limit = None
|
|
118
|
+
self.order_by_clause = []
|
|
119
|
+
|
|
120
|
+
def _create_stmt(self) -> sql.Select:
|
|
121
|
+
"""Create Select from local state"""
|
|
122
|
+
|
|
123
|
+
assert self.sql_elements.contains(self.select_list)
|
|
124
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
125
|
+
if self.set_pk:
|
|
126
|
+
sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
|
|
127
|
+
stmt = sql.select(*sql_select_list)
|
|
128
|
+
|
|
129
|
+
order_by_clause: list[sql.ClauseElement] = []
|
|
130
|
+
for e, asc in self.order_by_clause:
|
|
131
|
+
if isinstance(e, exprs.SimilarityExpr):
|
|
132
|
+
order_by_clause.append(e.as_order_by_clause(asc))
|
|
133
|
+
else:
|
|
134
|
+
order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
|
|
135
|
+
stmt = stmt.order_by(*order_by_clause)
|
|
136
|
+
|
|
137
|
+
if self.filter is None and self.limit is not None:
|
|
138
|
+
# if we don't have a Python filter, we can apply the limit to stmt
|
|
139
|
+
stmt = stmt.limit(self.limit)
|
|
140
|
+
|
|
141
|
+
return stmt
|
|
142
|
+
|
|
143
|
+
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
144
|
+
return exprs.Expr.list_tbl_ids(e for e, _ in self.order_by_clause)
|
|
145
|
+
|
|
146
|
+
def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
147
|
+
"""
|
|
148
|
+
Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
(CTE, dict from Expr to output column)
|
|
152
|
+
"""
|
|
153
|
+
if self.filter is not None:
|
|
154
|
+
# the filter needs to run in Python
|
|
155
|
+
return None
|
|
156
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
157
|
+
if self.cte is None:
|
|
158
|
+
self.cte = self._create_stmt().cte()
|
|
159
|
+
assert len(self.cte.c) == len(self.select_list)
|
|
160
|
+
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
164
|
+
"""Change rowid refs to point to target"""
|
|
165
|
+
for e in expr_seq:
|
|
166
|
+
if isinstance(e, exprs.RowidRef):
|
|
167
|
+
e.set_tbl(target)
|
|
66
168
|
|
|
67
169
|
@classmethod
|
|
68
170
|
def create_from_clause(
|
|
69
|
-
cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[
|
|
70
|
-
exact_version_only: Optional[
|
|
171
|
+
cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[set[UUID]] = None,
|
|
172
|
+
exact_version_only: Optional[set[UUID]] = None
|
|
71
173
|
) -> sql.Select:
|
|
72
174
|
"""Add From clause to stmt for tables/views referenced by materialized_exprs
|
|
73
175
|
Args:
|
|
@@ -85,7 +187,7 @@ class SqlNode(ExecNode):
|
|
|
85
187
|
exact_version_only = {}
|
|
86
188
|
candidates = tbl.get_tbl_versions()
|
|
87
189
|
assert len(candidates) > 0
|
|
88
|
-
joined_tbls:
|
|
190
|
+
joined_tbls: list[catalog.TableVersion] = [candidates[0]]
|
|
89
191
|
for tbl in candidates[1:]:
|
|
90
192
|
if tbl.id in refd_tbl_ids:
|
|
91
193
|
joined_tbls.append(tbl)
|
|
@@ -111,72 +213,97 @@ class SqlNode(ExecNode):
|
|
|
111
213
|
prev_tbl = tbl
|
|
112
214
|
return stmt
|
|
113
215
|
|
|
114
|
-
def
|
|
216
|
+
def add_order_by(self, ordering: OrderByClause) -> None:
|
|
217
|
+
"""Add Order By clause to stmt"""
|
|
218
|
+
if self.tbl is not None:
|
|
219
|
+
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
220
|
+
# the number of tables that need to be joined to the target table
|
|
221
|
+
self.retarget_rowid_refs(self.tbl, [e for e, _ in ordering])
|
|
222
|
+
combined = combine_order_by_clauses([self.order_by_clause, ordering])
|
|
223
|
+
assert combined is not None
|
|
224
|
+
self.order_by_clause = combined
|
|
225
|
+
|
|
226
|
+
def set_limit(self, limit: int) -> None:
|
|
227
|
+
self.limit = limit
|
|
228
|
+
|
|
229
|
+
def _log_explain(self, stmt: sql.Select) -> None:
|
|
115
230
|
try:
|
|
116
231
|
# don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
|
|
117
|
-
stmt_str = str(
|
|
232
|
+
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
118
233
|
explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
|
|
119
234
|
explain_str = '\n'.join([str(row) for row in explain_result])
|
|
120
235
|
_logger.debug(f'SqlScanNode explain:\n{explain_str}')
|
|
121
236
|
except Exception as e:
|
|
122
237
|
_logger.warning(f'EXPLAIN failed')
|
|
123
238
|
|
|
124
|
-
def
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
self.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
self.
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
self.has_more_rows = False
|
|
149
|
-
break
|
|
239
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
240
|
+
# run the query; do this here rather than in _open(), exceptions are only expected during iteration
|
|
241
|
+
assert self.ctx.conn is not None
|
|
242
|
+
try:
|
|
243
|
+
with warnings.catch_warnings(record=True) as w:
|
|
244
|
+
stmt = self._create_stmt()
|
|
245
|
+
try:
|
|
246
|
+
# log stmt, if possible
|
|
247
|
+
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
248
|
+
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
249
|
+
except Exception as e:
|
|
250
|
+
pass
|
|
251
|
+
self._log_explain(stmt)
|
|
252
|
+
|
|
253
|
+
result_cursor = self.ctx.conn.execute(stmt)
|
|
254
|
+
for warning in w:
|
|
255
|
+
pass
|
|
256
|
+
except Exception as e:
|
|
257
|
+
raise e
|
|
258
|
+
|
|
259
|
+
tbl_version = self.tbl.tbl_version if self.tbl is not None else None
|
|
260
|
+
output_batch = DataRowBatch(tbl_version, self.row_builder)
|
|
261
|
+
output_row: Optional[exprs.DataRow] = None
|
|
262
|
+
num_rows_returned = 0
|
|
150
263
|
|
|
151
|
-
|
|
152
|
-
|
|
264
|
+
for sql_row in result_cursor:
|
|
265
|
+
output_row = output_batch.add_row(output_row)
|
|
266
|
+
|
|
267
|
+
# populate output_row
|
|
153
268
|
if self.num_pk_cols > 0:
|
|
154
269
|
output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
|
|
155
270
|
# copy the output of the SQL query into the output row
|
|
156
|
-
for i, e in enumerate(self.
|
|
271
|
+
for i, e in enumerate(self.select_list):
|
|
157
272
|
slot_idx = e.slot_idx
|
|
158
|
-
|
|
273
|
+
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
274
|
+
if isinstance(sql_row[i], Decimal):
|
|
275
|
+
if e.col_type.is_int_type():
|
|
276
|
+
output_row[slot_idx] = int(sql_row[i])
|
|
277
|
+
elif e.col_type.is_float_type():
|
|
278
|
+
output_row[slot_idx] = float(sql_row[i])
|
|
279
|
+
else:
|
|
280
|
+
raise RuntimeError(f'Unexpected Decimal value for {e}')
|
|
281
|
+
else:
|
|
282
|
+
output_row[slot_idx] = sql_row[i]
|
|
283
|
+
|
|
159
284
|
if self.filter is not None:
|
|
285
|
+
# evaluate filter
|
|
160
286
|
self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
output_row.clear()
|
|
287
|
+
if self.filter is not None and not output_row[self.filter.slot_idx]:
|
|
288
|
+
# we re-use this row for the next sql row since it didn't pass the filter
|
|
289
|
+
output_row = output_batch.pop_row()
|
|
290
|
+
output_row.clear()
|
|
291
|
+
else:
|
|
292
|
+
# reset output_row in order to add new one
|
|
293
|
+
output_row = None
|
|
294
|
+
num_rows_returned += 1
|
|
170
295
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
296
|
+
if self.limit is not None and num_rows_returned == self.limit:
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
|
|
300
|
+
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
301
|
+
yield output_batch
|
|
302
|
+
output_batch = DataRowBatch(tbl_version, self.row_builder)
|
|
175
303
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
return output_batch
|
|
304
|
+
if len(output_batch) > 0:
|
|
305
|
+
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
306
|
+
yield output_batch
|
|
180
307
|
|
|
181
308
|
def _close(self) -> None:
|
|
182
309
|
if self.result_cursor is not None:
|
|
@@ -189,12 +316,14 @@ class SqlScanNode(SqlNode):
|
|
|
189
316
|
|
|
190
317
|
Supports filtering and ordering.
|
|
191
318
|
"""
|
|
319
|
+
where_clause: Optional[exprs.Expr]
|
|
320
|
+
exact_version_only: list[catalog.TableVersion]
|
|
321
|
+
|
|
192
322
|
def __init__(
|
|
193
323
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
194
324
|
select_list: Iterable[exprs.Expr],
|
|
195
325
|
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
|
|
196
|
-
|
|
197
|
-
limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
|
|
326
|
+
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
198
327
|
):
|
|
199
328
|
"""
|
|
200
329
|
Args:
|
|
@@ -208,52 +337,29 @@ class SqlScanNode(SqlNode):
|
|
|
208
337
|
sql_elements = exprs.SqlElementCache()
|
|
209
338
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
|
|
210
339
|
# create Select stmt
|
|
211
|
-
if order_by_items is None:
|
|
212
|
-
order_by_items = []
|
|
213
340
|
if exact_version_only is None:
|
|
214
341
|
exact_version_only = []
|
|
215
342
|
target = tbl.tbl_version # the stored table we're scanning
|
|
216
343
|
self.filter = filter
|
|
217
344
|
self.filter_eval_ctx = \
|
|
218
345
|
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
219
|
-
self.limit = limit
|
|
220
346
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
self.stmt = self.create_from_clause(
|
|
224
|
-
tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
|
|
225
|
-
|
|
226
|
-
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
227
|
-
# the number of tables that need to be joined to the target table
|
|
228
|
-
for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
|
|
229
|
-
rowid_ref.set_tbl(tbl)
|
|
230
|
-
order_by_clause: List[sql.ClauseElement] = []
|
|
231
|
-
for e, asc in order_by_items:
|
|
232
|
-
if isinstance(e, exprs.SimilarityExpr):
|
|
233
|
-
order_by_clause.append(e.as_order_by_clause(asc))
|
|
234
|
-
else:
|
|
235
|
-
order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
|
|
347
|
+
self.where_clause = where_clause
|
|
348
|
+
self.exact_version_only = exact_version_only
|
|
236
349
|
|
|
237
|
-
|
|
238
|
-
|
|
350
|
+
def _create_stmt(self) -> sql.Select:
|
|
351
|
+
stmt = super()._create_stmt()
|
|
352
|
+
where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
|
|
353
|
+
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
354
|
+
stmt = self.create_from_clause(
|
|
355
|
+
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
|
|
356
|
+
|
|
357
|
+
if self.where_clause is not None:
|
|
358
|
+
sql_where_clause = self.sql_elements.get(self.where_clause)
|
|
239
359
|
assert sql_where_clause is not None
|
|
240
|
-
|
|
241
|
-
if len(order_by_clause) > 0:
|
|
242
|
-
self.stmt = self.stmt.order_by(*order_by_clause)
|
|
243
|
-
elif target.id in row_builder.unstored_iter_args:
|
|
244
|
-
# we are referencing unstored iter columns from this view and try to order by our primary key,
|
|
245
|
-
# which ensures that iterators will see monotonically increasing pos values
|
|
246
|
-
self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
|
|
247
|
-
if limit != 0 and self.filter is None:
|
|
248
|
-
# if we need to do post-SQL filtering, we can't use LIMIT
|
|
249
|
-
self.stmt = self.stmt.limit(limit)
|
|
360
|
+
stmt = stmt.where(sql_where_clause)
|
|
250
361
|
|
|
251
|
-
|
|
252
|
-
# log stmt, if possible
|
|
253
|
-
stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
254
|
-
_logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
|
|
255
|
-
except Exception as e:
|
|
256
|
-
pass
|
|
362
|
+
return stmt
|
|
257
363
|
|
|
258
364
|
|
|
259
365
|
class SqlLookupNode(SqlNode):
|
|
@@ -261,8 +367,7 @@ class SqlLookupNode(SqlNode):
|
|
|
261
367
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
262
368
|
"""
|
|
263
369
|
|
|
264
|
-
|
|
265
|
-
where_clause: sql.ColumnElement[bool]
|
|
370
|
+
where_clause: sql.ColumnElement
|
|
266
371
|
|
|
267
372
|
def __init__(
|
|
268
373
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
@@ -276,21 +381,45 @@ class SqlLookupNode(SqlNode):
|
|
|
276
381
|
"""
|
|
277
382
|
sql_elements = exprs.SqlElementCache()
|
|
278
383
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
279
|
-
target = tbl.tbl_version # the stored table we're scanning
|
|
280
|
-
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
|
|
281
|
-
self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
|
|
282
384
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
283
385
|
self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
284
|
-
self.stmt = self.stmt.where(self.where_clause)
|
|
285
386
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
387
|
+
def _create_stmt(self) -> sql.Select:
|
|
388
|
+
stmt = super()._create_stmt()
|
|
389
|
+
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | self._ordering_tbl_ids()
|
|
390
|
+
stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
|
|
391
|
+
stmt = stmt.where(self.where_clause)
|
|
392
|
+
return stmt
|
|
290
393
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
394
|
+
class SqlAggregationNode(SqlNode):
|
|
395
|
+
"""
|
|
396
|
+
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
group_by_items: Optional[list[exprs.Expr]]
|
|
400
|
+
|
|
401
|
+
def __init__(
|
|
402
|
+
self, row_builder: exprs.RowBuilder,
|
|
403
|
+
input: SqlNode,
|
|
404
|
+
select_list: Iterable[exprs.Expr],
|
|
405
|
+
group_by_items: Optional[list[exprs.Expr]] = None,
|
|
406
|
+
limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
407
|
+
):
|
|
408
|
+
"""
|
|
409
|
+
Args:
|
|
410
|
+
select_list: can contain calls to AggregateFunctions
|
|
411
|
+
group_by_items: list of expressions to group by
|
|
412
|
+
limit: max number of rows to return: None = no limit
|
|
413
|
+
"""
|
|
414
|
+
_, input_col_map = input.to_cte()
|
|
415
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
416
|
+
super().__init__(None, row_builder, select_list, sql_elements)
|
|
417
|
+
self.group_by_items = group_by_items
|
|
418
|
+
|
|
419
|
+
def _create_stmt(self) -> sql.Select:
|
|
420
|
+
stmt = super()._create_stmt()
|
|
421
|
+
if self.group_by_items is not None:
|
|
422
|
+
sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
|
|
423
|
+
assert all(e is not None for e in sql_group_by_items)
|
|
424
|
+
stmt = stmt.group_by(*sql_group_by_items)
|
|
425
|
+
return stmt
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .comparison import Comparison
|
|
|
6
6
|
from .compound_predicate import CompoundPredicate
|
|
7
7
|
from .data_row import DataRow
|
|
8
8
|
from .expr import Expr
|
|
9
|
+
from .expr_dict import ExprDict
|
|
9
10
|
from .expr_set import ExprSet
|
|
10
11
|
from .function_call import FunctionCall
|
|
11
12
|
from .in_predicate import InPredicate
|
pixeltable/exprs/data_row.py
CHANGED
|
@@ -33,29 +33,40 @@ class DataRow:
|
|
|
33
33
|
- ImageType: PIL.Image.Image
|
|
34
34
|
- VideoType: local path if available, otherwise url
|
|
35
35
|
"""
|
|
36
|
-
def __init__(self, size: int, img_slot_idxs: List[int], media_slot_idxs: List[int], array_slot_idxs: List[int]):
|
|
37
|
-
self.vals: List[Any] = [None] * size # either cell values or exceptions
|
|
38
|
-
self.has_val = [False] * size
|
|
39
|
-
self.excs: List[Optional[Exception]] = [None] * size
|
|
40
36
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
self.array_slot_idxs = array_slot_idxs
|
|
37
|
+
vals: list[Any]
|
|
38
|
+
has_val: list[bool]
|
|
39
|
+
excs: list[Optional[Exception]]
|
|
45
40
|
|
|
46
|
-
|
|
47
|
-
|
|
41
|
+
# control structures that are shared across all DataRows in a batch
|
|
42
|
+
img_slot_idxs: list[int]
|
|
43
|
+
media_slot_idxs: list[int]
|
|
44
|
+
array_slot_idxs: list[int]
|
|
48
45
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# - None if vals[i] is not media type
|
|
52
|
-
# - not None if file_paths[i] is not None
|
|
53
|
-
self.file_urls: List[Optional[str]] = [None] * size
|
|
46
|
+
# the primary key of a store row is a sequence of ints (the number is different for table vs view)
|
|
47
|
+
pk: Optional[tuple[int, ...]]
|
|
54
48
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
49
|
+
# file_urls:
|
|
50
|
+
# - stored url of file for media in vals[i]
|
|
51
|
+
# - None if vals[i] is not media type
|
|
52
|
+
# - not None if file_paths[i] is not None
|
|
53
|
+
file_urls: list[Optional[str]]
|
|
54
|
+
|
|
55
|
+
# file_paths:
|
|
56
|
+
# - local path of media file in vals[i]; points to the file cache if file_urls[i] is remote
|
|
57
|
+
# - None if vals[i] is not a media type or if there is no local file yet for file_urls[i]
|
|
58
|
+
file_paths: list[Optional[str]]
|
|
59
|
+
|
|
60
|
+
def __init__(self, size: int, img_slot_idxs: List[int], media_slot_idxs: List[int], array_slot_idxs: List[int]):
|
|
61
|
+
self.vals = [None] * size
|
|
62
|
+
self.has_val = [False] * size
|
|
63
|
+
self.excs = [None] * size
|
|
64
|
+
self.img_slot_idxs = img_slot_idxs
|
|
65
|
+
self.media_slot_idxs = media_slot_idxs
|
|
66
|
+
self.array_slot_idxs = array_slot_idxs
|
|
67
|
+
self.pk = None
|
|
68
|
+
self.file_urls = [None] * size
|
|
69
|
+
self.file_paths = [None] * size
|
|
59
70
|
|
|
60
71
|
def clear(self) -> None:
|
|
61
72
|
size = len(self.vals)
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -7,11 +7,11 @@ import inspect
|
|
|
7
7
|
import json
|
|
8
8
|
import sys
|
|
9
9
|
import typing
|
|
10
|
-
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union, overload
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union, overload, Iterable
|
|
11
11
|
from uuid import UUID
|
|
12
12
|
|
|
13
13
|
import sqlalchemy as sql
|
|
14
|
-
from typing_extensions import Self
|
|
14
|
+
from typing_extensions import _AnnotatedAlias, Self
|
|
15
15
|
|
|
16
16
|
import pixeltable
|
|
17
17
|
import pixeltable.catalog as catalog
|
|
@@ -281,9 +281,10 @@ class Expr(abc.ABC):
|
|
|
281
281
|
"""
|
|
282
282
|
Iterate over all subexprs, including self.
|
|
283
283
|
"""
|
|
284
|
-
is_match =
|
|
285
|
-
|
|
286
|
-
|
|
284
|
+
is_match = isinstance(self, expr_class) if expr_class is not None else True
|
|
285
|
+
# apply filter after checking for expr_class
|
|
286
|
+
if filter is not None and is_match:
|
|
287
|
+
is_match = filter(self)
|
|
287
288
|
if not is_match or traverse_matches:
|
|
288
289
|
for c in self.components:
|
|
289
290
|
yield from c.subexprs(expr_class=expr_class, filter=filter, traverse_matches=traverse_matches)
|
|
@@ -292,7 +293,7 @@ class Expr(abc.ABC):
|
|
|
292
293
|
|
|
293
294
|
@overload
|
|
294
295
|
def list_subexprs(
|
|
295
|
-
expr_list:
|
|
296
|
+
expr_list: Iterable[Expr], *, filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
|
|
296
297
|
) -> Iterator[Expr]: ...
|
|
297
298
|
|
|
298
299
|
@overload
|
|
@@ -312,13 +313,11 @@ class Expr(abc.ABC):
|
|
|
312
313
|
|
|
313
314
|
def _contains(self, cls: Optional[type[Expr]] = None, filter: Optional[Callable[[Expr], bool]] = None) -> bool:
|
|
314
315
|
"""
|
|
315
|
-
Returns True if any subexpr is an instance of cls.
|
|
316
|
+
Returns True if any subexpr is an instance of cls and/or matches filter.
|
|
316
317
|
"""
|
|
317
|
-
assert
|
|
318
|
-
if cls is not None:
|
|
319
|
-
filter = lambda e: isinstance(e, cls)
|
|
318
|
+
assert cls is not None or filter is not None
|
|
320
319
|
try:
|
|
321
|
-
_ = next(self.subexprs(filter=filter, traverse_matches=False))
|
|
320
|
+
_ = next(self.subexprs(expr_class=cls, filter=filter, traverse_matches=False))
|
|
322
321
|
return True
|
|
323
322
|
except StopIteration:
|
|
324
323
|
return False
|
|
@@ -457,11 +456,13 @@ class Expr(abc.ABC):
|
|
|
457
456
|
else:
|
|
458
457
|
return InPredicate(self, value_set_literal=value_set)
|
|
459
458
|
|
|
460
|
-
def astype(self, new_type: ts.ColumnType) -> 'pixeltable.exprs.TypeCast':
|
|
459
|
+
def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'pixeltable.exprs.TypeCast':
|
|
461
460
|
from pixeltable.exprs import TypeCast
|
|
462
|
-
return TypeCast(self, new_type)
|
|
461
|
+
return TypeCast(self, ts.ColumnType.normalize_type(new_type))
|
|
463
462
|
|
|
464
|
-
def apply(self, fn: Callable, *, col_type:
|
|
463
|
+
def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'pixeltable.exprs.FunctionCall':
|
|
464
|
+
if col_type is not None:
|
|
465
|
+
col_type = ts.ColumnType.normalize_type(col_type)
|
|
465
466
|
function = self._make_applicator_function(fn, col_type)
|
|
466
467
|
# Return a `FunctionCall` obtained by passing this `Expr` to the new `function`.
|
|
467
468
|
return function(self)
|