pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/column.py +37 -11
- pixeltable/catalog/globals.py +21 -0
- pixeltable/catalog/insertable_table.py +6 -4
- pixeltable/catalog/table.py +227 -148
- pixeltable/catalog/table_version.py +66 -28
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +18 -19
- pixeltable/dataframe.py +16 -32
- pixeltable/env.py +6 -1
- pixeltable/exec/__init__.py +1 -2
- pixeltable/exec/aggregation_node.py +27 -17
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/data_row_batch.py +9 -26
- pixeltable/exec/exec_node.py +36 -7
- pixeltable/exec/expr_eval_node.py +19 -11
- pixeltable/exec/in_memory_data_node.py +14 -11
- pixeltable/exec/sql_node.py +266 -138
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +3 -1
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +93 -14
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +56 -36
- pixeltable/exprs/expr.py +65 -63
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +26 -15
- pixeltable/exprs/function_call.py +53 -24
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +5 -10
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +14 -13
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +12 -6
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/aggregate_function.py +1 -1
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function.py +11 -10
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +84 -42
- pixeltable/functions/huggingface.py +31 -34
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/llama_cpp.py +106 -0
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +59 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +2 -2
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +65 -74
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +20 -7
- pixeltable/index/embedding_index.py +12 -14
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +98 -2
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +10 -8
- pixeltable/iterators/video.py +126 -60
- pixeltable/metadata/__init__.py +4 -3
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +54 -12
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +40 -21
- pixeltable/plan.py +149 -165
- pixeltable/py.typed +0 -0
- pixeltable/store.py +57 -37
- pixeltable/tool/create_test_db_dump.py +6 -6
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +55 -0
- pixeltable/type_system.py +260 -61
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +16 -2
- pixeltable/utils/filecache.py +9 -9
- pixeltable/utils/formatter.py +10 -11
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
- pixeltable-0.2.22.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.20.dist-info/RECORD +0 -147
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,73 +1,176 @@
|
|
|
1
|
-
from typing import List, Optional, Tuple, Iterable, Set
|
|
2
|
-
from uuid import UUID
|
|
3
1
|
import logging
|
|
4
2
|
import warnings
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from typing import Iterable, Iterator, NamedTuple, Optional
|
|
5
|
+
from uuid import UUID
|
|
5
6
|
|
|
6
7
|
import sqlalchemy as sql
|
|
7
8
|
|
|
8
|
-
from .data_row_batch import DataRowBatch
|
|
9
|
-
from .exec_node import ExecNode
|
|
10
|
-
import pixeltable.exprs as exprs
|
|
11
9
|
import pixeltable.catalog as catalog
|
|
10
|
+
import pixeltable.exprs as exprs
|
|
12
11
|
|
|
12
|
+
from .data_row_batch import DataRowBatch
|
|
13
|
+
from .exec_node import ExecNode
|
|
13
14
|
|
|
14
15
|
_logger = logging.getLogger('pixeltable')
|
|
15
16
|
|
|
17
|
+
|
|
18
|
+
class OrderByItem(NamedTuple):
|
|
19
|
+
expr: exprs.Expr
|
|
20
|
+
asc: Optional[bool]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
OrderByClause = list[OrderByItem]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
|
|
27
|
+
"""Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
|
|
28
|
+
Two clauses are compatible if for each of their respective items c1[i] and c2[i]
|
|
29
|
+
a) the exprs are identical and
|
|
30
|
+
b) the asc values are identical or at least one is None (None serves as a wildcard)
|
|
31
|
+
"""
|
|
32
|
+
result: OrderByClause = []
|
|
33
|
+
for clause in clauses:
|
|
34
|
+
combined: OrderByClause = []
|
|
35
|
+
for item1, item2 in zip(result, clause):
|
|
36
|
+
if item1.expr.id != item2.expr.id:
|
|
37
|
+
return None
|
|
38
|
+
if item1.asc is not None and item2.asc is not None and item1.asc != item2.asc:
|
|
39
|
+
return None
|
|
40
|
+
asc = item1.asc if item1.asc is not None else item2.asc
|
|
41
|
+
combined.append(OrderByItem(item1.expr, asc))
|
|
42
|
+
|
|
43
|
+
# add remaining ordering of the longer list
|
|
44
|
+
prefix_len = min(len(result), len(clause))
|
|
45
|
+
if len(result) > prefix_len:
|
|
46
|
+
combined.extend(result[prefix_len:])
|
|
47
|
+
elif len(clause) > prefix_len:
|
|
48
|
+
combined.extend(clause[prefix_len:])
|
|
49
|
+
result = combined
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def print_order_by_clause(clause: OrderByClause) -> str:
|
|
54
|
+
return ', '.join([
|
|
55
|
+
f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
|
|
56
|
+
for item in clause
|
|
57
|
+
])
|
|
58
|
+
|
|
59
|
+
|
|
16
60
|
class SqlNode(ExecNode):
|
|
17
|
-
"""
|
|
61
|
+
"""
|
|
62
|
+
Materializes data from the store via a Select stmt.
|
|
63
|
+
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
tbl: Optional[catalog.TableVersionPath]
|
|
67
|
+
select_list: exprs.ExprSet
|
|
68
|
+
set_pk: bool
|
|
69
|
+
num_pk_cols: int
|
|
70
|
+
filter: Optional[exprs.Expr]
|
|
71
|
+
filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
|
|
72
|
+
cte: Optional[sql.CTE]
|
|
73
|
+
sql_elements: exprs.SqlElementCache
|
|
74
|
+
limit: Optional[int]
|
|
75
|
+
order_by_clause: OrderByClause
|
|
18
76
|
|
|
19
77
|
def __init__(
|
|
20
|
-
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
78
|
+
self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
|
|
21
79
|
select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
|
|
22
80
|
):
|
|
23
81
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
This only provides the select list. The subclass is responsible for the From clause and any additional clauses.
|
|
82
|
+
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
83
|
+
SQL-materializable subexpressions.
|
|
27
84
|
|
|
28
85
|
Args:
|
|
29
86
|
select_list: output of the query
|
|
30
87
|
set_pk: if True, sets the primary for each DataRow
|
|
31
88
|
"""
|
|
32
89
|
# create Select stmt
|
|
90
|
+
self.sql_elements = sql_elements
|
|
33
91
|
self.tbl = tbl
|
|
34
|
-
|
|
35
|
-
self.sql_exprs = exprs.ExprSet(select_list)
|
|
92
|
+
self.select_list = exprs.ExprSet(select_list)
|
|
36
93
|
# unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
|
|
37
94
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
38
|
-
sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
|
|
95
|
+
sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
|
|
39
96
|
for e in sql_subexprs:
|
|
40
|
-
self.
|
|
41
|
-
super().__init__(row_builder, self.
|
|
97
|
+
self.select_list.add(e)
|
|
98
|
+
super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
|
|
42
99
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
rowid_ref.set_tbl(tbl)
|
|
100
|
+
if tbl is not None:
|
|
101
|
+
# minimize the number of tables that need to be joined to the target table
|
|
102
|
+
self.retarget_rowid_refs(tbl, self.select_list)
|
|
47
103
|
|
|
48
|
-
|
|
49
|
-
assert len(sql_select_list) == len(self.sql_exprs)
|
|
50
|
-
assert all(e is not None for e in sql_select_list)
|
|
104
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
51
105
|
self.set_pk = set_pk
|
|
52
106
|
self.num_pk_cols = 0
|
|
53
107
|
if set_pk:
|
|
54
108
|
# we also need to retrieve the pk columns
|
|
55
|
-
|
|
56
|
-
self.num_pk_cols = len(pk_columns)
|
|
57
|
-
sql_select_list += pk_columns
|
|
58
|
-
|
|
59
|
-
self.stmt = sql.select(*sql_select_list)
|
|
109
|
+
assert tbl is not None
|
|
110
|
+
self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
|
|
60
111
|
|
|
61
112
|
# additional state
|
|
62
|
-
self.result_cursor
|
|
113
|
+
self.result_cursor = None
|
|
63
114
|
# the filter is provided by the subclass
|
|
64
|
-
self.filter
|
|
65
|
-
self.filter_eval_ctx
|
|
115
|
+
self.filter = None
|
|
116
|
+
self.filter_eval_ctx = None
|
|
117
|
+
self.cte = None
|
|
118
|
+
self.limit = None
|
|
119
|
+
self.order_by_clause = []
|
|
120
|
+
|
|
121
|
+
def _create_stmt(self) -> sql.Select:
|
|
122
|
+
"""Create Select from local state"""
|
|
123
|
+
|
|
124
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
125
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
126
|
+
if self.set_pk:
|
|
127
|
+
sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
|
|
128
|
+
stmt = sql.select(*sql_select_list)
|
|
129
|
+
|
|
130
|
+
order_by_clause: list[sql.ColumnElement] = []
|
|
131
|
+
for e, asc in self.order_by_clause:
|
|
132
|
+
if isinstance(e, exprs.SimilarityExpr):
|
|
133
|
+
order_by_clause.append(e.as_order_by_clause(asc))
|
|
134
|
+
else:
|
|
135
|
+
order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
|
|
136
|
+
stmt = stmt.order_by(*order_by_clause)
|
|
137
|
+
|
|
138
|
+
if self.filter is None and self.limit is not None:
|
|
139
|
+
# if we don't have a Python filter, we can apply the limit to stmt
|
|
140
|
+
stmt = stmt.limit(self.limit)
|
|
141
|
+
|
|
142
|
+
return stmt
|
|
143
|
+
|
|
144
|
+
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
145
|
+
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
146
|
+
|
|
147
|
+
def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
148
|
+
"""
|
|
149
|
+
Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
(CTE, dict from Expr to output column)
|
|
153
|
+
"""
|
|
154
|
+
if self.filter is not None:
|
|
155
|
+
# the filter needs to run in Python
|
|
156
|
+
return None
|
|
157
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
158
|
+
if self.cte is None:
|
|
159
|
+
self.cte = self._create_stmt().cte()
|
|
160
|
+
assert len(self.cte.c) == len(self.select_list)
|
|
161
|
+
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
165
|
+
"""Change rowid refs to point to target"""
|
|
166
|
+
for e in expr_seq:
|
|
167
|
+
if isinstance(e, exprs.RowidRef):
|
|
168
|
+
e.set_tbl(target)
|
|
66
169
|
|
|
67
170
|
@classmethod
|
|
68
171
|
def create_from_clause(
|
|
69
|
-
cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[
|
|
70
|
-
exact_version_only: Optional[
|
|
172
|
+
cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[set[UUID]] = None,
|
|
173
|
+
exact_version_only: Optional[set[UUID]] = None
|
|
71
174
|
) -> sql.Select:
|
|
72
175
|
"""Add From clause to stmt for tables/views referenced by materialized_exprs
|
|
73
176
|
Args:
|
|
@@ -80,17 +183,18 @@ class SqlNode(ExecNode):
|
|
|
80
183
|
"""
|
|
81
184
|
# we need to include at least the root
|
|
82
185
|
if refd_tbl_ids is None:
|
|
83
|
-
refd_tbl_ids =
|
|
186
|
+
refd_tbl_ids = set()
|
|
84
187
|
if exact_version_only is None:
|
|
85
|
-
exact_version_only =
|
|
188
|
+
exact_version_only = set()
|
|
86
189
|
candidates = tbl.get_tbl_versions()
|
|
87
190
|
assert len(candidates) > 0
|
|
88
|
-
joined_tbls:
|
|
191
|
+
joined_tbls: list[catalog.TableVersion] = [candidates[0]]
|
|
89
192
|
for tbl in candidates[1:]:
|
|
90
193
|
if tbl.id in refd_tbl_ids:
|
|
91
194
|
joined_tbls.append(tbl)
|
|
92
195
|
|
|
93
196
|
first = True
|
|
197
|
+
prev_tbl: catalog.TableVersion
|
|
94
198
|
for tbl in joined_tbls[::-1]:
|
|
95
199
|
if first:
|
|
96
200
|
stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
|
|
@@ -111,72 +215,94 @@ class SqlNode(ExecNode):
|
|
|
111
215
|
prev_tbl = tbl
|
|
112
216
|
return stmt
|
|
113
217
|
|
|
114
|
-
def
|
|
218
|
+
def add_order_by(self, ordering: OrderByClause) -> None:
|
|
219
|
+
"""Add Order By clause to stmt"""
|
|
220
|
+
if self.tbl is not None:
|
|
221
|
+
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
222
|
+
# the number of tables that need to be joined to the target table
|
|
223
|
+
self.retarget_rowid_refs(self.tbl, [e for e, _ in ordering])
|
|
224
|
+
combined = combine_order_by_clauses([self.order_by_clause, ordering])
|
|
225
|
+
assert combined is not None
|
|
226
|
+
self.order_by_clause = combined
|
|
227
|
+
|
|
228
|
+
def set_limit(self, limit: int) -> None:
|
|
229
|
+
self.limit = limit
|
|
230
|
+
|
|
231
|
+
def _log_explain(self, stmt: sql.Select) -> None:
|
|
115
232
|
try:
|
|
116
233
|
# don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
|
|
117
|
-
stmt_str = str(
|
|
234
|
+
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
118
235
|
explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
|
|
119
236
|
explain_str = '\n'.join([str(row) for row in explain_result])
|
|
120
237
|
_logger.debug(f'SqlScanNode explain:\n{explain_str}')
|
|
121
238
|
except Exception as e:
|
|
122
239
|
_logger.warning(f'EXPLAIN failed')
|
|
123
240
|
|
|
124
|
-
def
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
241
|
+
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
242
|
+
# run the query; do this here rather than in _open(), exceptions are only expected during iteration
|
|
243
|
+
assert self.ctx.conn is not None
|
|
244
|
+
with warnings.catch_warnings(record=True) as w:
|
|
245
|
+
stmt = self._create_stmt()
|
|
128
246
|
try:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
break
|
|
150
|
-
|
|
151
|
-
if needs_row:
|
|
152
|
-
output_row = output_batch.add_row()
|
|
247
|
+
# log stmt, if possible
|
|
248
|
+
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
249
|
+
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
250
|
+
except Exception:
|
|
251
|
+
pass
|
|
252
|
+
self._log_explain(stmt)
|
|
253
|
+
|
|
254
|
+
result_cursor = self.ctx.conn.execute(stmt)
|
|
255
|
+
for warning in w:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
tbl_version = self.tbl.tbl_version if self.tbl is not None else None
|
|
259
|
+
output_batch = DataRowBatch(tbl_version, self.row_builder)
|
|
260
|
+
output_row: Optional[exprs.DataRow] = None
|
|
261
|
+
num_rows_returned = 0
|
|
262
|
+
|
|
263
|
+
for sql_row in result_cursor:
|
|
264
|
+
output_row = output_batch.add_row(output_row)
|
|
265
|
+
|
|
266
|
+
# populate output_row
|
|
153
267
|
if self.num_pk_cols > 0:
|
|
154
268
|
output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
|
|
155
269
|
# copy the output of the SQL query into the output row
|
|
156
|
-
for i, e in enumerate(self.
|
|
270
|
+
for i, e in enumerate(self.select_list):
|
|
157
271
|
slot_idx = e.slot_idx
|
|
158
|
-
|
|
272
|
+
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
273
|
+
if isinstance(sql_row[i], Decimal):
|
|
274
|
+
if e.col_type.is_int_type():
|
|
275
|
+
output_row[slot_idx] = int(sql_row[i])
|
|
276
|
+
elif e.col_type.is_float_type():
|
|
277
|
+
output_row[slot_idx] = float(sql_row[i])
|
|
278
|
+
else:
|
|
279
|
+
raise RuntimeError(f'Unexpected Decimal value for {e}')
|
|
280
|
+
else:
|
|
281
|
+
output_row[slot_idx] = sql_row[i]
|
|
282
|
+
|
|
159
283
|
if self.filter is not None:
|
|
284
|
+
# evaluate filter
|
|
160
285
|
self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
output_row.clear()
|
|
286
|
+
if self.filter is not None and not output_row[self.filter.slot_idx]:
|
|
287
|
+
# we re-use this row for the next sql row since it didn't pass the filter
|
|
288
|
+
output_row = output_batch.pop_row()
|
|
289
|
+
output_row.clear()
|
|
290
|
+
else:
|
|
291
|
+
# reset output_row in order to add new one
|
|
292
|
+
output_row = None
|
|
293
|
+
num_rows_returned += 1
|
|
170
294
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
295
|
+
if self.limit is not None and num_rows_returned == self.limit:
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
|
|
299
|
+
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
300
|
+
yield output_batch
|
|
301
|
+
output_batch = DataRowBatch(tbl_version, self.row_builder)
|
|
175
302
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
return output_batch
|
|
303
|
+
if len(output_batch) > 0:
|
|
304
|
+
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
305
|
+
yield output_batch
|
|
180
306
|
|
|
181
307
|
def _close(self) -> None:
|
|
182
308
|
if self.result_cursor is not None:
|
|
@@ -189,12 +315,14 @@ class SqlScanNode(SqlNode):
|
|
|
189
315
|
|
|
190
316
|
Supports filtering and ordering.
|
|
191
317
|
"""
|
|
318
|
+
where_clause: Optional[exprs.Expr]
|
|
319
|
+
exact_version_only: list[catalog.TableVersion]
|
|
320
|
+
|
|
192
321
|
def __init__(
|
|
193
322
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
194
323
|
select_list: Iterable[exprs.Expr],
|
|
195
324
|
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
|
|
196
|
-
|
|
197
|
-
limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
|
|
325
|
+
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
198
326
|
):
|
|
199
327
|
"""
|
|
200
328
|
Args:
|
|
@@ -208,52 +336,29 @@ class SqlScanNode(SqlNode):
|
|
|
208
336
|
sql_elements = exprs.SqlElementCache()
|
|
209
337
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
|
|
210
338
|
# create Select stmt
|
|
211
|
-
if order_by_items is None:
|
|
212
|
-
order_by_items = []
|
|
213
339
|
if exact_version_only is None:
|
|
214
340
|
exact_version_only = []
|
|
215
341
|
target = tbl.tbl_version # the stored table we're scanning
|
|
216
342
|
self.filter = filter
|
|
217
343
|
self.filter_eval_ctx = \
|
|
218
344
|
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
219
|
-
self.limit = limit
|
|
220
345
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
self.stmt = self.create_from_clause(
|
|
224
|
-
tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
|
|
225
|
-
|
|
226
|
-
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
227
|
-
# the number of tables that need to be joined to the target table
|
|
228
|
-
for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
|
|
229
|
-
rowid_ref.set_tbl(tbl)
|
|
230
|
-
order_by_clause: List[sql.ClauseElement] = []
|
|
231
|
-
for e, asc in order_by_items:
|
|
232
|
-
if isinstance(e, exprs.SimilarityExpr):
|
|
233
|
-
order_by_clause.append(e.as_order_by_clause(asc))
|
|
234
|
-
else:
|
|
235
|
-
order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
|
|
346
|
+
self.where_clause = where_clause
|
|
347
|
+
self.exact_version_only = exact_version_only
|
|
236
348
|
|
|
237
|
-
|
|
238
|
-
|
|
349
|
+
def _create_stmt(self) -> sql.Select:
|
|
350
|
+
stmt = super()._create_stmt()
|
|
351
|
+
where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
|
|
352
|
+
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
353
|
+
stmt = self.create_from_clause(
|
|
354
|
+
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
|
|
355
|
+
|
|
356
|
+
if self.where_clause is not None:
|
|
357
|
+
sql_where_clause = self.sql_elements.get(self.where_clause)
|
|
239
358
|
assert sql_where_clause is not None
|
|
240
|
-
|
|
241
|
-
if len(order_by_clause) > 0:
|
|
242
|
-
self.stmt = self.stmt.order_by(*order_by_clause)
|
|
243
|
-
elif target.id in row_builder.unstored_iter_args:
|
|
244
|
-
# we are referencing unstored iter columns from this view and try to order by our primary key,
|
|
245
|
-
# which ensures that iterators will see monotonically increasing pos values
|
|
246
|
-
self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
|
|
247
|
-
if limit != 0 and self.filter is None:
|
|
248
|
-
# if we need to do post-SQL filtering, we can't use LIMIT
|
|
249
|
-
self.stmt = self.stmt.limit(limit)
|
|
359
|
+
stmt = stmt.where(sql_where_clause)
|
|
250
360
|
|
|
251
|
-
|
|
252
|
-
# log stmt, if possible
|
|
253
|
-
stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
254
|
-
_logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
|
|
255
|
-
except Exception as e:
|
|
256
|
-
pass
|
|
361
|
+
return stmt
|
|
257
362
|
|
|
258
363
|
|
|
259
364
|
class SqlLookupNode(SqlNode):
|
|
@@ -261,8 +366,7 @@ class SqlLookupNode(SqlNode):
|
|
|
261
366
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
262
367
|
"""
|
|
263
368
|
|
|
264
|
-
|
|
265
|
-
where_clause: sql.ColumnElement[bool]
|
|
369
|
+
where_clause: sql.ColumnElement
|
|
266
370
|
|
|
267
371
|
def __init__(
|
|
268
372
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
@@ -276,21 +380,45 @@ class SqlLookupNode(SqlNode):
|
|
|
276
380
|
"""
|
|
277
381
|
sql_elements = exprs.SqlElementCache()
|
|
278
382
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
279
|
-
target = tbl.tbl_version # the stored table we're scanning
|
|
280
|
-
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
|
|
281
|
-
self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
|
|
282
383
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
283
384
|
self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
284
|
-
self.stmt = self.stmt.where(self.where_clause)
|
|
285
385
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
386
|
+
def _create_stmt(self) -> sql.Select:
|
|
387
|
+
stmt = super()._create_stmt()
|
|
388
|
+
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
|
|
389
|
+
stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
|
|
390
|
+
stmt = stmt.where(self.where_clause)
|
|
391
|
+
return stmt
|
|
290
392
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
393
|
+
class SqlAggregationNode(SqlNode):
|
|
394
|
+
"""
|
|
395
|
+
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
group_by_items: Optional[list[exprs.Expr]]
|
|
399
|
+
|
|
400
|
+
def __init__(
|
|
401
|
+
self, row_builder: exprs.RowBuilder,
|
|
402
|
+
input: SqlNode,
|
|
403
|
+
select_list: Iterable[exprs.Expr],
|
|
404
|
+
group_by_items: Optional[list[exprs.Expr]] = None,
|
|
405
|
+
limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
406
|
+
):
|
|
407
|
+
"""
|
|
408
|
+
Args:
|
|
409
|
+
select_list: can contain calls to AggregateFunctions
|
|
410
|
+
group_by_items: list of expressions to group by
|
|
411
|
+
limit: max number of rows to return: None = no limit
|
|
412
|
+
"""
|
|
413
|
+
_, input_col_map = input.to_cte()
|
|
414
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
415
|
+
super().__init__(None, row_builder, select_list, sql_elements)
|
|
416
|
+
self.group_by_items = group_by_items
|
|
417
|
+
|
|
418
|
+
def _create_stmt(self) -> sql.Select:
|
|
419
|
+
stmt = super()._create_stmt()
|
|
420
|
+
if self.group_by_items is not None:
|
|
421
|
+
sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
|
|
422
|
+
assert all(e is not None for e in sql_group_by_items)
|
|
423
|
+
stmt = stmt.group_by(*sql_group_by_items)
|
|
424
|
+
return stmt
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .comparison import Comparison
|
|
|
6
6
|
from .compound_predicate import CompoundPredicate
|
|
7
7
|
from .data_row import DataRow
|
|
8
8
|
from .expr import Expr
|
|
9
|
+
from .expr_dict import ExprDict
|
|
9
10
|
from .expr_set import ExprSet
|
|
10
11
|
from .function_call import FunctionCall
|
|
11
12
|
from .in_predicate import InPredicate
|
|
@@ -6,6 +6,7 @@ import sqlalchemy as sql
|
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
8
|
import pixeltable.type_system as ts
|
|
9
|
+
|
|
9
10
|
from .data_row import DataRow
|
|
10
11
|
from .expr import Expr
|
|
11
12
|
from .globals import ArithmeticOperator
|
|
@@ -86,6 +87,7 @@ class ArithmeticExpr(Expr):
|
|
|
86
87
|
return sql.sql.expression.cast(sql.func.floor(left / right), sql.Integer)
|
|
87
88
|
if self.col_type.is_float_type():
|
|
88
89
|
return sql.sql.expression.cast(sql.func.floor(left / right), sql.Float)
|
|
90
|
+
assert False
|
|
89
91
|
|
|
90
92
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
91
93
|
op1_val = data_row[self._op1.slot_idx]
|
|
@@ -121,7 +123,7 @@ class ArithmeticExpr(Expr):
|
|
|
121
123
|
return {'operator': self.operator.value, **super()._as_dict()}
|
|
122
124
|
|
|
123
125
|
@classmethod
|
|
124
|
-
def _from_dict(cls, d: dict, components: list[Expr]) ->
|
|
126
|
+
def _from_dict(cls, d: dict, components: list[Expr]) -> ArithmeticExpr:
|
|
125
127
|
assert 'operator' in d
|
|
126
128
|
assert len(components) == 2
|
|
127
129
|
return cls(ArithmeticOperator(d['operator']), components[0], components[1])
|
pixeltable/exprs/array_slice.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
4
|
|
|
5
5
|
import sqlalchemy as sql
|
|
6
6
|
|
|
@@ -15,7 +15,7 @@ class ArraySlice(Expr):
|
|
|
15
15
|
"""
|
|
16
16
|
Slice operation on an array, eg, t.array_col[:, 1:2].
|
|
17
17
|
"""
|
|
18
|
-
def __init__(self, arr: Expr, index:
|
|
18
|
+
def __init__(self, arr: Expr, index: tuple[Union[int, slice], ...]):
|
|
19
19
|
assert arr.col_type.is_array_type()
|
|
20
20
|
# determine result type
|
|
21
21
|
super().__init__(arr.col_type)
|
|
@@ -24,7 +24,7 @@ class ArraySlice(Expr):
|
|
|
24
24
|
self.id = self._create_id()
|
|
25
25
|
|
|
26
26
|
def __str__(self) -> str:
|
|
27
|
-
index_strs:
|
|
27
|
+
index_strs: list[str] = []
|
|
28
28
|
for el in self.index:
|
|
29
29
|
if isinstance(el, int):
|
|
30
30
|
index_strs.append(str(el))
|
|
@@ -39,7 +39,7 @@ class ArraySlice(Expr):
|
|
|
39
39
|
def _equals(self, other: ArraySlice) -> bool:
|
|
40
40
|
return self.index == other.index
|
|
41
41
|
|
|
42
|
-
def _id_attrs(self) ->
|
|
42
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
43
43
|
return super()._id_attrs() + [('index', self.index)]
|
|
44
44
|
|
|
45
45
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
@@ -49,8 +49,8 @@ class ArraySlice(Expr):
|
|
|
49
49
|
val = data_row[self._array.slot_idx]
|
|
50
50
|
data_row[self.slot_idx] = val[self.index]
|
|
51
51
|
|
|
52
|
-
def _as_dict(self) ->
|
|
53
|
-
index = []
|
|
52
|
+
def _as_dict(self) -> dict:
|
|
53
|
+
index: list[Any] = []
|
|
54
54
|
for el in self.index:
|
|
55
55
|
if isinstance(el, slice):
|
|
56
56
|
index.append([el.start, el.stop, el.step])
|
|
@@ -59,7 +59,7 @@ class ArraySlice(Expr):
|
|
|
59
59
|
return {'index': index, **super()._as_dict()}
|
|
60
60
|
|
|
61
61
|
@classmethod
|
|
62
|
-
def _from_dict(cls, d:
|
|
62
|
+
def _from_dict(cls, d: dict, components: list[Expr]) -> ArraySlice:
|
|
63
63
|
assert 'index' in d
|
|
64
64
|
index = []
|
|
65
65
|
for el in d['index']:
|