pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +42 -8
- pixeltable/{dataframe.py → _query.py} +470 -206
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +5 -4
- pixeltable/catalog/catalog.py +1785 -432
- pixeltable/catalog/column.py +190 -113
- pixeltable/catalog/dir.py +2 -4
- pixeltable/catalog/globals.py +19 -46
- pixeltable/catalog/insertable_table.py +191 -98
- pixeltable/catalog/path.py +63 -23
- pixeltable/catalog/schema_object.py +11 -15
- pixeltable/catalog/table.py +843 -436
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +978 -657
- pixeltable/catalog/table_version_handle.py +72 -16
- pixeltable/catalog/table_version_path.py +112 -43
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +134 -90
- pixeltable/config.py +134 -22
- pixeltable/env.py +471 -157
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/__init__.py +4 -1
- pixeltable/exec/aggregation_node.py +7 -8
- pixeltable/exec/cache_prefetch_node.py +83 -110
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +4 -3
- pixeltable/exec/data_row_batch.py +8 -65
- pixeltable/exec/exec_context.py +16 -4
- pixeltable/exec/exec_node.py +13 -36
- pixeltable/exec/expr_eval/evaluators.py +11 -7
- pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
- pixeltable/exec/expr_eval/globals.py +8 -5
- pixeltable/exec/expr_eval/row_buffer.py +1 -2
- pixeltable/exec/expr_eval/schedulers.py +106 -56
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +19 -19
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +16 -9
- pixeltable/exec/sql_node.py +351 -84
- pixeltable/exprs/__init__.py +1 -1
- pixeltable/exprs/arithmetic_expr.py +27 -22
- pixeltable/exprs/array_slice.py +3 -3
- pixeltable/exprs/column_property_ref.py +36 -23
- pixeltable/exprs/column_ref.py +213 -89
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +5 -4
- pixeltable/exprs/data_row.py +164 -54
- pixeltable/exprs/expr.py +70 -44
- pixeltable/exprs/expr_dict.py +3 -3
- pixeltable/exprs/expr_set.py +17 -10
- pixeltable/exprs/function_call.py +100 -40
- pixeltable/exprs/globals.py +2 -2
- pixeltable/exprs/in_predicate.py +4 -4
- pixeltable/exprs/inline_expr.py +18 -32
- pixeltable/exprs/is_null.py +7 -3
- pixeltable/exprs/json_mapper.py +8 -8
- pixeltable/exprs/json_path.py +56 -22
- pixeltable/exprs/literal.py +27 -5
- pixeltable/exprs/method_ref.py +2 -2
- pixeltable/exprs/object_ref.py +2 -2
- pixeltable/exprs/row_builder.py +167 -67
- pixeltable/exprs/rowid_ref.py +25 -10
- pixeltable/exprs/similarity_expr.py +58 -40
- pixeltable/exprs/sql_element_cache.py +4 -4
- pixeltable/exprs/string_op.py +5 -5
- pixeltable/exprs/type_cast.py +3 -5
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/aggregate_function.py +8 -8
- pixeltable/func/callable_function.py +9 -9
- pixeltable/func/expr_template_function.py +17 -11
- pixeltable/func/function.py +18 -20
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/globals.py +2 -3
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +29 -27
- pixeltable/func/signature.py +46 -19
- pixeltable/func/tools.py +31 -13
- pixeltable/func/udf.py +18 -20
- pixeltable/functions/__init__.py +16 -0
- pixeltable/functions/anthropic.py +123 -77
- pixeltable/functions/audio.py +147 -10
- pixeltable/functions/bedrock.py +13 -6
- pixeltable/functions/date.py +7 -4
- pixeltable/functions/deepseek.py +35 -43
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +11 -20
- pixeltable/functions/gemini.py +195 -39
- pixeltable/functions/globals.py +142 -14
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1056 -24
- pixeltable/functions/image.py +115 -57
- pixeltable/functions/json.py +1 -1
- pixeltable/functions/llama_cpp.py +28 -13
- pixeltable/functions/math.py +67 -5
- pixeltable/functions/mistralai.py +18 -55
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +20 -13
- pixeltable/functions/openai.py +240 -226
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +4 -4
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +239 -69
- pixeltable/functions/timestamp.py +16 -16
- pixeltable/functions/together.py +24 -84
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +6 -1
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1515 -107
- pixeltable/functions/vision.py +8 -8
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +16 -8
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/{ext/functions → functions}/yolox.py +2 -4
- pixeltable/globals.py +362 -115
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +28 -22
- pixeltable/index/embedding_index.py +100 -118
- pixeltable/io/__init__.py +4 -2
- pixeltable/io/datarows.py +8 -7
- pixeltable/io/external_store.py +56 -105
- pixeltable/io/fiftyone.py +13 -13
- pixeltable/io/globals.py +31 -30
- pixeltable/io/hf_datasets.py +61 -16
- pixeltable/io/label_studio.py +74 -70
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +21 -12
- pixeltable/io/parquet.py +25 -105
- pixeltable/io/table_data_conduit.py +250 -123
- pixeltable/io/utils.py +4 -4
- pixeltable/iterators/__init__.py +2 -1
- pixeltable/iterators/audio.py +26 -25
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +112 -78
- pixeltable/iterators/image.py +12 -15
- pixeltable/iterators/string.py +11 -4
- pixeltable/iterators/video.py +523 -120
- pixeltable/metadata/__init__.py +14 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_18.py +2 -2
- pixeltable/metadata/converters/convert_19.py +2 -2
- pixeltable/metadata/converters/convert_20.py +2 -2
- pixeltable/metadata/converters/convert_21.py +2 -2
- pixeltable/metadata/converters/convert_22.py +2 -2
- pixeltable/metadata/converters/convert_24.py +2 -2
- pixeltable/metadata/converters/convert_25.py +2 -2
- pixeltable/metadata/converters/convert_26.py +2 -2
- pixeltable/metadata/converters/convert_29.py +4 -4
- pixeltable/metadata/converters/convert_30.py +34 -21
- pixeltable/metadata/converters/convert_34.py +2 -2
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +20 -31
- pixeltable/metadata/notes.py +9 -0
- pixeltable/metadata/schema.py +140 -53
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +382 -115
- pixeltable/share/__init__.py +1 -1
- pixeltable/share/packager.py +547 -83
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +257 -59
- pixeltable/store.py +311 -194
- pixeltable/type_system.py +373 -211
- pixeltable/utils/__init__.py +2 -3
- pixeltable/utils/arrow.py +131 -17
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +6 -6
- pixeltable/utils/code.py +3 -3
- pixeltable/utils/console_output.py +4 -1
- pixeltable/utils/coroutine.py +6 -23
- pixeltable/utils/dbms.py +32 -6
- pixeltable/utils/description_helper.py +4 -5
- pixeltable/utils/documents.py +7 -18
- pixeltable/utils/exception_handler.py +7 -30
- pixeltable/utils/filecache.py +6 -6
- pixeltable/utils/formatter.py +86 -48
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +2 -3
- pixeltable/utils/iceberg.py +1 -2
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +5 -6
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +26 -0
- pixeltable/utils/system.py +30 -0
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -40
- pixeltable/ext/__init__.py +0 -17
- pixeltable/ext/functions/__init__.py +0 -11
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/utils/media_store.py +0 -77
- pixeltable/utils/s3.py +0 -17
- pixeltable-0.3.14.dist-info/METADATA +0 -434
- pixeltable-0.3.14.dist-info/RECORD +0 -186
- pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import logging
|
|
2
3
|
import warnings
|
|
3
4
|
from decimal import Decimal
|
|
4
|
-
from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple,
|
|
5
|
+
from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Sequence
|
|
5
6
|
from uuid import UUID
|
|
6
7
|
|
|
7
8
|
import sqlalchemy as sql
|
|
@@ -14,19 +15,20 @@ from .exec_node import ExecNode
|
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
import pixeltable.plan
|
|
18
|
+
from pixeltable.plan import SampleClause
|
|
17
19
|
|
|
18
20
|
_logger = logging.getLogger('pixeltable')
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class OrderByItem(NamedTuple):
|
|
22
24
|
expr: exprs.Expr
|
|
23
|
-
asc:
|
|
25
|
+
asc: bool | None
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
OrderByClause = list[OrderByItem]
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) ->
|
|
31
|
+
def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> OrderByClause | None:
|
|
30
32
|
"""Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
|
|
31
33
|
Two clauses are compatible if for each of their respective items c1[i] and c2[i]
|
|
32
34
|
a) the exprs are identical and
|
|
@@ -55,60 +57,90 @@ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[Order
|
|
|
55
57
|
|
|
56
58
|
def print_order_by_clause(clause: OrderByClause) -> str:
|
|
57
59
|
return ', '.join(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
for item in clause
|
|
61
|
-
]
|
|
60
|
+
f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
|
|
61
|
+
for item in clause
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
class SqlNode(ExecNode):
|
|
66
66
|
"""
|
|
67
|
-
Materializes data from the store via a
|
|
67
|
+
Materializes data from the store via a SQL statement.
|
|
68
68
|
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
69
|
+
The pk columns are not included in the select list.
|
|
70
|
+
If set_pk is True, they are added to the end of the result set when creating the SQL statement
|
|
71
|
+
so they can always be referenced as cols[-num_pk_cols:] in the result set.
|
|
72
|
+
The pk_columns consist of the rowid columns of the target table followed by the version number.
|
|
73
|
+
|
|
74
|
+
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
75
|
+
SQL-materializable subexpressions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
select_list: output of the query
|
|
79
|
+
set_pk: if True, sets the primary for each DataRow
|
|
69
80
|
"""
|
|
70
81
|
|
|
71
|
-
tbl:
|
|
82
|
+
tbl: catalog.TableVersionPath | None
|
|
72
83
|
select_list: exprs.ExprSet
|
|
84
|
+
columns: list[catalog.Column] # for which columns to populate DataRow.cell_vals/cell_md
|
|
85
|
+
cell_md_refs: list[exprs.ColumnPropertyRef] # of ColumnRefs which also need DataRow.slot_cellmd for evaluation
|
|
73
86
|
set_pk: bool
|
|
74
87
|
num_pk_cols: int
|
|
75
|
-
py_filter:
|
|
76
|
-
py_filter_eval_ctx:
|
|
77
|
-
cte:
|
|
88
|
+
py_filter: exprs.Expr | None # a predicate that can only be run in Python
|
|
89
|
+
py_filter_eval_ctx: exprs.RowBuilder.EvalCtx | None
|
|
90
|
+
cte: sql.CTE | None
|
|
78
91
|
sql_elements: exprs.SqlElementCache
|
|
79
92
|
|
|
93
|
+
# execution state
|
|
94
|
+
sql_select_list_exprs: exprs.ExprSet
|
|
95
|
+
cellmd_item_idxs: exprs.ExprDict[int] # cellmd expr -> idx in sql select list
|
|
96
|
+
column_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
|
|
97
|
+
column_cellmd_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
|
|
98
|
+
result_cursor: sql.engine.CursorResult | None
|
|
99
|
+
|
|
80
100
|
# where_clause/-_element: allow subclass to set one or the other (but not both)
|
|
81
|
-
where_clause:
|
|
82
|
-
where_clause_element:
|
|
101
|
+
where_clause: exprs.Expr | None
|
|
102
|
+
where_clause_element: sql.ColumnElement | None
|
|
83
103
|
|
|
84
104
|
order_by_clause: OrderByClause
|
|
85
|
-
limit:
|
|
105
|
+
limit: int | None
|
|
86
106
|
|
|
87
107
|
def __init__(
|
|
88
108
|
self,
|
|
89
|
-
tbl:
|
|
109
|
+
tbl: catalog.TableVersionPath | None,
|
|
90
110
|
row_builder: exprs.RowBuilder,
|
|
91
111
|
select_list: Iterable[exprs.Expr],
|
|
112
|
+
columns: list[catalog.Column],
|
|
92
113
|
sql_elements: exprs.SqlElementCache,
|
|
114
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
93
115
|
set_pk: bool = False,
|
|
94
116
|
):
|
|
95
|
-
"""
|
|
96
|
-
If row_builder contains references to unstored iter columns, expands the select list to include their
|
|
97
|
-
SQL-materializable subexpressions.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
select_list: output of the query
|
|
101
|
-
set_pk: if True, sets the primary for each DataRow
|
|
102
|
-
"""
|
|
103
117
|
# create Select stmt
|
|
104
118
|
self.sql_elements = sql_elements
|
|
105
119
|
self.tbl = tbl
|
|
120
|
+
self.columns = columns
|
|
121
|
+
if cell_md_col_refs is not None:
|
|
122
|
+
assert all(ref.col.stores_cellmd for ref in cell_md_col_refs)
|
|
123
|
+
self.cell_md_refs = [
|
|
124
|
+
exprs.ColumnPropertyRef(ref, exprs.ColumnPropertyRef.Property.CELLMD) for ref in cell_md_col_refs
|
|
125
|
+
]
|
|
126
|
+
else:
|
|
127
|
+
self.cell_md_refs = []
|
|
106
128
|
self.select_list = exprs.ExprSet(select_list)
|
|
107
|
-
# unstored iter columns: we also need to retrieve whatever is needed to materialize the
|
|
129
|
+
# unstored iter columns: we also need to retrieve whatever is needed to materialize the
|
|
130
|
+
# iter args and stored outputs
|
|
108
131
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
109
132
|
sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
|
|
110
|
-
|
|
111
|
-
|
|
133
|
+
self.select_list.update(sql_subexprs)
|
|
134
|
+
# We query for unstored outputs only if we're not loading a view; when we're loading a view, we are populating
|
|
135
|
+
# those columns, so we need to keep them out of the select list. This isn't a problem, because view loads never
|
|
136
|
+
# need to call set_pos().
|
|
137
|
+
# TODO: This is necessary because create_view_load_plan passes stored output columns to `RowBuilder` via the
|
|
138
|
+
# `columns` parameter (even though they don't appear in `output_exprs`). This causes them to be recorded as
|
|
139
|
+
# expressions in `RowBuilder`, which creates a conflict if we add them here. If `RowBuilder` is restructured
|
|
140
|
+
# to keep them out of `unique_exprs`, then this conditional can be removed.
|
|
141
|
+
if not row_builder.for_view_load:
|
|
142
|
+
for outputs in row_builder.unstored_iter_outputs.values():
|
|
143
|
+
self.select_list.update(outputs)
|
|
112
144
|
super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
|
|
113
145
|
|
|
114
146
|
if tbl is not None:
|
|
@@ -122,8 +154,12 @@ class SqlNode(ExecNode):
|
|
|
122
154
|
# we also need to retrieve the pk columns
|
|
123
155
|
assert tbl is not None
|
|
124
156
|
self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
|
|
157
|
+
assert self.num_pk_cols > 1
|
|
125
158
|
|
|
126
159
|
# additional state
|
|
160
|
+
self.cellmd_item_idxs = exprs.ExprDict()
|
|
161
|
+
self.column_item_idxs = {}
|
|
162
|
+
self.column_cellmd_item_idxs = {}
|
|
127
163
|
self.result_cursor = None
|
|
128
164
|
# the filter is provided by the subclass
|
|
129
165
|
self.py_filter = None
|
|
@@ -134,14 +170,38 @@ class SqlNode(ExecNode):
|
|
|
134
170
|
self.where_clause_element = None
|
|
135
171
|
self.order_by_clause = []
|
|
136
172
|
|
|
137
|
-
|
|
138
|
-
|
|
173
|
+
if self.tbl is not None:
|
|
174
|
+
tv = self.tbl.tbl_version._tbl_version
|
|
175
|
+
if tv is not None:
|
|
176
|
+
assert tv.is_validated
|
|
139
177
|
|
|
140
|
-
|
|
141
|
-
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
178
|
+
def _pk_col_items(self) -> list[sql.Column]:
|
|
142
179
|
if self.set_pk:
|
|
180
|
+
# we need to retrieve the pk columns
|
|
143
181
|
assert self.tbl is not None
|
|
144
|
-
|
|
182
|
+
assert self.tbl.tbl_version.get().is_validated
|
|
183
|
+
return self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
184
|
+
return []
|
|
185
|
+
|
|
186
|
+
def _init_exec_state(self) -> None:
|
|
187
|
+
assert self.sql_elements.contains_all(self.select_list)
|
|
188
|
+
self.sql_select_list_exprs = exprs.ExprSet(self.select_list)
|
|
189
|
+
self.cellmd_item_idxs = exprs.ExprDict((ref, self.sql_select_list_exprs.add(ref)) for ref in self.cell_md_refs)
|
|
190
|
+
column_refs = [exprs.ColumnRef(col) for col in self.columns]
|
|
191
|
+
self.column_item_idxs = {col_ref.col: self.sql_select_list_exprs.add(col_ref) for col_ref in column_refs}
|
|
192
|
+
column_cellmd_refs = [
|
|
193
|
+
exprs.ColumnPropertyRef(col_ref, exprs.ColumnPropertyRef.Property.CELLMD)
|
|
194
|
+
for col_ref in column_refs
|
|
195
|
+
if col_ref.col.stores_cellmd
|
|
196
|
+
]
|
|
197
|
+
self.column_cellmd_item_idxs = {
|
|
198
|
+
cellmd_ref.col_ref.col: self.sql_select_list_exprs.add(cellmd_ref) for cellmd_ref in column_cellmd_refs
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def _create_stmt(self) -> sql.Select:
|
|
202
|
+
"""Create Select from local state"""
|
|
203
|
+
self._init_exec_state()
|
|
204
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.sql_select_list_exprs] + self._pk_col_items()
|
|
145
205
|
stmt = sql.select(*sql_select_list)
|
|
146
206
|
|
|
147
207
|
where_clause_element = (
|
|
@@ -167,9 +227,10 @@ class SqlNode(ExecNode):
|
|
|
167
227
|
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
168
228
|
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
169
229
|
|
|
170
|
-
def to_cte(self) ->
|
|
230
|
+
def to_cte(self, keep_pk: bool = False) -> tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]] | None:
|
|
171
231
|
"""
|
|
172
|
-
|
|
232
|
+
Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
|
|
233
|
+
keep_pk: if True, the PK columns are included in the CTE Select statement
|
|
173
234
|
|
|
174
235
|
Returns:
|
|
175
236
|
(CTE, dict from Expr to output column)
|
|
@@ -177,11 +238,11 @@ class SqlNode(ExecNode):
|
|
|
177
238
|
if self.py_filter is not None:
|
|
178
239
|
# the filter needs to run in Python
|
|
179
240
|
return None
|
|
180
|
-
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
181
241
|
if self.cte is None:
|
|
242
|
+
if not keep_pk:
|
|
243
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
182
244
|
self.cte = self._create_stmt().cte()
|
|
183
|
-
|
|
184
|
-
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
|
|
245
|
+
return self.cte, exprs.ExprDict(zip(list(self.select_list) + self.cell_md_refs, self.cte.c)) # skip pk cols
|
|
185
246
|
|
|
186
247
|
@classmethod
|
|
187
248
|
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
@@ -195,8 +256,8 @@ class SqlNode(ExecNode):
|
|
|
195
256
|
cls,
|
|
196
257
|
tbl: catalog.TableVersionPath,
|
|
197
258
|
stmt: sql.Select,
|
|
198
|
-
refd_tbl_ids:
|
|
199
|
-
exact_version_only:
|
|
259
|
+
refd_tbl_ids: set[UUID] | None = None,
|
|
260
|
+
exact_version_only: set[UUID] | None = None,
|
|
200
261
|
) -> sql.Select:
|
|
201
262
|
"""Add From clause to stmt for tables/views referenced by materialized_exprs
|
|
202
263
|
Args:
|
|
@@ -220,26 +281,29 @@ class SqlNode(ExecNode):
|
|
|
220
281
|
joined_tbls.append(t)
|
|
221
282
|
|
|
222
283
|
first = True
|
|
223
|
-
|
|
284
|
+
prev_tv: catalog.TableVersion | None = None
|
|
224
285
|
for t in joined_tbls[::-1]:
|
|
286
|
+
tv = t.get()
|
|
287
|
+
# _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
|
|
225
288
|
if first:
|
|
226
|
-
stmt = stmt.select_from(
|
|
289
|
+
stmt = stmt.select_from(tv.store_tbl.sa_tbl)
|
|
227
290
|
first = False
|
|
228
291
|
else:
|
|
229
|
-
# join
|
|
230
|
-
prev_tbl_rowid_cols =
|
|
231
|
-
tbl_rowid_cols =
|
|
292
|
+
# join tv to prev_tv on prev_tv's rowid cols
|
|
293
|
+
prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
|
|
294
|
+
tbl_rowid_cols = tv.store_tbl.rowid_columns()
|
|
232
295
|
rowid_clauses = [
|
|
233
296
|
c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
|
|
234
297
|
]
|
|
235
|
-
stmt = stmt.join(
|
|
298
|
+
stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
|
|
299
|
+
|
|
236
300
|
if t.id in exact_version_only:
|
|
237
|
-
stmt = stmt.where(
|
|
301
|
+
stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
|
|
238
302
|
else:
|
|
239
|
-
stmt = stmt.where(
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
303
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
|
|
304
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
|
|
305
|
+
prev_tv = tv
|
|
306
|
+
|
|
243
307
|
return stmt
|
|
244
308
|
|
|
245
309
|
def set_where(self, where_clause: exprs.Expr) -> None:
|
|
@@ -284,7 +348,8 @@ class SqlNode(ExecNode):
|
|
|
284
348
|
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
285
349
|
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
286
350
|
except Exception:
|
|
287
|
-
|
|
351
|
+
# log something if we can't log the compiled stmt
|
|
352
|
+
_logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
|
|
288
353
|
self._log_explain(stmt)
|
|
289
354
|
|
|
290
355
|
conn = Env.get().conn
|
|
@@ -292,28 +357,56 @@ class SqlNode(ExecNode):
|
|
|
292
357
|
for _ in w:
|
|
293
358
|
pass
|
|
294
359
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
output_row: Optional[exprs.DataRow] = None
|
|
360
|
+
output_batch = DataRowBatch(self.row_builder)
|
|
361
|
+
output_row: exprs.DataRow | None = None
|
|
298
362
|
num_rows_returned = 0
|
|
363
|
+
is_using_cockroachdb = Env.get().is_using_cockroachdb
|
|
364
|
+
tzinfo = Env.get().default_time_zone
|
|
299
365
|
|
|
300
366
|
for sql_row in result_cursor:
|
|
301
367
|
output_row = output_batch.add_row(output_row)
|
|
302
368
|
|
|
303
369
|
# populate output_row
|
|
370
|
+
|
|
304
371
|
if self.num_pk_cols > 0:
|
|
305
372
|
output_row.set_pk(tuple(sql_row[-self.num_pk_cols :]))
|
|
373
|
+
|
|
374
|
+
# column copies
|
|
375
|
+
for col, item_idx in self.column_item_idxs.items():
|
|
376
|
+
output_row.cell_vals[col.id] = sql_row[item_idx]
|
|
377
|
+
for col, item_idx in self.column_cellmd_item_idxs.items():
|
|
378
|
+
cell_md_dict = sql_row[item_idx]
|
|
379
|
+
output_row.cell_md[col.id] = exprs.CellMd(**cell_md_dict) if cell_md_dict is not None else None
|
|
380
|
+
|
|
381
|
+
# populate DataRow.slot_cellmd, where requested
|
|
382
|
+
for cellmd_ref, item_idx in self.cellmd_item_idxs.items():
|
|
383
|
+
cell_md_dict = sql_row[item_idx]
|
|
384
|
+
output_row.slot_md[cellmd_ref.col_ref.slot_idx] = (
|
|
385
|
+
exprs.CellMd.from_dict(cell_md_dict) if cell_md_dict is not None else None
|
|
386
|
+
)
|
|
387
|
+
|
|
306
388
|
# copy the output of the SQL query into the output row
|
|
307
389
|
for i, e in enumerate(self.select_list):
|
|
308
390
|
slot_idx = e.slot_idx
|
|
309
|
-
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
310
391
|
if isinstance(sql_row[i], Decimal):
|
|
392
|
+
# certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
|
|
311
393
|
if e.col_type.is_int_type():
|
|
312
394
|
output_row[slot_idx] = int(sql_row[i])
|
|
313
395
|
elif e.col_type.is_float_type():
|
|
314
396
|
output_row[slot_idx] = float(sql_row[i])
|
|
315
397
|
else:
|
|
316
398
|
raise RuntimeError(f'Unexpected Decimal value for {e}')
|
|
399
|
+
elif is_using_cockroachdb and isinstance(sql_row[i], datetime.datetime):
|
|
400
|
+
# Ensure that the datetime is timezone-aware and in the session time zone
|
|
401
|
+
# cockroachDB returns timestamps in the session time zone, with numeric offset,
|
|
402
|
+
# convert to the session time zone with the requested tzinfo for DST handling
|
|
403
|
+
if e.col_type.is_timestamp_type():
|
|
404
|
+
if isinstance(sql_row[i].tzinfo, datetime.timezone):
|
|
405
|
+
output_row[slot_idx] = sql_row[i].astimezone(tz=tzinfo)
|
|
406
|
+
else:
|
|
407
|
+
output_row[slot_idx] = sql_row[i]
|
|
408
|
+
else:
|
|
409
|
+
raise RuntimeError(f'Unexpected datetime value for {e}')
|
|
317
410
|
else:
|
|
318
411
|
output_row[slot_idx] = sql_row[i]
|
|
319
412
|
|
|
@@ -335,7 +428,7 @@ class SqlNode(ExecNode):
|
|
|
335
428
|
if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
|
|
336
429
|
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
337
430
|
yield output_batch
|
|
338
|
-
output_batch = DataRowBatch(
|
|
431
|
+
output_batch = DataRowBatch(self.row_builder)
|
|
339
432
|
|
|
340
433
|
if len(output_batch) > 0:
|
|
341
434
|
_logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
|
|
@@ -351,6 +444,11 @@ class SqlScanNode(SqlNode):
|
|
|
351
444
|
Materializes data from the store via a Select stmt.
|
|
352
445
|
|
|
353
446
|
Supports filtering and ordering.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
select_list: output of the query
|
|
450
|
+
set_pk: if True, sets the primary for each DataRow
|
|
451
|
+
exact_version_only: tables for which we only want to see rows created at the current version
|
|
354
452
|
"""
|
|
355
453
|
|
|
356
454
|
exact_version_only: list[catalog.TableVersionHandle]
|
|
@@ -360,17 +458,21 @@ class SqlScanNode(SqlNode):
|
|
|
360
458
|
tbl: catalog.TableVersionPath,
|
|
361
459
|
row_builder: exprs.RowBuilder,
|
|
362
460
|
select_list: Iterable[exprs.Expr],
|
|
461
|
+
columns: list[catalog.Column],
|
|
462
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
363
463
|
set_pk: bool = False,
|
|
364
|
-
exact_version_only:
|
|
464
|
+
exact_version_only: list[catalog.TableVersionHandle] | None = None,
|
|
365
465
|
):
|
|
366
|
-
"""
|
|
367
|
-
Args:
|
|
368
|
-
select_list: output of the query
|
|
369
|
-
set_pk: if True, sets the primary for each DataRow
|
|
370
|
-
exact_version_only: tables for which we only want to see rows created at the current version
|
|
371
|
-
"""
|
|
372
466
|
sql_elements = exprs.SqlElementCache()
|
|
373
|
-
super().__init__(
|
|
467
|
+
super().__init__(
|
|
468
|
+
tbl,
|
|
469
|
+
row_builder,
|
|
470
|
+
select_list,
|
|
471
|
+
columns=columns,
|
|
472
|
+
sql_elements=sql_elements,
|
|
473
|
+
set_pk=set_pk,
|
|
474
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
475
|
+
)
|
|
374
476
|
# create Select stmt
|
|
375
477
|
if exact_version_only is None:
|
|
376
478
|
exact_version_only = []
|
|
@@ -390,6 +492,11 @@ class SqlScanNode(SqlNode):
|
|
|
390
492
|
class SqlLookupNode(SqlNode):
|
|
391
493
|
"""
|
|
392
494
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
select_list: output of the query
|
|
498
|
+
sa_key_cols: list of key columns in the store table
|
|
499
|
+
key_vals: list of key values to look up
|
|
393
500
|
"""
|
|
394
501
|
|
|
395
502
|
def __init__(
|
|
@@ -397,17 +504,21 @@ class SqlLookupNode(SqlNode):
|
|
|
397
504
|
tbl: catalog.TableVersionPath,
|
|
398
505
|
row_builder: exprs.RowBuilder,
|
|
399
506
|
select_list: Iterable[exprs.Expr],
|
|
507
|
+
columns: list[catalog.Column],
|
|
400
508
|
sa_key_cols: list[sql.Column],
|
|
401
509
|
key_vals: list[tuple],
|
|
510
|
+
cell_md_col_refs: list[exprs.ColumnRef] | None = None,
|
|
402
511
|
):
|
|
403
|
-
"""
|
|
404
|
-
Args:
|
|
405
|
-
select_list: output of the query
|
|
406
|
-
sa_key_cols: list of key columns in the store table
|
|
407
|
-
key_vals: list of key values to look up
|
|
408
|
-
"""
|
|
409
512
|
sql_elements = exprs.SqlElementCache()
|
|
410
|
-
super().__init__(
|
|
513
|
+
super().__init__(
|
|
514
|
+
tbl,
|
|
515
|
+
row_builder,
|
|
516
|
+
select_list,
|
|
517
|
+
columns=columns,
|
|
518
|
+
sql_elements=sql_elements,
|
|
519
|
+
set_pk=True,
|
|
520
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
521
|
+
)
|
|
411
522
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
412
523
|
self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
413
524
|
|
|
@@ -421,29 +532,29 @@ class SqlLookupNode(SqlNode):
|
|
|
421
532
|
class SqlAggregationNode(SqlNode):
|
|
422
533
|
"""
|
|
423
534
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
select_list: can contain calls to AggregateFunctions
|
|
538
|
+
group_by_items: list of expressions to group by
|
|
539
|
+
limit: max number of rows to return: None = no limit
|
|
424
540
|
"""
|
|
425
541
|
|
|
426
|
-
group_by_items:
|
|
427
|
-
input_cte:
|
|
542
|
+
group_by_items: list[exprs.Expr] | None
|
|
543
|
+
input_cte: sql.CTE | None
|
|
428
544
|
|
|
429
545
|
def __init__(
|
|
430
546
|
self,
|
|
431
547
|
row_builder: exprs.RowBuilder,
|
|
432
548
|
input: SqlNode,
|
|
433
549
|
select_list: Iterable[exprs.Expr],
|
|
434
|
-
group_by_items:
|
|
435
|
-
limit:
|
|
436
|
-
exact_version_only:
|
|
550
|
+
group_by_items: list[exprs.Expr] | None = None,
|
|
551
|
+
limit: int | None = None,
|
|
552
|
+
exact_version_only: list[catalog.TableVersion] | None = None,
|
|
437
553
|
):
|
|
438
|
-
|
|
439
|
-
Args:
|
|
440
|
-
select_list: can contain calls to AggregateFunctions
|
|
441
|
-
group_by_items: list of expressions to group by
|
|
442
|
-
limit: max number of rows to return: None = no limit
|
|
443
|
-
"""
|
|
554
|
+
assert len(input.cell_md_refs) == 0 # there's no aggregation over json or arrays in SQL
|
|
444
555
|
self.input_cte, input_col_map = input.to_cte()
|
|
445
556
|
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
446
|
-
super().__init__(None, row_builder, select_list, sql_elements)
|
|
557
|
+
super().__init__(None, row_builder, select_list, columns=[], sql_elements=sql_elements)
|
|
447
558
|
self.group_by_items = group_by_items
|
|
448
559
|
|
|
449
560
|
def _create_stmt(self) -> sql.Select:
|
|
@@ -479,7 +590,10 @@ class SqlJoinNode(SqlNode):
|
|
|
479
590
|
input_cte, input_col_map = input_node.to_cte()
|
|
480
591
|
self.input_ctes.append(input_cte)
|
|
481
592
|
sql_elements.extend(input_col_map)
|
|
482
|
-
|
|
593
|
+
cell_md_col_refs = [cell_md_ref.col_ref for input in inputs for cell_md_ref in input.cell_md_refs]
|
|
594
|
+
super().__init__(
|
|
595
|
+
None, row_builder, select_list, columns=[], sql_elements=sql_elements, cell_md_col_refs=cell_md_col_refs
|
|
596
|
+
)
|
|
483
597
|
|
|
484
598
|
def _create_stmt(self) -> sql.Select:
|
|
485
599
|
from pixeltable import plan
|
|
@@ -501,3 +615,156 @@ class SqlJoinNode(SqlNode):
|
|
|
501
615
|
full=join_clause == plan.JoinType.FULL_OUTER,
|
|
502
616
|
)
|
|
503
617
|
return stmt
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class SqlSampleNode(SqlNode):
|
|
621
|
+
"""
|
|
622
|
+
Returns rows sampled from the input node.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
input: SqlNode to sample from
|
|
626
|
+
select_list: can contain calls to AggregateFunctions
|
|
627
|
+
sample_clause: specifies the sampling method
|
|
628
|
+
stratify_exprs: Analyzer processed list of expressions to stratify by.
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
input_cte: sql.CTE | None
|
|
632
|
+
pk_count: int
|
|
633
|
+
stratify_exprs: list[exprs.Expr] | None
|
|
634
|
+
sample_clause: 'SampleClause'
|
|
635
|
+
|
|
636
|
+
def __init__(
|
|
637
|
+
self,
|
|
638
|
+
row_builder: exprs.RowBuilder,
|
|
639
|
+
input: SqlNode,
|
|
640
|
+
select_list: Iterable[exprs.Expr],
|
|
641
|
+
sample_clause: 'SampleClause',
|
|
642
|
+
stratify_exprs: list[exprs.Expr],
|
|
643
|
+
):
|
|
644
|
+
assert isinstance(input, SqlNode)
|
|
645
|
+
self.input_cte, input_col_map = input.to_cte(keep_pk=True)
|
|
646
|
+
self.pk_count = input.num_pk_cols
|
|
647
|
+
assert self.pk_count > 1
|
|
648
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
649
|
+
assert sql_elements.contains_all(stratify_exprs)
|
|
650
|
+
cell_md_col_refs = [cell_md_ref.col_ref for cell_md_ref in input.cell_md_refs]
|
|
651
|
+
super().__init__(
|
|
652
|
+
input.tbl,
|
|
653
|
+
row_builder,
|
|
654
|
+
select_list,
|
|
655
|
+
columns=[],
|
|
656
|
+
sql_elements=sql_elements,
|
|
657
|
+
cell_md_col_refs=cell_md_col_refs,
|
|
658
|
+
set_pk=True,
|
|
659
|
+
)
|
|
660
|
+
self.stratify_exprs = stratify_exprs
|
|
661
|
+
self.sample_clause = sample_clause
|
|
662
|
+
|
|
663
|
+
@classmethod
|
|
664
|
+
def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
|
|
665
|
+
"""Construct expression which is the ordering key for rows to be sampled
|
|
666
|
+
General SQL form is:
|
|
667
|
+
- MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
|
|
668
|
+
"""
|
|
669
|
+
sql_expr: sql.ColumnElement = seed.cast(sql.String)
|
|
670
|
+
for e in sql_cols:
|
|
671
|
+
# Quotes are required below to guarantee that the string is properly presented in SQL
|
|
672
|
+
sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + e.cast(sql.String)
|
|
673
|
+
sql_expr = sql.func.md5(sql_expr)
|
|
674
|
+
return sql_expr
|
|
675
|
+
|
|
676
|
+
def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
|
|
677
|
+
"""Create an expression for randomly ordering rows with a given seed"""
|
|
678
|
+
rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
|
|
679
|
+
assert len(rowid_cols) > 0
|
|
680
|
+
# If seed is not set in the sample clause, use the random seed given by the execution context
|
|
681
|
+
seed = self.sample_clause.seed if self.sample_clause.seed is not None else self.ctx.random_seed
|
|
682
|
+
return self.key_sql_expr(sql.literal_column(str(seed)), rowid_cols)
|
|
683
|
+
|
|
684
|
+
def _create_stmt(self) -> sql.Select:
|
|
685
|
+
from pixeltable.plan import SampleClause
|
|
686
|
+
|
|
687
|
+
self._init_exec_state()
|
|
688
|
+
|
|
689
|
+
if self.sample_clause.fraction is not None:
|
|
690
|
+
if len(self.stratify_exprs) == 0:
|
|
691
|
+
# If non-stratified sampling, construct a where clause, order_by, and limit clauses
|
|
692
|
+
s_key = self._create_key_sql(self.input_cte)
|
|
693
|
+
|
|
694
|
+
# Construct a suitable where clause
|
|
695
|
+
fraction_md5 = SampleClause.fraction_to_md5_hex(self.sample_clause.fraction)
|
|
696
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
697
|
+
return sql.select(*self.input_cte.c).where(s_key < fraction_md5).order_by(order_by)
|
|
698
|
+
|
|
699
|
+
return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
|
|
700
|
+
else:
|
|
701
|
+
if len(self.stratify_exprs) == 0:
|
|
702
|
+
# No stratification, just return n samples from the input CTE
|
|
703
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
704
|
+
return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
|
|
705
|
+
|
|
706
|
+
return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
|
|
707
|
+
|
|
708
|
+
def _create_stmt_stratified_n(self, n: int | None, n_per_stratum: int | None) -> sql.Select:
|
|
709
|
+
"""Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
|
|
710
|
+
|
|
711
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
712
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
713
|
+
|
|
714
|
+
# Create a list of all columns plus the rank
|
|
715
|
+
# Get all columns from the input CTE dynamically
|
|
716
|
+
select_columns = [*self.input_cte.c]
|
|
717
|
+
select_columns.append(
|
|
718
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
719
|
+
)
|
|
720
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
721
|
+
|
|
722
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
723
|
+
if n_per_stratum is not None:
|
|
724
|
+
return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
|
|
725
|
+
else:
|
|
726
|
+
secondary_order = self._create_key_sql(row_rank_cte)
|
|
727
|
+
return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
|
|
728
|
+
|
|
729
|
+
def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
|
|
730
|
+
"""Create a Select stmt that returns a fraction of the rows per strata"""
|
|
731
|
+
|
|
732
|
+
# Build the strata count CTE
|
|
733
|
+
# Produces a table of the form:
|
|
734
|
+
# (*stratify_exprs, s_s_size)
|
|
735
|
+
# where s_s_size is the number of samples to take from each stratum
|
|
736
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
737
|
+
per_strata_count_cte = (
|
|
738
|
+
sql.select(
|
|
739
|
+
*sql_strata_exprs,
|
|
740
|
+
sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
|
|
741
|
+
)
|
|
742
|
+
.select_from(self.input_cte)
|
|
743
|
+
.group_by(*sql_strata_exprs)
|
|
744
|
+
.cte('per_strata_count_cte')
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Build a CTE that ranks the rows within each stratum
|
|
748
|
+
# Include all columns from the input CTE dynamically
|
|
749
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
750
|
+
select_columns = [*self.input_cte.c]
|
|
751
|
+
select_columns.append(
|
|
752
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
753
|
+
)
|
|
754
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
755
|
+
|
|
756
|
+
# Build the join criterion dynamically to accommodate any number of stratify_by expressions
|
|
757
|
+
join_c = sql.true()
|
|
758
|
+
for col in per_strata_count_cte.c[:-1]:
|
|
759
|
+
join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
|
|
760
|
+
|
|
761
|
+
# Join with per_strata_count_cte to limit returns to the requested fraction of rows
|
|
762
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
763
|
+
stmt = (
|
|
764
|
+
sql.select(*final_columns)
|
|
765
|
+
.select_from(row_rank_cte)
|
|
766
|
+
.join(per_strata_count_cte, join_c)
|
|
767
|
+
.where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
return stmt
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ from .column_property_ref import ColumnPropertyRef
|
|
|
6
6
|
from .column_ref import ColumnRef
|
|
7
7
|
from .comparison import Comparison
|
|
8
8
|
from .compound_predicate import CompoundPredicate
|
|
9
|
-
from .data_row import DataRow
|
|
9
|
+
from .data_row import ArrayMd, BinaryMd, CellMd, DataRow
|
|
10
10
|
from .expr import Expr
|
|
11
11
|
from .expr_dict import ExprDict
|
|
12
12
|
from .expr_set import ExprSet
|