pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,31 +1,34 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import logging
|
|
2
3
|
import warnings
|
|
3
4
|
from decimal import Decimal
|
|
4
|
-
from typing import
|
|
5
|
+
from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Sequence
|
|
5
6
|
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
import sqlalchemy as sql
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
10
|
+
from pixeltable import catalog, exprs
|
|
11
|
+
from pixeltable.env import Env
|
|
12
|
+
|
|
11
13
|
from .data_row_batch import DataRowBatch
|
|
12
14
|
from .exec_node import ExecNode
|
|
13
15
|
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
17
|
import pixeltable.plan
|
|
18
|
+
from pixeltable.plan import SampleClause
|
|
16
19
|
|
|
17
20
|
_logger = logging.getLogger('pixeltable')
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class OrderByItem(NamedTuple):
|
|
21
24
|
expr: exprs.Expr
|
|
22
|
-
asc:
|
|
25
|
+
asc: bool | None
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
OrderByClause = list[OrderByItem]
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) ->
|
|
31
|
+
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> OrderByClause | None:
|
|
29
32
|
"""Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
|
|
30
33
|
Two clauses are compatible if for each of their respective items c1[i] and c2[i]
|
|
31
34
|
a) the exprs are identical and
|
|
@@ -53,56 +56,91 @@ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[Order
|
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def print_order_by_clause(clause: OrderByClause) -> str:
|
|
56
|
-
return ', '.join(
|
|
59
|
+
return ', '.join(
|
|
57
60
|
f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
|
|
58
61
|
for item in clause
|
|
59
|
-
|
|
62
|
+
)
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
class SqlNode(ExecNode):
|
|
63
66
|
"""
|
|
64
|
-
Materializes data from the store via a
|
|
67
|
+
Materializes data from the store via a SQL statement.
|
|
65
68
|
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
69
|
+
The pk columns are not included in the select list.
|
|
70
|
+
If set_pk is True, they are added to the end of the result set when creating the SQL statement
|
|
71
|
+
so they can always be referenced as cols[-num_pk_cols:] in the result set.
|
|
72
|
+
The pk_columns consist of the rowid columns of the target table followed by the version number.
|
|
73
|
+
|
|
74
|
+
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
75
|
+
SQL-materializable subexpressions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
select_list: output of the query
|
|
79
|
+
set_pk: if True, sets the primary for each DataRow
|
|
66
80
|
"""
|
|
67
81
|
|
|
68
|
-
tbl:
|
|
82
|
+
tbl: catalog.TableVersionPath | None
|
|
69
83
|
select_list: exprs.ExprSet
|
|
84
|
+
columns: list[catalog.Column] # for which columns to populate DataRow.cell_vals/cell_md
|
|
85
|
+
cell_md_refs: list[exprs.ColumnPropertyRef] # of ColumnRefs which also need DataRow.slot_cellmd for evaluation
|
|
70
86
|
set_pk: bool
|
|
71
87
|
num_pk_cols: int
|
|
72
|
-
py_filter:
|
|
73
|
-
py_filter_eval_ctx:
|
|
74
|
-
cte:
|
|
88
|
+
py_filter: exprs.Expr | None # a predicate that can only be run in Python
|
|
89
|
+
py_filter_eval_ctx: exprs.RowBuilder.EvalCtx | None
|
|
90
|
+
cte: sql.CTE | None
|
|
75
91
|
sql_elements: exprs.SqlElementCache
|
|
76
92
|
|
|
93
|
+
# execution state
|
|
94
|
+
sql_select_list_exprs: exprs.ExprSet
|
|
95
|
+
cellmd_item_idxs: exprs.ExprDict[int] # cellmd expr -> idx in sql select list
|
|
96
|
+
column_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
|
|
97
|
+
column_cellmd_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
|
|
98
|
+
result_cursor: sql.engine.CursorResult | None
|
|
99
|
+
|
|
77
100
|
# where_clause/-_element: allow subclass to set one or the other (but not both)
|
|
78
|
-
where_clause:
|
|
79
|
-
where_clause_element:
|
|
101
|
+
where_clause: exprs.Expr | None
|
|
102
|
+
where_clause_element: sql.ColumnElement | None
|
|
80
103
|
|
|
81
104
|
order_by_clause: OrderByClause
|
|
82
|
-
limit:
|
|
105
|
+
limit: int | None
|
|
83
106
|
|
|
84
107
|
def __init__(
|
|
85
|
-
|
|
86
|
-
|
|
108
|
+
self,
|
|
109
|
+
tbl: catalog.TableVersionPath | None,
|
|
110
|
+
row_builder: exprs.RowBuilder,
|
|
111
|
+
select_list: Iterable[exprs.Expr],
|
|
112
|
+
columns: list[catalog.Column],
|
|
113
|
+
sql_elements: exprs.SqlElementCache,
|
|
114
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
115
|
+
set_pk: bool = False,
|
|
87
116
|
):
|
|
88
|
-
"""
|
|
89
|
-
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
90
|
-
SQL-materializable subexpressions.
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
select_list: output of the query
|
|
94
|
-
set_pk: if True, sets the primary for each DataRow
|
|
95
|
-
"""
|
|
96
117
|
# create Select stmt
|
|
97
118
|
self.sql_elements = sql_elements
|
|
98
119
|
self.tbl = tbl
|
|
99
|
-
|
|
120
|
+
self.columns = columns
|
|
121
|
+
if cell_md_col_refs is not None:
|
|
122
|
+
assert all(ref.col.stores_cellmd for ref in cell_md_col_refs)
|
|
123
|
+
self.cell_md_refs = [
|
|
124
|
+
exprs.ColumnPropertyRef(ref, exprs.ColumnPropertyRef.Property.CELLMD) for ref in cell_md_col_refs
|
|
125
|
+
]
|
|
126
|
+
else:
|
|
127
|
+
self.cell_md_refs = []
|
|
100
128
|
self.select_list = exprs.ExprSet(select_list)
|
|
101
|
-
# unstored iter columns: we also need to retrieve whatever is needed to materialize the
|
|
129
|
+
# unstored iter columns: we also need to retrieve whatever is needed to materialize the
|
|
130
|
+
# iter args and stored outputs
|
|
102
131
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
103
132
|
sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
|
|
104
|
-
|
|
105
|
-
|
|
133
|
+
self.select_list.update(sql_subexprs)
|
|
134
|
+
# We query for unstored outputs only if we're not loading a view; when we're loading a view, we are populating
|
|
135
|
+
# those columns, so we need to keep them out of the select list. This isn't a problem, because view loads never
|
|
136
|
+
# need to call set_pos().
|
|
137
|
+
# TODO: This is necessary because create_view_load_plan passes stored output columns to `RowBuilder` via the
|
|
138
|
+
# `columns` parameter (even though they don't appear in `output_exprs`). This causes them to be recorded as
|
|
139
|
+
# expressions in `RowBuilder`, which creates a conflict if we add them here. If `RowBuilder` is restructured
|
|
140
|
+
# to keep them out of `unique_exprs`, then this conditional can be removed.
|
|
141
|
+
if not row_builder.for_view_load:
|
|
142
|
+
for outputs in row_builder.unstored_iter_outputs.values():
|
|
143
|
+
self.select_list.update(outputs)
|
|
106
144
|
super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
|
|
107
145
|
|
|
108
146
|
if tbl is not None:
|
|
@@ -115,9 +153,13 @@ class SqlNode(ExecNode):
|
|
|
115
153
|
if set_pk:
|
|
116
154
|
# we also need to retrieve the pk columns
|
|
117
155
|
assert tbl is not None
|
|
118
|
-
self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
|
|
156
|
+
self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
|
|
157
|
+
assert self.num_pk_cols > 1
|
|
119
158
|
|
|
120
159
|
# additional state
|
|
160
|
+
self.cellmd_item_idxs = exprs.ExprDict()
|
|
161
|
+
self.column_item_idxs = {}
|
|
162
|
+
self.column_cellmd_item_idxs = {}
|
|
121
163
|
self.result_cursor = None
|
|
122
164
|
# the filter is provided by the subclass
|
|
123
165
|
self.py_filter = None
|
|
@@ -128,14 +170,38 @@ class SqlNode(ExecNode):
|
|
|
128
170
|
self.where_clause_element = None
|
|
129
171
|
self.order_by_clause = []
|
|
130
172
|
|
|
131
|
-
|
|
132
|
-
|
|
173
|
+
if self.tbl is not None:
|
|
174
|
+
tv = self.tbl.tbl_version._tbl_version
|
|
175
|
+
if tv is not None:
|
|
176
|
+
assert tv.is_validated
|
|
133
177
|
|
|
134
|
-
|
|
135
|
-
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
178
|
+
def _pk_col_items(self) -> list[sql.Column]:
|
|
136
179
|
if self.set_pk:
|
|
180
|
+
# we need to retrieve the pk columns
|
|
137
181
|
assert self.tbl is not None
|
|
138
|
-
|
|
182
|
+
assert self.tbl.tbl_version.get().is_validated
|
|
183
|
+
return self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
184
|
+
return []
|
|
185
|
+
|
|
186
|
+
def _init_exec_state(self) -> None:
|
|
187
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
188
|
+
self.sql_select_list_exprs = exprs.ExprSet(self.select_list)
|
|
189
|
+
self.cellmd_item_idxs = exprs.ExprDict((ref, self.sql_select_list_exprs.add(ref)) for ref in self.cell_md_refs)
|
|
190
|
+
column_refs = [exprs.ColumnRef(col) for col in self.columns]
|
|
191
|
+
self.column_item_idxs = {col_ref.col: self.sql_select_list_exprs.add(col_ref) for col_ref in column_refs}
|
|
192
|
+
column_cellmd_refs = [
|
|
193
|
+
exprs.ColumnPropertyRef(col_ref, exprs.ColumnPropertyRef.Property.CELLMD)
|
|
194
|
+
for col_ref in column_refs
|
|
195
|
+
if col_ref.col.stores_cellmd
|
|
196
|
+
]
|
|
197
|
+
self.column_cellmd_item_idxs = {
|
|
198
|
+
cellmd_ref.col_ref.col: self.sql_select_list_exprs.add(cellmd_ref) for cellmd_ref in column_cellmd_refs
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def _create_stmt(self) -> sql.Select:
|
|
202
|
+
"""Create Select from local state"""
|
|
203
|
+
self._init_exec_state()
|
|
204
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.sql_select_list_exprs] + self._pk_col_items()
|
|
139
205
|
stmt = sql.select(*sql_select_list)
|
|
140
206
|
|
|
141
207
|
where_clause_element = (
|
|
@@ -161,9 +227,10 @@ class SqlNode(ExecNode):
|
|
|
161
227
|
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
162
228
|
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
163
229
|
|
|
164
|
-
def to_cte(self) ->
|
|
230
|
+
def to_cte(self, keep_pk: bool = False) -> tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]] | None:
|
|
165
231
|
"""
|
|
166
|
-
|
|
232
|
+
Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
|
|
233
|
+
keep_pk: if True, the PK columns are included in the CTE Select statement
|
|
167
234
|
|
|
168
235
|
Returns:
|
|
169
236
|
(CTE, dict from Expr to output column)
|
|
@@ -171,11 +238,11 @@ class SqlNode(ExecNode):
|
|
|
171
238
|
if self.py_filter is not None:
|
|
172
239
|
# the filter needs to run in Python
|
|
173
240
|
return None
|
|
174
|
-
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
175
241
|
if self.cte is None:
|
|
242
|
+
if not keep_pk:
|
|
243
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
176
244
|
self.cte = self._create_stmt().cte()
|
|
177
|
-
|
|
178
|
-
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
|
|
245
|
+
return self.cte, exprs.ExprDict(zip(list(self.select_list) + self.cell_md_refs, self.cte.c)) # skip pk cols
|
|
179
246
|
|
|
180
247
|
@classmethod
|
|
181
248
|
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
@@ -186,8 +253,11 @@ class SqlNode(ExecNode):
|
|
|
186
253
|
|
|
187
254
|
@classmethod
|
|
188
255
|
def create_from_clause(
|
|
189
|
-
|
|
190
|
-
|
|
256
|
+
cls,
|
|
257
|
+
tbl: catalog.TableVersionPath,
|
|
258
|
+
stmt: sql.Select,
|
|
259
|
+
refd_tbl_ids: set[UUID] | None = None,
|
|
260
|
+
exact_version_only: set[UUID] | None = None,
|
|
191
261
|
) -> sql.Select:
|
|
192
262
|
"""Add From clause to stmt for tables/views referenced by materialized_exprs
|
|
193
263
|
Args:
|
|
@@ -205,31 +275,35 @@ class SqlNode(ExecNode):
|
|
|
205
275
|
exact_version_only = set()
|
|
206
276
|
candidates = tbl.get_tbl_versions()
|
|
207
277
|
assert len(candidates) > 0
|
|
208
|
-
joined_tbls: list[catalog.
|
|
209
|
-
for
|
|
210
|
-
if
|
|
211
|
-
joined_tbls.append(
|
|
278
|
+
joined_tbls: list[catalog.TableVersionHandle] = [candidates[0]]
|
|
279
|
+
for t in candidates[1:]:
|
|
280
|
+
if t.id in refd_tbl_ids:
|
|
281
|
+
joined_tbls.append(t)
|
|
212
282
|
|
|
213
283
|
first = True
|
|
214
|
-
|
|
215
|
-
for
|
|
284
|
+
prev_tv: catalog.TableVersion | None = None
|
|
285
|
+
for t in joined_tbls[::-1]:
|
|
286
|
+
tv = t.get()
|
|
287
|
+
# _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
|
|
216
288
|
if first:
|
|
217
|
-
stmt = stmt.select_from(
|
|
289
|
+
stmt = stmt.select_from(tv.store_tbl.sa_tbl)
|
|
218
290
|
first = False
|
|
219
291
|
else:
|
|
220
|
-
# join
|
|
221
|
-
prev_tbl_rowid_cols =
|
|
222
|
-
tbl_rowid_cols =
|
|
223
|
-
rowid_clauses =
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
292
|
+
# join tv to prev_tv on prev_tv's rowid cols
|
|
293
|
+
prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
|
|
294
|
+
tbl_rowid_cols = tv.store_tbl.rowid_columns()
|
|
295
|
+
rowid_clauses = [
|
|
296
|
+
c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
|
|
297
|
+
]
|
|
298
|
+
stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
|
|
299
|
+
|
|
300
|
+
if t.id in exact_version_only:
|
|
301
|
+
stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
|
|
228
302
|
else:
|
|
229
|
-
stmt = stmt
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
303
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
|
|
304
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
|
|
305
|
+
prev_tv = tv
|
|
306
|
+
|
|
233
307
|
return stmt
|
|
234
308
|
|
|
235
309
|
def set_where(self, where_clause: exprs.Expr) -> None:
|
|
@@ -255,18 +329,18 @@ class SqlNode(ExecNode):
|
|
|
255
329
|
self.limit = limit
|
|
256
330
|
|
|
257
331
|
def _log_explain(self, stmt: sql.Select) -> None:
|
|
332
|
+
conn = Env.get().conn
|
|
258
333
|
try:
|
|
259
334
|
# don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
|
|
260
335
|
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
261
|
-
explain_result =
|
|
336
|
+
explain_result = conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
|
|
262
337
|
explain_str = '\n'.join([str(row) for row in explain_result])
|
|
263
338
|
_logger.debug(f'SqlScanNode explain:\n{explain_str}')
|
|
264
339
|
except Exception as e:
|
|
265
|
-
_logger.warning(f'EXPLAIN failed')
|
|
340
|
+
_logger.warning(f'EXPLAIN failed with error: {e}')
|
|
266
341
|
|
|
267
|
-
def
|
|
342
|
+
async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
|
|
268
343
|
# run the query; do this here rather than in _open(), exceptions are only expected during iteration
|
|
269
|
-
assert self.ctx.conn is not None
|
|
270
344
|
with warnings.catch_warnings(record=True) as w:
|
|
271
345
|
stmt = self._create_stmt()
|
|
272
346
|
try:
|
|
@@ -274,35 +348,65 @@ class SqlNode(ExecNode):
|
|
|
274
348
|
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
275
349
|
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
276
350
|
except Exception:
|
|
277
|
-
|
|
351
|
+
# log something if we can't log the compiled stmt
|
|
352
|
+
_logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
|
|
278
353
|
self._log_explain(stmt)
|
|
279
354
|
|
|
280
|
-
|
|
281
|
-
|
|
355
|
+
conn = Env.get().conn
|
|
356
|
+
result_cursor = conn.execute(stmt)
|
|
357
|
+
for _ in w:
|
|
282
358
|
pass
|
|
283
359
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
output_row: Optional[exprs.DataRow] = None
|
|
360
|
+
output_batch = DataRowBatch(self.row_builder)
|
|
361
|
+
output_row: exprs.DataRow | None = None
|
|
287
362
|
num_rows_returned = 0
|
|
363
|
+
is_using_cockroachdb = Env.get().is_using_cockroachdb
|
|
364
|
+
tzinfo = Env.get().default_time_zone
|
|
288
365
|
|
|
289
366
|
for sql_row in result_cursor:
|
|
290
367
|
output_row = output_batch.add_row(output_row)
|
|
291
368
|
|
|
292
369
|
# populate output_row
|
|
370
|
+
|
|
293
371
|
if self.num_pk_cols > 0:
|
|
294
|
-
output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
|
|
372
|
+
output_row.set_pk(tuple(sql_row[-self.num_pk_cols :]))
|
|
373
|
+
|
|
374
|
+
# column copies
|
|
375
|
+
for col, item_idx in self.column_item_idxs.items():
|
|
376
|
+
output_row.cell_vals[col.id] = sql_row[item_idx]
|
|
377
|
+
for col, item_idx in self.column_cellmd_item_idxs.items():
|
|
378
|
+
cell_md_dict = sql_row[item_idx]
|
|
379
|
+
output_row.cell_md[col.id] = exprs.CellMd(**cell_md_dict) if cell_md_dict is not None else None
|
|
380
|
+
|
|
381
|
+
# populate DataRow.slot_cellmd, where requested
|
|
382
|
+
for cellmd_ref, item_idx in self.cellmd_item_idxs.items():
|
|
383
|
+
cell_md_dict = sql_row[item_idx]
|
|
384
|
+
output_row.slot_md[cellmd_ref.col_ref.slot_idx] = (
|
|
385
|
+
exprs.CellMd.from_dict(cell_md_dict) if cell_md_dict is not None else None
|
|
386
|
+
)
|
|
387
|
+
|
|
295
388
|
# copy the output of the SQL query into the output row
|
|
296
389
|
for i, e in enumerate(self.select_list):
|
|
297
390
|
slot_idx = e.slot_idx
|
|
298
|
-
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
299
391
|
if isinstance(sql_row[i], Decimal):
|
|
392
|
+
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
300
393
|
if e.col_type.is_int_type():
|
|
301
394
|
output_row[slot_idx] = int(sql_row[i])
|
|
302
395
|
elif e.col_type.is_float_type():
|
|
303
396
|
output_row[slot_idx] = float(sql_row[i])
|
|
304
397
|
else:
|
|
305
398
|
raise RuntimeError(f'Unexpected Decimal value for {e}')
|
|
399
|
+
elif is_using_cockroachdb and isinstance(sql_row[i], datetime.datetime):
|
|
400
|
+
# Ensure that the datetime is timezone-aware and in the session time zone
|
|
401
|
+
# cockroachDB returns timestamps in the session time zone, with numeric offset,
|
|
402
|
+
# convert to the session time zone with the requested tzinfo for DST handling
|
|
403
|
+
if e.col_type.is_timestamp_type():
|
|
404
|
+
if isinstance(sql_row[i].tzinfo, datetime.timezone):
|
|
405
|
+
output_row[slot_idx] = sql_row[i].astimezone(tz=tzinfo)
|
|
406
|
+
else:
|
|
407
|
+
output_row[slot_idx] = sql_row[i]
|
|
408
|
+
else:
|
|
409
|
+
raise RuntimeError(f'Unexpected datetime value for {e}')
|
|
306
410
|
else:
|
|
307
411
|
output_row[slot_idx] = sql_row[i]
|
|
308
412
|
|
|
@@ -324,7 +428,7 @@ class SqlNode(ExecNode):
|
|
|
324
428
|
if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
|
|
325
429
|
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
326
430
|
yield output_batch
|
|
327
|
-
output_batch = DataRowBatch(
|
|
431
|
+
output_batch = DataRowBatch(self.row_builder)
|
|
328
432
|
|
|
329
433
|
if len(output_batch) > 0:
|
|
330
434
|
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
@@ -340,22 +444,35 @@ class SqlScanNode(SqlNode):
|
|
|
340
444
|
Materializes data from the store via a Select stmt.
|
|
341
445
|
|
|
342
446
|
Supports filtering and ordering.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
select_list: output of the query
|
|
450
|
+
set_pk: if True, sets the primary for each DataRow
|
|
451
|
+
exact_version_only: tables for which we only want to see rows created at the current version
|
|
343
452
|
"""
|
|
344
|
-
|
|
453
|
+
|
|
454
|
+
exact_version_only: list[catalog.TableVersionHandle]
|
|
345
455
|
|
|
346
456
|
def __init__(
|
|
347
|
-
self,
|
|
457
|
+
self,
|
|
458
|
+
tbl: catalog.TableVersionPath,
|
|
459
|
+
row_builder: exprs.RowBuilder,
|
|
348
460
|
select_list: Iterable[exprs.Expr],
|
|
349
|
-
|
|
461
|
+
columns: list[catalog.Column],
|
|
462
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
463
|
+
set_pk: bool = False,
|
|
464
|
+
exact_version_only: list[catalog.TableVersionHandle] | None = None,
|
|
350
465
|
):
|
|
351
|
-
"""
|
|
352
|
-
Args:
|
|
353
|
-
select_list: output of the query
|
|
354
|
-
set_pk: if True, sets the primary for each DataRow
|
|
355
|
-
exact_version_only: tables for which we only want to see rows created at the current version
|
|
356
|
-
"""
|
|
357
466
|
sql_elements = exprs.SqlElementCache()
|
|
358
|
-
super().__init__(
|
|
467
|
+
super().__init__(
|
|
468
|
+
tbl,
|
|
469
|
+
row_builder,
|
|
470
|
+
select_list,
|
|
471
|
+
columns=columns,
|
|
472
|
+
sql_elements=sql_elements,
|
|
473
|
+
set_pk=set_pk,
|
|
474
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
475
|
+
)
|
|
359
476
|
# create Select stmt
|
|
360
477
|
if exact_version_only is None:
|
|
361
478
|
exact_version_only = []
|
|
@@ -367,27 +484,41 @@ class SqlScanNode(SqlNode):
|
|
|
367
484
|
where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
|
|
368
485
|
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
369
486
|
stmt = self.create_from_clause(
|
|
370
|
-
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only}
|
|
487
|
+
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only}
|
|
488
|
+
)
|
|
371
489
|
return stmt
|
|
372
490
|
|
|
373
491
|
|
|
374
492
|
class SqlLookupNode(SqlNode):
|
|
375
493
|
"""
|
|
376
494
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
select_list: output of the query
|
|
498
|
+
sa_key_cols: list of key columns in the store table
|
|
499
|
+
key_vals: list of key values to look up
|
|
377
500
|
"""
|
|
378
501
|
|
|
379
502
|
def __init__(
|
|
380
|
-
self,
|
|
381
|
-
|
|
503
|
+
self,
|
|
504
|
+
tbl: catalog.TableVersionPath,
|
|
505
|
+
row_builder: exprs.RowBuilder,
|
|
506
|
+
select_list: Iterable[exprs.Expr],
|
|
507
|
+
columns: list[catalog.Column],
|
|
508
|
+
sa_key_cols: list[sql.Column],
|
|
509
|
+
key_vals: list[tuple],
|
|
510
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
382
511
|
):
|
|
383
|
-
"""
|
|
384
|
-
Args:
|
|
385
|
-
select_list: output of the query
|
|
386
|
-
sa_key_cols: list of key columns in the store table
|
|
387
|
-
key_vals: list of key values to look up
|
|
388
|
-
"""
|
|
389
512
|
sql_elements = exprs.SqlElementCache()
|
|
390
|
-
super().__init__(
|
|
513
|
+
super().__init__(
|
|
514
|
+
tbl,
|
|
515
|
+
row_builder,
|
|
516
|
+
select_list,
|
|
517
|
+
columns=columns,
|
|
518
|
+
sql_elements=sql_elements,
|
|
519
|
+
set_pk=True,
|
|
520
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
521
|
+
)
|
|
391
522
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
392
523
|
self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
393
524
|
|
|
@@ -401,30 +532,33 @@ class SqlLookupNode(SqlNode):
|
|
|
401
532
|
class SqlAggregationNode(SqlNode):
|
|
402
533
|
"""
|
|
403
534
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
select_list: can contain calls to AggregateFunctions
|
|
538
|
+
group_by_items: list of expressions to group by
|
|
539
|
+
limit: max number of rows to return: None = no limit
|
|
404
540
|
"""
|
|
405
541
|
|
|
406
|
-
group_by_items:
|
|
542
|
+
group_by_items: list[exprs.Expr] | None
|
|
543
|
+
input_cte: sql.CTE | None
|
|
407
544
|
|
|
408
545
|
def __init__(
|
|
409
|
-
self,
|
|
546
|
+
self,
|
|
547
|
+
row_builder: exprs.RowBuilder,
|
|
410
548
|
input: SqlNode,
|
|
411
549
|
select_list: Iterable[exprs.Expr],
|
|
412
|
-
group_by_items:
|
|
413
|
-
limit:
|
|
550
|
+
group_by_items: list[exprs.Expr] | None = None,
|
|
551
|
+
limit: int | None = None,
|
|
552
|
+
exact_version_only: list[catalog.TableVersion] | None = None,
|
|
414
553
|
):
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
select_list: can contain calls to AggregateFunctions
|
|
418
|
-
group_by_items: list of expressions to group by
|
|
419
|
-
limit: max number of rows to return: None = no limit
|
|
420
|
-
"""
|
|
421
|
-
_, input_col_map = input.to_cte()
|
|
554
|
+
assert len(input.cell_md_refs) == 0 # there's no aggregation over json or arrays in SQL
|
|
555
|
+
self.input_cte, input_col_map = input.to_cte()
|
|
422
556
|
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
423
|
-
super().__init__(None, row_builder, select_list, sql_elements)
|
|
557
|
+
super().__init__(None, row_builder, select_list, columns=[], sql_elements=sql_elements)
|
|
424
558
|
self.group_by_items = group_by_items
|
|
425
559
|
|
|
426
560
|
def _create_stmt(self) -> sql.Select:
|
|
427
|
-
stmt = super()._create_stmt()
|
|
561
|
+
stmt = super()._create_stmt().select_from(self.input_cte)
|
|
428
562
|
if self.group_by_items is not None:
|
|
429
563
|
sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
|
|
430
564
|
assert all(e is not None for e in sql_group_by_items)
|
|
@@ -436,12 +570,16 @@ class SqlJoinNode(SqlNode):
|
|
|
436
570
|
"""
|
|
437
571
|
Materializes data from the store via a Select ... From ... that contains joins
|
|
438
572
|
"""
|
|
573
|
+
|
|
439
574
|
input_ctes: list[sql.CTE]
|
|
440
575
|
join_clauses: list['pixeltable.plan.JoinClause']
|
|
441
576
|
|
|
442
577
|
def __init__(
|
|
443
|
-
self,
|
|
444
|
-
|
|
578
|
+
self,
|
|
579
|
+
row_builder: exprs.RowBuilder,
|
|
580
|
+
inputs: Sequence[SqlNode],
|
|
581
|
+
join_clauses: list['pixeltable.plan.JoinClause'],
|
|
582
|
+
select_list: Iterable[exprs.Expr],
|
|
445
583
|
):
|
|
446
584
|
assert len(inputs) > 1
|
|
447
585
|
assert len(inputs) == len(join_clauses) + 1
|
|
@@ -452,20 +590,181 @@ class SqlJoinNode(SqlNode):
|
|
|
452
590
|
input_cte, input_col_map = input_node.to_cte()
|
|
453
591
|
self.input_ctes.append(input_cte)
|
|
454
592
|
sql_elements.extend(input_col_map)
|
|
455
|
-
|
|
593
|
+
cell_md_col_refs = [cell_md_ref.col_ref for input in inputs for cell_md_ref in input.cell_md_refs]
|
|
594
|
+
super().__init__(
|
|
595
|
+
None, row_builder, select_list, columns=[], sql_elements=sql_elements, cell_md_col_refs=cell_md_col_refs
|
|
596
|
+
)
|
|
456
597
|
|
|
457
598
|
def _create_stmt(self) -> sql.Select:
|
|
458
599
|
from pixeltable import plan
|
|
600
|
+
|
|
459
601
|
stmt = super()._create_stmt()
|
|
460
602
|
stmt = stmt.select_from(self.input_ctes[0])
|
|
461
603
|
for i in range(len(self.join_clauses)):
|
|
462
604
|
join_clause = self.join_clauses[i]
|
|
463
605
|
on_clause = (
|
|
464
|
-
self.sql_elements.get(join_clause.join_predicate)
|
|
606
|
+
self.sql_elements.get(join_clause.join_predicate)
|
|
607
|
+
if join_clause.join_type != plan.JoinType.CROSS
|
|
465
608
|
else sql.sql.expression.literal(True)
|
|
466
609
|
)
|
|
467
|
-
is_outer = join_clause.join_type
|
|
610
|
+
is_outer = join_clause.join_type in (plan.JoinType.LEFT, plan.JoinType.FULL_OUTER)
|
|
468
611
|
stmt = stmt.join(
|
|
469
|
-
self.input_ctes[i + 1],
|
|
470
|
-
|
|
471
|
-
|
|
612
|
+
self.input_ctes[i + 1],
|
|
613
|
+
onclause=on_clause,
|
|
614
|
+
isouter=is_outer,
|
|
615
|
+
full=join_clause == plan.JoinType.FULL_OUTER,
|
|
616
|
+
)
|
|
617
|
+
return stmt
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class SqlSampleNode(SqlNode):
|
|
621
|
+
"""
|
|
622
|
+
Returns rows sampled from the input node.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
input: SqlNode to sample from
|
|
626
|
+
select_list: can contain calls to AggregateFunctions
|
|
627
|
+
sample_clause: specifies the sampling method
|
|
628
|
+
stratify_exprs: Analyzer processed list of expressions to stratify by.
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
input_cte: sql.CTE | None
|
|
632
|
+
pk_count: int
|
|
633
|
+
stratify_exprs: list[exprs.Expr] | None
|
|
634
|
+
sample_clause: 'SampleClause'
|
|
635
|
+
|
|
636
|
+
def __init__(
|
|
637
|
+
self,
|
|
638
|
+
row_builder: exprs.RowBuilder,
|
|
639
|
+
input: SqlNode,
|
|
640
|
+
select_list: Iterable[exprs.Expr],
|
|
641
|
+
sample_clause: 'SampleClause',
|
|
642
|
+
stratify_exprs: list[exprs.Expr],
|
|
643
|
+
):
|
|
644
|
+
assert isinstance(input, SqlNode)
|
|
645
|
+
self.input_cte, input_col_map = input.to_cte(keep_pk=True)
|
|
646
|
+
self.pk_count = input.num_pk_cols
|
|
647
|
+
assert self.pk_count > 1
|
|
648
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
649
|
+
assert sql_elements.contains_all(stratify_exprs)
|
|
650
|
+
cell_md_col_refs = [cell_md_ref.col_ref for cell_md_ref in input.cell_md_refs]
|
|
651
|
+
super().__init__(
|
|
652
|
+
input.tbl,
|
|
653
|
+
row_builder,
|
|
654
|
+
select_list,
|
|
655
|
+
columns=[],
|
|
656
|
+
sql_elements=sql_elements,
|
|
657
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
658
|
+
set_pk=True,
|
|
659
|
+
)
|
|
660
|
+
self.stratify_exprs = stratify_exprs
|
|
661
|
+
self.sample_clause = sample_clause
|
|
662
|
+
|
|
663
|
+
@classmethod
|
|
664
|
+
def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
|
|
665
|
+
"""Construct expression which is the ordering key for rows to be sampled
|
|
666
|
+
General SQL form is:
|
|
667
|
+
- MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
|
|
668
|
+
"""
|
|
669
|
+
sql_expr: sql.ColumnElement = seed.cast(sql.String)
|
|
670
|
+
for e in sql_cols:
|
|
671
|
+
# Quotes are required below to guarantee that the string is properly presented in SQL
|
|
672
|
+
sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + e.cast(sql.String)
|
|
673
|
+
sql_expr = sql.func.md5(sql_expr)
|
|
674
|
+
return sql_expr
|
|
675
|
+
|
|
676
|
+
def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
|
|
677
|
+
"""Create an expression for randomly ordering rows with a given seed"""
|
|
678
|
+
rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
|
|
679
|
+
assert len(rowid_cols) > 0
|
|
680
|
+
# If seed is not set in the sample clause, use the random seed given by the execution context
|
|
681
|
+
seed = self.sample_clause.seed if self.sample_clause.seed is not None else self.ctx.random_seed
|
|
682
|
+
return self.key_sql_expr(sql.literal_column(str(seed)), rowid_cols)
|
|
683
|
+
|
|
684
|
+
def _create_stmt(self) -> sql.Select:
|
|
685
|
+
from pixeltable.plan import SampleClause
|
|
686
|
+
|
|
687
|
+
self._init_exec_state()
|
|
688
|
+
|
|
689
|
+
if self.sample_clause.fraction is not None:
|
|
690
|
+
if len(self.stratify_exprs) == 0:
|
|
691
|
+
# If non-stratified sampling, construct a where clause, order_by, and limit clauses
|
|
692
|
+
s_key = self._create_key_sql(self.input_cte)
|
|
693
|
+
|
|
694
|
+
# Construct a suitable where clause
|
|
695
|
+
fraction_md5 = SampleClause.fraction_to_md5_hex(self.sample_clause.fraction)
|
|
696
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
697
|
+
return sql.select(*self.input_cte.c).where(s_key < fraction_md5).order_by(order_by)
|
|
698
|
+
|
|
699
|
+
return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
|
|
700
|
+
else:
|
|
701
|
+
if len(self.stratify_exprs) == 0:
|
|
702
|
+
# No stratification, just return n samples from the input CTE
|
|
703
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
704
|
+
return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
|
|
705
|
+
|
|
706
|
+
return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
|
|
707
|
+
|
|
708
|
+
def _create_stmt_stratified_n(self, n: int | None, n_per_stratum: int | None) -> sql.Select:
|
|
709
|
+
"""Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
|
|
710
|
+
|
|
711
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
712
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
713
|
+
|
|
714
|
+
# Create a list of all columns plus the rank
|
|
715
|
+
# Get all columns from the input CTE dynamically
|
|
716
|
+
select_columns = [*self.input_cte.c]
|
|
717
|
+
select_columns.append(
|
|
718
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
719
|
+
)
|
|
720
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
721
|
+
|
|
722
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
723
|
+
if n_per_stratum is not None:
|
|
724
|
+
return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
|
|
725
|
+
else:
|
|
726
|
+
secondary_order = self._create_key_sql(row_rank_cte)
|
|
727
|
+
return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
|
|
728
|
+
|
|
729
|
+
def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
|
|
730
|
+
"""Create a Select stmt that returns a fraction of the rows per strata"""
|
|
731
|
+
|
|
732
|
+
# Build the strata count CTE
|
|
733
|
+
# Produces a table of the form:
|
|
734
|
+
# (*stratify_exprs, s_s_size)
|
|
735
|
+
# where s_s_size is the number of samples to take from each stratum
|
|
736
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
737
|
+
per_strata_count_cte = (
|
|
738
|
+
sql.select(
|
|
739
|
+
*sql_strata_exprs,
|
|
740
|
+
sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
|
|
741
|
+
)
|
|
742
|
+
.select_from(self.input_cte)
|
|
743
|
+
.group_by(*sql_strata_exprs)
|
|
744
|
+
.cte('per_strata_count_cte')
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Build a CTE that ranks the rows within each stratum
|
|
748
|
+
# Include all columns from the input CTE dynamically
|
|
749
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
750
|
+
select_columns = [*self.input_cte.c]
|
|
751
|
+
select_columns.append(
|
|
752
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
753
|
+
)
|
|
754
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
755
|
+
|
|
756
|
+
# Build the join criterion dynamically to accommodate any number of stratify_by expressions
|
|
757
|
+
join_c = sql.true()
|
|
758
|
+
for col in per_strata_count_cte.c[:-1]:
|
|
759
|
+
join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
|
|
760
|
+
|
|
761
|
+
# Join with per_strata_count_cte to limit returns to the requested fraction of rows
|
|
762
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
763
|
+
stmt = (
|
|
764
|
+
sql.select(*final_columns)
|
|
765
|
+
.select_from(row_rank_cte)
|
|
766
|
+
.join(per_strata_count_cte, join_c)
|
|
767
|
+
.where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
return stmt
|