pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +42 -8
- pixeltable/{dataframe.py → _query.py} +470 -206
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +5 -4
- pixeltable/catalog/catalog.py +1785 -432
- pixeltable/catalog/column.py +190 -113
- pixeltable/catalog/dir.py +2 -4
- pixeltable/catalog/globals.py +19 -46
- pixeltable/catalog/insertable_table.py +191 -98
- pixeltable/catalog/path.py +63 -23
- pixeltable/catalog/schema_object.py +11 -15
- pixeltable/catalog/table.py +843 -436
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +978 -657
- pixeltable/catalog/table_version_handle.py +72 -16
- pixeltable/catalog/table_version_path.py +112 -43
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +134 -90
- pixeltable/config.py +134 -22
- pixeltable/env.py +471 -157
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/__init__.py +4 -1
- pixeltable/exec/aggregation_node.py +7 -8
- pixeltable/exec/cache_prefetch_node.py +83 -110
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +4 -3
- pixeltable/exec/data_row_batch.py +8 -65
- pixeltable/exec/exec_context.py +16 -4
- pixeltable/exec/exec_node.py +13 -36
- pixeltable/exec/expr_eval/evaluators.py +11 -7
- pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
- pixeltable/exec/expr_eval/globals.py +8 -5
- pixeltable/exec/expr_eval/row_buffer.py +1 -2
- pixeltable/exec/expr_eval/schedulers.py +106 -56
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +19 -19
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +16 -9
- pixeltable/exec/sql_node.py +351 -84
- pixeltable/exprs/__init__.py +1 -1
- pixeltable/exprs/arithmetic_expr.py +27 -22
- pixeltable/exprs/array_slice.py +3 -3
- pixeltable/exprs/column_property_ref.py +36 -23
- pixeltable/exprs/column_ref.py +213 -89
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +5 -4
- pixeltable/exprs/data_row.py +164 -54
- pixeltable/exprs/expr.py +70 -44
- pixeltable/exprs/expr_dict.py +3 -3
- pixeltable/exprs/expr_set.py +17 -10
- pixeltable/exprs/function_call.py +100 -40
- pixeltable/exprs/globals.py +2 -2
- pixeltable/exprs/in_predicate.py +4 -4
- pixeltable/exprs/inline_expr.py +18 -32
- pixeltable/exprs/is_null.py +7 -3
- pixeltable/exprs/json_mapper.py +8 -8
- pixeltable/exprs/json_path.py +56 -22
- pixeltable/exprs/literal.py +27 -5
- pixeltable/exprs/method_ref.py +2 -2
- pixeltable/exprs/object_ref.py +2 -2
- pixeltable/exprs/row_builder.py +167 -67
- pixeltable/exprs/rowid_ref.py +25 -10
- pixeltable/exprs/similarity_expr.py +58 -40
- pixeltable/exprs/sql_element_cache.py +4 -4
- pixeltable/exprs/string_op.py +5 -5
- pixeltable/exprs/type_cast.py +3 -5
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/aggregate_function.py +8 -8
- pixeltable/func/callable_function.py +9 -9
- pixeltable/func/expr_template_function.py +17 -11
- pixeltable/func/function.py +18 -20
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/globals.py +2 -3
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +29 -27
- pixeltable/func/signature.py +46 -19
- pixeltable/func/tools.py +31 -13
- pixeltable/func/udf.py +18 -20
- pixeltable/functions/__init__.py +16 -0
- pixeltable/functions/anthropic.py +123 -77
- pixeltable/functions/audio.py +147 -10
- pixeltable/functions/bedrock.py +13 -6
- pixeltable/functions/date.py +7 -4
- pixeltable/functions/deepseek.py +35 -43
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +11 -20
- pixeltable/functions/gemini.py +195 -39
- pixeltable/functions/globals.py +142 -14
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1056 -24
- pixeltable/functions/image.py +115 -57
- pixeltable/functions/json.py +1 -1
- pixeltable/functions/llama_cpp.py +28 -13
- pixeltable/functions/math.py +67 -5
- pixeltable/functions/mistralai.py +18 -55
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +20 -13
- pixeltable/functions/openai.py +240 -226
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +4 -4
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +239 -69
- pixeltable/functions/timestamp.py +16 -16
- pixeltable/functions/together.py +24 -84
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +6 -1
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1515 -107
- pixeltable/functions/vision.py +8 -8
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +16 -8
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/{ext/functions → functions}/yolox.py +2 -4
- pixeltable/globals.py +362 -115
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +28 -22
- pixeltable/index/embedding_index.py +100 -118
- pixeltable/io/__init__.py +4 -2
- pixeltable/io/datarows.py +8 -7
- pixeltable/io/external_store.py +56 -105
- pixeltable/io/fiftyone.py +13 -13
- pixeltable/io/globals.py +31 -30
- pixeltable/io/hf_datasets.py +61 -16
- pixeltable/io/label_studio.py +74 -70
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +21 -12
- pixeltable/io/parquet.py +25 -105
- pixeltable/io/table_data_conduit.py +250 -123
- pixeltable/io/utils.py +4 -4
- pixeltable/iterators/__init__.py +2 -1
- pixeltable/iterators/audio.py +26 -25
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +112 -78
- pixeltable/iterators/image.py +12 -15
- pixeltable/iterators/string.py +11 -4
- pixeltable/iterators/video.py +523 -120
- pixeltable/metadata/__init__.py +14 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_18.py +2 -2
- pixeltable/metadata/converters/convert_19.py +2 -2
- pixeltable/metadata/converters/convert_20.py +2 -2
- pixeltable/metadata/converters/convert_21.py +2 -2
- pixeltable/metadata/converters/convert_22.py +2 -2
- pixeltable/metadata/converters/convert_24.py +2 -2
- pixeltable/metadata/converters/convert_25.py +2 -2
- pixeltable/metadata/converters/convert_26.py +2 -2
- pixeltable/metadata/converters/convert_29.py +4 -4
- pixeltable/metadata/converters/convert_30.py +34 -21
- pixeltable/metadata/converters/convert_34.py +2 -2
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +20 -31
- pixeltable/metadata/notes.py +9 -0
- pixeltable/metadata/schema.py +140 -53
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +382 -115
- pixeltable/share/__init__.py +1 -1
- pixeltable/share/packager.py +547 -83
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +257 -59
- pixeltable/store.py +311 -194
- pixeltable/type_system.py +373 -211
- pixeltable/utils/__init__.py +2 -3
- pixeltable/utils/arrow.py +131 -17
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +6 -6
- pixeltable/utils/code.py +3 -3
- pixeltable/utils/console_output.py +4 -1
- pixeltable/utils/coroutine.py +6 -23
- pixeltable/utils/dbms.py +32 -6
- pixeltable/utils/description_helper.py +4 -5
- pixeltable/utils/documents.py +7 -18
- pixeltable/utils/exception_handler.py +7 -30
- pixeltable/utils/filecache.py +6 -6
- pixeltable/utils/formatter.py +86 -48
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +2 -3
- pixeltable/utils/iceberg.py +1 -2
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +5 -6
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +26 -0
- pixeltable/utils/system.py +30 -0
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -40
- pixeltable/ext/__init__.py +0 -17
- pixeltable/ext/functions/__init__.py +0 -11
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/utils/media_store.py +0 -77
- pixeltable/utils/s3.py +0 -17
- pixeltable-0.3.14.dist-info/METADATA +0 -434
- pixeltable-0.3.14.dist-info/RECORD +0 -186
- pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
- {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
|
@@ -8,15 +8,17 @@ import json
|
|
|
8
8
|
import logging
|
|
9
9
|
import traceback
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Hashable, Iterator, NoReturn,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Hashable, Iterator, NoReturn, Sequence, TypeVar
|
|
12
12
|
|
|
13
13
|
import pandas as pd
|
|
14
|
-
import
|
|
14
|
+
import pydantic
|
|
15
|
+
import sqlalchemy.exc as sql_exc
|
|
15
16
|
|
|
16
17
|
from pixeltable import catalog, exceptions as excs, exec, exprs, plan, type_system as ts
|
|
17
|
-
from pixeltable.catalog import is_valid_identifier
|
|
18
|
-
from pixeltable.catalog.
|
|
18
|
+
from pixeltable.catalog import Catalog, is_valid_identifier
|
|
19
|
+
from pixeltable.catalog.update_status import UpdateStatus
|
|
19
20
|
from pixeltable.env import Env
|
|
21
|
+
from pixeltable.plan import Planner, SampleClause
|
|
20
22
|
from pixeltable.type_system import ColumnType
|
|
21
23
|
from pixeltable.utils.description_helper import DescriptionHelper
|
|
22
24
|
from pixeltable.utils.formatter import Formatter
|
|
@@ -25,12 +27,17 @@ if TYPE_CHECKING:
|
|
|
25
27
|
import torch
|
|
26
28
|
import torch.utils.data
|
|
27
29
|
|
|
28
|
-
__all__ = ['
|
|
30
|
+
__all__ = ['Query']
|
|
29
31
|
|
|
30
32
|
_logger = logging.getLogger('pixeltable')
|
|
31
33
|
|
|
32
34
|
|
|
33
|
-
class
|
|
35
|
+
class ResultSet:
|
|
36
|
+
_rows: list[list[Any]]
|
|
37
|
+
_col_names: list[str]
|
|
38
|
+
__schema: dict[str, ColumnType]
|
|
39
|
+
__formatter: Formatter
|
|
40
|
+
|
|
34
41
|
def __init__(self, rows: list[list[Any]], schema: dict[str, ColumnType]):
|
|
35
42
|
self._rows = rows
|
|
36
43
|
self._col_names = list(schema.keys())
|
|
@@ -65,6 +72,44 @@ class DataFrameResultSet:
|
|
|
65
72
|
def to_pandas(self) -> pd.DataFrame:
|
|
66
73
|
return pd.DataFrame.from_records(self._rows, columns=self._col_names)
|
|
67
74
|
|
|
75
|
+
BaseModelT = TypeVar('BaseModelT', bound=pydantic.BaseModel)
|
|
76
|
+
|
|
77
|
+
def to_pydantic(self, model: type[BaseModelT]) -> Iterator[BaseModelT]:
|
|
78
|
+
"""
|
|
79
|
+
Convert the ResultSet to a list of Pydantic model instances.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model: A Pydantic model class.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
An iterator over Pydantic model instances, one for each row in the result set.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
Error: If the row data doesn't match the model schema.
|
|
89
|
+
"""
|
|
90
|
+
model_fields = model.model_fields
|
|
91
|
+
model_config = getattr(model, 'model_config', {})
|
|
92
|
+
forbid_extra_fields = model_config.get('extra') == 'forbid'
|
|
93
|
+
|
|
94
|
+
# schema validation
|
|
95
|
+
required_fields = {name for name, field in model_fields.items() if field.is_required()}
|
|
96
|
+
col_names = set(self._col_names)
|
|
97
|
+
missing_fields = required_fields - col_names
|
|
98
|
+
if len(missing_fields) > 0:
|
|
99
|
+
raise excs.Error(
|
|
100
|
+
f'Required model fields {missing_fields} are missing from result set columns {self._col_names}'
|
|
101
|
+
)
|
|
102
|
+
if forbid_extra_fields:
|
|
103
|
+
extra_fields = col_names - set(model_fields.keys())
|
|
104
|
+
if len(extra_fields) > 0:
|
|
105
|
+
raise excs.Error(f"Extra fields {extra_fields} are not allowed in model with extra='forbid'")
|
|
106
|
+
|
|
107
|
+
for row in self:
|
|
108
|
+
try:
|
|
109
|
+
yield model(**row)
|
|
110
|
+
except pydantic.ValidationError as e:
|
|
111
|
+
raise excs.Error(str(e)) from e
|
|
112
|
+
|
|
68
113
|
def _row_to_dict(self, row_idx: int) -> dict[str, Any]:
|
|
69
114
|
return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
|
|
70
115
|
|
|
@@ -89,7 +134,7 @@ class DataFrameResultSet:
|
|
|
89
134
|
return (self._row_to_dict(i) for i in range(len(self)))
|
|
90
135
|
|
|
91
136
|
def __eq__(self, other: object) -> bool:
|
|
92
|
-
if not isinstance(other,
|
|
137
|
+
if not isinstance(other, ResultSet):
|
|
93
138
|
return False
|
|
94
139
|
return self.to_pandas().equals(other.to_pandas())
|
|
95
140
|
|
|
@@ -106,14 +151,14 @@ class DataFrameResultSet:
|
|
|
106
151
|
# # output of the agg stage
|
|
107
152
|
# self.agg_output_exprs: list[exprs.Expr] = []
|
|
108
153
|
# # Where clause of the Select stmt of the SQL scan stage
|
|
109
|
-
# self.sql_where_clause:
|
|
154
|
+
# self.sql_where_clause: sql.ClauseElement | None = None
|
|
110
155
|
# # filter predicate applied to input rows of the SQL scan stage
|
|
111
|
-
# self.filter:
|
|
112
|
-
# self.similarity_clause:
|
|
156
|
+
# self.filter: exprs.Predicate | None = None
|
|
157
|
+
# self.similarity_clause: exprs.ImageSimilarityPredicate | None = None
|
|
113
158
|
# self.agg_fn_calls: list[exprs.FunctionCall] = [] # derived from unique_exprs
|
|
114
159
|
# self.has_frame_col: bool = False # True if we're referencing the frame col
|
|
115
160
|
#
|
|
116
|
-
# self.evaluator:
|
|
161
|
+
# self.evaluator: exprs.Evaluator | None = None
|
|
117
162
|
# self.sql_scan_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of SQL scan stage
|
|
118
163
|
# self.agg_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of agg stage
|
|
119
164
|
# self.filter_eval_ctx: list[exprs.Expr] = []
|
|
@@ -129,32 +174,36 @@ class DataFrameResultSet:
|
|
|
129
174
|
# self.filter.release()
|
|
130
175
|
|
|
131
176
|
|
|
132
|
-
class
|
|
177
|
+
class Query:
|
|
178
|
+
"""Represents a query for retrieving and transforming data from Pixeltable tables."""
|
|
179
|
+
|
|
133
180
|
_from_clause: plan.FromClause
|
|
134
181
|
_select_list_exprs: list[exprs.Expr]
|
|
135
182
|
_schema: dict[str, ts.ColumnType]
|
|
136
|
-
select_list:
|
|
137
|
-
where_clause:
|
|
138
|
-
group_by_clause:
|
|
139
|
-
grouping_tbl:
|
|
140
|
-
order_by_clause:
|
|
141
|
-
limit_val:
|
|
183
|
+
select_list: list[tuple[exprs.Expr, str | None]] | None
|
|
184
|
+
where_clause: exprs.Expr | None
|
|
185
|
+
group_by_clause: list[exprs.Expr] | None
|
|
186
|
+
grouping_tbl: catalog.TableVersion | None
|
|
187
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None
|
|
188
|
+
limit_val: exprs.Expr | None
|
|
189
|
+
sample_clause: SampleClause | None
|
|
142
190
|
|
|
143
191
|
def __init__(
|
|
144
192
|
self,
|
|
145
|
-
from_clause:
|
|
146
|
-
select_list:
|
|
147
|
-
where_clause:
|
|
148
|
-
group_by_clause:
|
|
149
|
-
grouping_tbl:
|
|
150
|
-
order_by_clause:
|
|
151
|
-
limit:
|
|
193
|
+
from_clause: plan.FromClause | None = None,
|
|
194
|
+
select_list: list[tuple[exprs.Expr, str | None]] | None = None,
|
|
195
|
+
where_clause: exprs.Expr | None = None,
|
|
196
|
+
group_by_clause: list[exprs.Expr] | None = None,
|
|
197
|
+
grouping_tbl: catalog.TableVersion | None = None,
|
|
198
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None = None, # list[(expr, asc)]
|
|
199
|
+
limit: exprs.Expr | None = None,
|
|
200
|
+
sample_clause: SampleClause | None = None,
|
|
152
201
|
):
|
|
153
202
|
self._from_clause = from_clause
|
|
154
203
|
|
|
155
204
|
# exprs contain execution state and therefore cannot be shared
|
|
156
205
|
select_list = copy.deepcopy(select_list)
|
|
157
|
-
select_list_exprs, column_names =
|
|
206
|
+
select_list_exprs, column_names = Query._normalize_select_list(self._from_clause.tbls, select_list)
|
|
158
207
|
# check select list after expansion to catch early
|
|
159
208
|
# the following two lists are always non empty, even if select list is None.
|
|
160
209
|
assert len(column_names) == len(select_list_exprs)
|
|
@@ -168,10 +217,11 @@ class DataFrame:
|
|
|
168
217
|
self.grouping_tbl = grouping_tbl
|
|
169
218
|
self.order_by_clause = copy.deepcopy(order_by_clause)
|
|
170
219
|
self.limit_val = limit
|
|
220
|
+
self.sample_clause = sample_clause
|
|
171
221
|
|
|
172
222
|
@classmethod
|
|
173
223
|
def _normalize_select_list(
|
|
174
|
-
cls, tbls: list[catalog.TableVersionPath], select_list:
|
|
224
|
+
cls, tbls: list[catalog.TableVersionPath], select_list: list[tuple[exprs.Expr, str | None]] | None
|
|
175
225
|
) -> tuple[list[exprs.Expr], list[str]]:
|
|
176
226
|
"""
|
|
177
227
|
Expand select list information with all columns and their names
|
|
@@ -210,12 +260,11 @@ class DataFrame:
|
|
|
210
260
|
|
|
211
261
|
@property
|
|
212
262
|
def _first_tbl(self) -> catalog.TableVersionPath:
|
|
213
|
-
|
|
214
|
-
return self._from_clause.tbls[0]
|
|
263
|
+
return self._from_clause._first_tbl
|
|
215
264
|
|
|
216
265
|
def _vars(self) -> dict[str, exprs.Variable]:
|
|
217
266
|
"""
|
|
218
|
-
Return a dict mapping variable name to Variable for all Variables contained in any component of the
|
|
267
|
+
Return a dict mapping variable name to Variable for all Variables contained in any component of the Query
|
|
219
268
|
"""
|
|
220
269
|
all_exprs: list[exprs.Expr] = []
|
|
221
270
|
all_exprs.extend(self._select_list_exprs)
|
|
@@ -233,19 +282,49 @@ class DataFrame:
|
|
|
233
282
|
if var.name not in unique_vars:
|
|
234
283
|
unique_vars[var.name] = var
|
|
235
284
|
elif unique_vars[var.name].col_type != var.col_type:
|
|
236
|
-
raise excs.Error(f'Multiple definitions of parameter {var.name}')
|
|
285
|
+
raise excs.Error(f'Multiple definitions of parameter {var.name!r}')
|
|
237
286
|
return unique_vars
|
|
238
287
|
|
|
288
|
+
@classmethod
|
|
289
|
+
def _convert_param_to_typed_expr(
|
|
290
|
+
cls, v: Any, required_type: ts.ColumnType, required: bool, name: str, range: tuple[Any, Any] | None = None
|
|
291
|
+
) -> exprs.Expr | None:
|
|
292
|
+
if v is None:
|
|
293
|
+
if required:
|
|
294
|
+
raise excs.Error(f'{name!r} parameter must be present')
|
|
295
|
+
return v
|
|
296
|
+
v_expr = exprs.Expr.from_object(v)
|
|
297
|
+
if not v_expr.col_type.matches(required_type):
|
|
298
|
+
raise excs.Error(f'{name!r} parameter must be of type `{required_type}`; got `{v_expr.col_type}`')
|
|
299
|
+
if range is not None:
|
|
300
|
+
if not isinstance(v_expr, exprs.Literal):
|
|
301
|
+
raise excs.Error(f'{name!r} parameter must be a constant; got: {v_expr}')
|
|
302
|
+
if range[0] is not None and not (v_expr.val >= range[0]):
|
|
303
|
+
raise excs.Error(f'{name!r} parameter must be >= {range[0]}')
|
|
304
|
+
if range[1] is not None and not (v_expr.val <= range[1]):
|
|
305
|
+
raise excs.Error(f'{name!r} parameter must be <= {range[1]}')
|
|
306
|
+
return v_expr
|
|
307
|
+
|
|
308
|
+
@classmethod
|
|
309
|
+
def validate_constant_type_range(
|
|
310
|
+
cls, v: Any, required_type: ts.ColumnType, required: bool, name: str, range: tuple[Any, Any] | None = None
|
|
311
|
+
) -> Any:
|
|
312
|
+
"""Validate that the given named parameter is a constant of the required type and within the specified range."""
|
|
313
|
+
v_expr = cls._convert_param_to_typed_expr(v, required_type, required, name, range)
|
|
314
|
+
if v_expr is None:
|
|
315
|
+
return None
|
|
316
|
+
return v_expr.val
|
|
317
|
+
|
|
239
318
|
def parameters(self) -> dict[str, ColumnType]:
|
|
240
319
|
"""Return a dict mapping parameter name to parameter type.
|
|
241
320
|
|
|
242
|
-
Parameters are Variables contained in any component of the
|
|
321
|
+
Parameters are Variables contained in any component of the Query.
|
|
243
322
|
"""
|
|
244
323
|
return {name: var.col_type for name, var in self._vars().items()}
|
|
245
324
|
|
|
246
325
|
def _exec(self) -> Iterator[exprs.DataRow]:
|
|
247
326
|
"""Run the query and return rows as a generator.
|
|
248
|
-
This function must not modify the state of the
|
|
327
|
+
This function must not modify the state of the Query, otherwise it breaks dataset caching.
|
|
249
328
|
"""
|
|
250
329
|
plan = self._create_query_plan()
|
|
251
330
|
|
|
@@ -261,7 +340,7 @@ class DataFrame:
|
|
|
261
340
|
|
|
262
341
|
async def _aexec(self) -> AsyncIterator[exprs.DataRow]:
|
|
263
342
|
"""Run the query and return rows as a generator.
|
|
264
|
-
This function must not modify the state of the
|
|
343
|
+
This function must not modify the state of the Query, otherwise it breaks dataset caching.
|
|
265
344
|
"""
|
|
266
345
|
plan = self._create_query_plan()
|
|
267
346
|
plan.open()
|
|
@@ -274,37 +353,44 @@ class DataFrame:
|
|
|
274
353
|
|
|
275
354
|
def _create_query_plan(self) -> exec.ExecNode:
|
|
276
355
|
# construct a group-by clause if we're grouping by a table
|
|
277
|
-
group_by_clause:
|
|
356
|
+
group_by_clause: list[exprs.Expr] | None = None
|
|
278
357
|
if self.grouping_tbl is not None:
|
|
279
358
|
assert self.group_by_clause is None
|
|
280
359
|
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
281
360
|
# the grouping table must be a base of self.tbl
|
|
282
361
|
assert num_rowid_cols <= len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
283
|
-
group_by_clause =
|
|
362
|
+
group_by_clause = self.__rowid_columns(num_rowid_cols)
|
|
284
363
|
elif self.group_by_clause is not None:
|
|
285
364
|
group_by_clause = self.group_by_clause
|
|
286
365
|
|
|
287
366
|
for item in self._select_list_exprs:
|
|
288
367
|
item.bind_rel_paths()
|
|
289
368
|
|
|
290
|
-
return
|
|
369
|
+
return Planner.create_query_plan(
|
|
291
370
|
self._from_clause,
|
|
292
371
|
self._select_list_exprs,
|
|
293
372
|
where_clause=self.where_clause,
|
|
294
373
|
group_by_clause=group_by_clause,
|
|
295
|
-
order_by_clause=self.order_by_clause
|
|
374
|
+
order_by_clause=self.order_by_clause,
|
|
296
375
|
limit=self.limit_val,
|
|
376
|
+
sample_clause=self.sample_clause,
|
|
297
377
|
)
|
|
298
378
|
|
|
379
|
+
def __rowid_columns(self, num_rowid_cols: int | None = None) -> list[exprs.Expr]:
|
|
380
|
+
"""Return list of RowidRef for the given number of associated rowids"""
|
|
381
|
+
return Planner.rowid_columns(self._first_tbl.tbl_version, num_rowid_cols)
|
|
382
|
+
|
|
299
383
|
def _has_joins(self) -> bool:
|
|
300
384
|
return len(self._from_clause.join_clauses) > 0
|
|
301
385
|
|
|
302
|
-
def show(self, n: int = 20) ->
|
|
386
|
+
def show(self, n: int = 20) -> ResultSet:
|
|
387
|
+
if self.sample_clause is not None:
|
|
388
|
+
raise excs.Error('show() cannot be used with sample()')
|
|
303
389
|
assert n is not None
|
|
304
390
|
return self.limit(n).collect()
|
|
305
391
|
|
|
306
|
-
def head(self, n: int = 10) ->
|
|
307
|
-
"""Return the first n rows of the
|
|
392
|
+
def head(self, n: int = 10) -> ResultSet:
|
|
393
|
+
"""Return the first n rows of the Query, in insertion order of the underlying Table.
|
|
308
394
|
|
|
309
395
|
head() is not supported for joins.
|
|
310
396
|
|
|
@@ -312,24 +398,26 @@ class DataFrame:
|
|
|
312
398
|
n: Number of rows to select. Default is 10.
|
|
313
399
|
|
|
314
400
|
Returns:
|
|
315
|
-
A
|
|
401
|
+
A ResultSet with the first n rows of the Query.
|
|
316
402
|
|
|
317
403
|
Raises:
|
|
318
|
-
Error: If the
|
|
319
|
-
if the
|
|
404
|
+
Error: If the Query is the result of a join or
|
|
405
|
+
if the Query has an order_by clause.
|
|
320
406
|
"""
|
|
321
407
|
if self.order_by_clause is not None:
|
|
322
408
|
raise excs.Error('head() cannot be used with order_by()')
|
|
323
409
|
if self._has_joins():
|
|
324
410
|
raise excs.Error('head() not supported for joins')
|
|
411
|
+
if self.sample_clause is not None:
|
|
412
|
+
raise excs.Error('head() cannot be used with sample()')
|
|
325
413
|
if self.group_by_clause is not None:
|
|
326
414
|
raise excs.Error('head() cannot be used with group_by()')
|
|
327
415
|
num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
328
416
|
order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
329
417
|
return self.order_by(*order_by_clause, asc=True).limit(n).collect()
|
|
330
418
|
|
|
331
|
-
def tail(self, n: int = 10) ->
|
|
332
|
-
"""Return the last n rows of the
|
|
419
|
+
def tail(self, n: int = 10) -> ResultSet:
|
|
420
|
+
"""Return the last n rows of the Query, in insertion order of the underlying Table.
|
|
333
421
|
|
|
334
422
|
tail() is not supported for joins.
|
|
335
423
|
|
|
@@ -337,16 +425,18 @@ class DataFrame:
|
|
|
337
425
|
n: Number of rows to select. Default is 10.
|
|
338
426
|
|
|
339
427
|
Returns:
|
|
340
|
-
A
|
|
428
|
+
A ResultSet with the last n rows of the Query.
|
|
341
429
|
|
|
342
430
|
Raises:
|
|
343
|
-
Error: If the
|
|
344
|
-
if the
|
|
431
|
+
Error: If the Query is the result of a join or
|
|
432
|
+
if the Query has an order_by clause.
|
|
345
433
|
"""
|
|
346
434
|
if self.order_by_clause is not None:
|
|
347
435
|
raise excs.Error('tail() cannot be used with order_by()')
|
|
348
436
|
if self._has_joins():
|
|
349
437
|
raise excs.Error('tail() not supported for joins')
|
|
438
|
+
if self.sample_clause is not None:
|
|
439
|
+
raise excs.Error('tail() cannot be used with sample()')
|
|
350
440
|
if self.group_by_clause is not None:
|
|
351
441
|
raise excs.Error('tail() cannot be used with group_by()')
|
|
352
442
|
num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
@@ -357,10 +447,11 @@ class DataFrame:
|
|
|
357
447
|
|
|
358
448
|
@property
|
|
359
449
|
def schema(self) -> dict[str, ColumnType]:
|
|
450
|
+
"""Column names and types in this Query."""
|
|
360
451
|
return self._schema
|
|
361
452
|
|
|
362
|
-
def bind(self, args: dict[str, Any]) ->
|
|
363
|
-
"""Bind arguments to parameters and return a new
|
|
453
|
+
def bind(self, args: dict[str, Any]) -> Query:
|
|
454
|
+
"""Bind arguments to parameters and return a new Query."""
|
|
364
455
|
# substitute Variables with the corresponding values according to 'args', converted to Literals
|
|
365
456
|
select_list_exprs = copy.deepcopy(self._select_list_exprs)
|
|
366
457
|
where_clause = copy.deepcopy(self.where_clause)
|
|
@@ -381,7 +472,7 @@ class DataFrame:
|
|
|
381
472
|
var_expr = vars[arg_name]
|
|
382
473
|
arg_expr = exprs.Expr.from_object(arg_val)
|
|
383
474
|
if arg_expr is None:
|
|
384
|
-
raise excs.Error(f'
|
|
475
|
+
raise excs.Error(f'That argument cannot be converted to a Pixeltable expression: {arg_val}')
|
|
385
476
|
var_exprs[var_expr] = arg_expr
|
|
386
477
|
|
|
387
478
|
exprs.Expr.list_substitute(select_list_exprs, var_exprs)
|
|
@@ -393,7 +484,7 @@ class DataFrame:
|
|
|
393
484
|
exprs.Expr.list_substitute(order_by_exprs, var_exprs)
|
|
394
485
|
|
|
395
486
|
select_list = list(zip(select_list_exprs, self.schema.keys()))
|
|
396
|
-
order_by_clause:
|
|
487
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None = None
|
|
397
488
|
if order_by_exprs is not None:
|
|
398
489
|
order_by_clause = [
|
|
399
490
|
(expr, asc) for expr, asc in zip(order_by_exprs, [asc for _, asc in self.order_by_clause])
|
|
@@ -401,9 +492,9 @@ class DataFrame:
|
|
|
401
492
|
if limit_val is not None:
|
|
402
493
|
limit_val = limit_val.substitute(var_exprs)
|
|
403
494
|
if limit_val is not None and not isinstance(limit_val, exprs.Literal):
|
|
404
|
-
raise excs.Error(f'limit(): parameter must be a constant
|
|
495
|
+
raise excs.Error(f'limit(): parameter must be a constant; got: {limit_val}')
|
|
405
496
|
|
|
406
|
-
return
|
|
497
|
+
return Query(
|
|
407
498
|
from_clause=self._from_clause,
|
|
408
499
|
select_list=select_list,
|
|
409
500
|
where_clause=where_clause,
|
|
@@ -431,41 +522,41 @@ class DataFrame:
|
|
|
431
522
|
raise excs.Error(msg) from e
|
|
432
523
|
|
|
433
524
|
def _output_row_iterator(self) -> Iterator[list]:
|
|
434
|
-
|
|
525
|
+
# TODO: extend begin_xact() to accept multiple TVPs for joins
|
|
526
|
+
single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
|
|
527
|
+
with Catalog.get().begin_xact(tbl=single_tbl, for_write=False):
|
|
435
528
|
try:
|
|
436
529
|
for data_row in self._exec():
|
|
437
530
|
yield [data_row[e.slot_idx] for e in self._select_list_exprs]
|
|
438
531
|
except excs.ExprEvalError as e:
|
|
439
532
|
self._raise_expr_eval_err(e)
|
|
440
|
-
except
|
|
441
|
-
|
|
533
|
+
except (sql_exc.DBAPIError, sql_exc.OperationalError, sql_exc.InternalError) as e:
|
|
534
|
+
Catalog.get().convert_sql_exc(e, tbl=(single_tbl.tbl_version if single_tbl is not None else None))
|
|
535
|
+
raise # just re-raise if not converted to a Pixeltable error
|
|
442
536
|
|
|
443
|
-
def collect(self) ->
|
|
444
|
-
return
|
|
537
|
+
def collect(self) -> ResultSet:
|
|
538
|
+
return ResultSet(list(self._output_row_iterator()), self.schema)
|
|
445
539
|
|
|
446
|
-
async def _acollect(self) ->
|
|
540
|
+
async def _acollect(self) -> ResultSet:
|
|
541
|
+
single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
|
|
447
542
|
try:
|
|
448
543
|
result = [[row[e.slot_idx] for e in self._select_list_exprs] async for row in self._aexec()]
|
|
449
|
-
return
|
|
544
|
+
return ResultSet(result, self.schema)
|
|
450
545
|
except excs.ExprEvalError as e:
|
|
451
546
|
self._raise_expr_eval_err(e)
|
|
452
|
-
except
|
|
453
|
-
|
|
547
|
+
except (sql_exc.DBAPIError, sql_exc.OperationalError, sql_exc.InternalError) as e:
|
|
548
|
+
Catalog.get().convert_sql_exc(e, tbl=(single_tbl.tbl_version if single_tbl is not None else None))
|
|
549
|
+
raise # just re-raise if not converted to a Pixeltable error
|
|
454
550
|
|
|
455
551
|
def count(self) -> int:
|
|
456
|
-
"""Return the number of rows in the
|
|
552
|
+
"""Return the number of rows in the Query.
|
|
457
553
|
|
|
458
554
|
Returns:
|
|
459
|
-
The number of rows in the
|
|
555
|
+
The number of rows in the Query.
|
|
460
556
|
"""
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
from pixeltable.plan import Planner
|
|
465
|
-
|
|
466
|
-
stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
|
|
467
|
-
with Env.get().begin_xact() as conn:
|
|
468
|
-
result: int = conn.execute(stmt).scalar_one()
|
|
557
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False) as conn:
|
|
558
|
+
count_stmt = Planner.create_count_stmt(self)
|
|
559
|
+
result: int = conn.execute(count_stmt).scalar_one()
|
|
469
560
|
assert isinstance(result, int)
|
|
470
561
|
return result
|
|
471
562
|
|
|
@@ -510,12 +601,15 @@ class DataFrame:
|
|
|
510
601
|
if self.limit_val is not None:
|
|
511
602
|
heading_vals.append('Limit')
|
|
512
603
|
info_vals.append(self.limit_val.display_str(inline=False))
|
|
604
|
+
if self.sample_clause is not None:
|
|
605
|
+
heading_vals.append('Sample')
|
|
606
|
+
info_vals.append(self.sample_clause.display_str(inline=False))
|
|
513
607
|
assert len(heading_vals) == len(info_vals)
|
|
514
608
|
return pd.DataFrame(info_vals, index=heading_vals)
|
|
515
609
|
|
|
516
610
|
def describe(self) -> None:
|
|
517
611
|
"""
|
|
518
|
-
Prints a tabular description of this
|
|
612
|
+
Prints a tabular description of this Query.
|
|
519
613
|
The description has two columns, heading and info, which list the contents of each 'component'
|
|
520
614
|
(select list, where clause, ...) vertically.
|
|
521
615
|
"""
|
|
@@ -532,35 +626,35 @@ class DataFrame:
|
|
|
532
626
|
def _repr_html_(self) -> str:
|
|
533
627
|
return self._descriptors().to_html()
|
|
534
628
|
|
|
535
|
-
def select(self, *items: Any, **named_items: Any) ->
|
|
536
|
-
"""Select columns or expressions from the
|
|
629
|
+
def select(self, *items: Any, **named_items: Any) -> Query:
|
|
630
|
+
"""Select columns or expressions from the Query.
|
|
537
631
|
|
|
538
632
|
Args:
|
|
539
633
|
items: expressions to be selected
|
|
540
634
|
named_items: named expressions to be selected
|
|
541
635
|
|
|
542
636
|
Returns:
|
|
543
|
-
A new
|
|
637
|
+
A new Query with the specified select list.
|
|
544
638
|
|
|
545
639
|
Raises:
|
|
546
640
|
Error: If the select list is already specified,
|
|
547
641
|
or if any of the specified expressions are invalid,
|
|
548
|
-
or refer to tables not in the
|
|
642
|
+
or refer to tables not in the Query.
|
|
549
643
|
|
|
550
644
|
Examples:
|
|
551
|
-
Given the
|
|
645
|
+
Given the Query person from a table t with all its columns and rows:
|
|
552
646
|
|
|
553
647
|
>>> person = t.select()
|
|
554
648
|
|
|
555
|
-
Select the columns 'name' and 'age' (referenced in table t) from the
|
|
649
|
+
Select the columns 'name' and 'age' (referenced in table t) from the Query person:
|
|
556
650
|
|
|
557
|
-
>>>
|
|
651
|
+
>>> query = person.select(t.name, t.age)
|
|
558
652
|
|
|
559
|
-
Select the columns 'name' (referenced in table t) from the
|
|
653
|
+
Select the columns 'name' (referenced in table t) from the Query person,
|
|
560
654
|
and a named column 'is_adult' from the expression `age >= 18` where 'age' is
|
|
561
655
|
another column in table t:
|
|
562
656
|
|
|
563
|
-
>>>
|
|
657
|
+
>>> query = person.select(t.name, is_adult=(t.age >= 18))
|
|
564
658
|
|
|
565
659
|
"""
|
|
566
660
|
if self.select_list is not None:
|
|
@@ -573,7 +667,7 @@ class DataFrame:
|
|
|
573
667
|
return self
|
|
574
668
|
|
|
575
669
|
# analyze select list; wrap literals with the corresponding expressions
|
|
576
|
-
select_list: list[tuple[exprs.Expr,
|
|
670
|
+
select_list: list[tuple[exprs.Expr, str | None]] = []
|
|
577
671
|
for raw_expr, name in base_list:
|
|
578
672
|
expr = exprs.Expr.from_object(raw_expr)
|
|
579
673
|
if expr is None:
|
|
@@ -593,22 +687,22 @@ class DataFrame:
|
|
|
593
687
|
pass
|
|
594
688
|
if not expr.is_bound_by(self._from_clause.tbls):
|
|
595
689
|
raise excs.Error(
|
|
596
|
-
f"
|
|
597
|
-
f'({",".join(tbl.tbl_version.get().versioned_name for tbl in self._from_clause.tbls)})'
|
|
690
|
+
f"That expression cannot be evaluated in the context of this query's tables "
|
|
691
|
+
f'({",".join(tbl.tbl_version.get().versioned_name for tbl in self._from_clause.tbls)}): {expr}'
|
|
598
692
|
)
|
|
599
693
|
select_list.append((expr, name))
|
|
600
694
|
|
|
601
695
|
# check user provided names do not conflict among themselves or with auto-generated ones
|
|
602
696
|
seen: set[str] = set()
|
|
603
|
-
_, names =
|
|
697
|
+
_, names = Query._normalize_select_list(self._from_clause.tbls, select_list)
|
|
604
698
|
for name in names:
|
|
605
699
|
if name in seen:
|
|
606
700
|
repeated_names = [j for j, x in enumerate(names) if x == name]
|
|
607
701
|
pretty = ', '.join(map(str, repeated_names))
|
|
608
|
-
raise excs.Error(f'Repeated column name
|
|
702
|
+
raise excs.Error(f'Repeated column name {name!r} in select() at positions: {pretty}')
|
|
609
703
|
seen.add(name)
|
|
610
704
|
|
|
611
|
-
return
|
|
705
|
+
return Query(
|
|
612
706
|
from_clause=self._from_clause,
|
|
613
707
|
select_list=select_list,
|
|
614
708
|
where_clause=self.where_clause,
|
|
@@ -618,37 +712,39 @@ class DataFrame:
|
|
|
618
712
|
limit=self.limit_val,
|
|
619
713
|
)
|
|
620
714
|
|
|
621
|
-
def where(self, pred: exprs.Expr) ->
|
|
715
|
+
def where(self, pred: exprs.Expr) -> Query:
|
|
622
716
|
"""Filter rows based on a predicate.
|
|
623
717
|
|
|
624
718
|
Args:
|
|
625
719
|
pred: the predicate to filter rows
|
|
626
720
|
|
|
627
721
|
Returns:
|
|
628
|
-
A new
|
|
722
|
+
A new Query with the specified predicates replacing the where-clause.
|
|
629
723
|
|
|
630
724
|
Raises:
|
|
631
725
|
Error: If the predicate is not a Pixeltable expression,
|
|
632
726
|
or if it does not return a boolean value,
|
|
633
|
-
or refers to tables not in the
|
|
727
|
+
or refers to tables not in the Query.
|
|
634
728
|
|
|
635
729
|
Examples:
|
|
636
|
-
Given the
|
|
730
|
+
Given the Query person from a table t with all its columns and rows:
|
|
637
731
|
|
|
638
732
|
>>> person = t.select()
|
|
639
733
|
|
|
640
|
-
Filter the above
|
|
734
|
+
Filter the above Query person to only include rows where the column 'age'
|
|
641
735
|
(referenced in table t) is greater than 30:
|
|
642
736
|
|
|
643
|
-
>>>
|
|
737
|
+
>>> query = person.where(t.age > 30)
|
|
644
738
|
"""
|
|
645
739
|
if self.where_clause is not None:
|
|
646
|
-
raise excs.Error('
|
|
740
|
+
raise excs.Error('where() clause already specified')
|
|
741
|
+
if self.sample_clause is not None:
|
|
742
|
+
raise excs.Error('where() cannot be used after sample()')
|
|
647
743
|
if not isinstance(pred, exprs.Expr):
|
|
648
|
-
raise excs.Error(f'
|
|
744
|
+
raise excs.Error(f'where() expects a Pixeltable expression; got: {pred}')
|
|
649
745
|
if not pred.col_type.is_bool_type():
|
|
650
|
-
raise excs.Error(f'
|
|
651
|
-
return
|
|
746
|
+
raise excs.Error(f'where() expression needs to return `Bool`, but instead returns `{pred.col_type}`')
|
|
747
|
+
return Query(
|
|
652
748
|
from_clause=self._from_clause,
|
|
653
749
|
select_list=self.select_list,
|
|
654
750
|
where_clause=pred,
|
|
@@ -659,7 +755,7 @@ class DataFrame:
|
|
|
659
755
|
)
|
|
660
756
|
|
|
661
757
|
def _create_join_predicate(
|
|
662
|
-
self, other: catalog.TableVersionPath, on:
|
|
758
|
+
self, other: catalog.TableVersionPath, on: exprs.Expr | Sequence[exprs.ColumnRef]
|
|
663
759
|
) -> exprs.Expr:
|
|
664
760
|
"""Verifies user-specified 'on' argument and converts it into a join predicate."""
|
|
665
761
|
col_refs: list[exprs.ColumnRef] = []
|
|
@@ -669,19 +765,21 @@ class DataFrame:
|
|
|
669
765
|
on = [on]
|
|
670
766
|
elif isinstance(on, exprs.Expr):
|
|
671
767
|
if not on.is_bound_by(joined_tbls):
|
|
672
|
-
raise excs.Error(f
|
|
768
|
+
raise excs.Error(f'`on` expression cannot be evaluated in the context of the joined tables: {on}')
|
|
673
769
|
if not on.col_type.is_bool_type():
|
|
674
|
-
raise excs.Error(
|
|
770
|
+
raise excs.Error(
|
|
771
|
+
f'`on` expects an expression of type `Bool`, but got one of type `{on.col_type}`: {on}'
|
|
772
|
+
)
|
|
675
773
|
return on
|
|
676
774
|
elif not isinstance(on, Sequence) or len(on) == 0:
|
|
677
|
-
raise excs.Error(
|
|
775
|
+
raise excs.Error('`on` must be a sequence of column references or a boolean expression')
|
|
678
776
|
|
|
679
777
|
assert isinstance(on, Sequence)
|
|
680
778
|
for col_ref in on:
|
|
681
779
|
if not isinstance(col_ref, exprs.ColumnRef):
|
|
682
|
-
raise excs.Error(
|
|
780
|
+
raise excs.Error('`on` must be a sequence of column references or a boolean expression')
|
|
683
781
|
if not col_ref.is_bound_by(joined_tbls):
|
|
684
|
-
raise excs.Error(f
|
|
782
|
+
raise excs.Error(f'`on` expression cannot be evaluated in the context of the joined tables: {col_ref}')
|
|
685
783
|
col_refs.append(col_ref)
|
|
686
784
|
|
|
687
785
|
predicates: list[exprs.Expr] = []
|
|
@@ -689,27 +787,27 @@ class DataFrame:
|
|
|
689
787
|
assert len(col_refs) > 0 and len(joined_tbls) >= 2
|
|
690
788
|
for col_ref in col_refs:
|
|
691
789
|
# identify the referenced column by name in 'other'
|
|
692
|
-
rhs_col = other.get_column(col_ref.col.name
|
|
790
|
+
rhs_col = other.get_column(col_ref.col.name)
|
|
693
791
|
if rhs_col is None:
|
|
694
|
-
raise excs.Error(f
|
|
792
|
+
raise excs.Error(f'`on` column {col_ref.col.name!r} not found in joined table')
|
|
695
793
|
rhs_col_ref = exprs.ColumnRef(rhs_col)
|
|
696
794
|
|
|
697
|
-
lhs_col_ref:
|
|
698
|
-
if any(tbl.has_column(col_ref.col
|
|
795
|
+
lhs_col_ref: exprs.ColumnRef | None = None
|
|
796
|
+
if any(tbl.has_column(col_ref.col) for tbl in self._from_clause.tbls):
|
|
699
797
|
# col_ref comes from the existing from_clause, we use that directly
|
|
700
798
|
lhs_col_ref = col_ref
|
|
701
799
|
else:
|
|
702
800
|
# col_ref comes from other, we need to look for a match in the existing from_clause by name
|
|
703
801
|
for tbl in self._from_clause.tbls:
|
|
704
|
-
col = tbl.get_column(col_ref.col.name
|
|
802
|
+
col = tbl.get_column(col_ref.col.name)
|
|
705
803
|
if col is None:
|
|
706
804
|
continue
|
|
707
805
|
if lhs_col_ref is not None:
|
|
708
|
-
raise excs.Error(f
|
|
806
|
+
raise excs.Error(f'`on`: ambiguous column reference: {col_ref.col.name}')
|
|
709
807
|
lhs_col_ref = exprs.ColumnRef(col)
|
|
710
808
|
if lhs_col_ref is None:
|
|
711
809
|
tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
|
|
712
|
-
raise excs.Error(f
|
|
810
|
+
raise excs.Error(f'`on`: column {col_ref.col.name!r} not found in any of: {" ".join(tbl_names)}')
|
|
713
811
|
pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
|
|
714
812
|
predicates.append(pred)
|
|
715
813
|
|
|
@@ -722,11 +820,11 @@ class DataFrame:
|
|
|
722
820
|
def join(
|
|
723
821
|
self,
|
|
724
822
|
other: catalog.Table,
|
|
725
|
-
on:
|
|
823
|
+
on: exprs.Expr | Sequence[exprs.ColumnRef] | None = None,
|
|
726
824
|
how: plan.JoinType.LiteralType = 'inner',
|
|
727
|
-
) ->
|
|
825
|
+
) -> Query:
|
|
728
826
|
"""
|
|
729
|
-
Join this
|
|
827
|
+
Join this Query with a table.
|
|
730
828
|
|
|
731
829
|
Args:
|
|
732
830
|
other: the table to join with
|
|
@@ -734,23 +832,23 @@ class DataFrame:
|
|
|
734
832
|
expression.
|
|
735
833
|
|
|
736
834
|
- column references: implies an equality predicate that matches columns in both this
|
|
737
|
-
|
|
835
|
+
Query and `other` by name.
|
|
738
836
|
|
|
739
|
-
- column in `other`: A column with that same name must be present in this
|
|
837
|
+
- column in `other`: A column with that same name must be present in this Query, and **it must
|
|
740
838
|
be unique** (otherwise the join is ambiguous).
|
|
741
|
-
- column in this
|
|
839
|
+
- column in this Query: A column with that same name must be present in `other`.
|
|
742
840
|
|
|
743
841
|
- boolean expression: The expressions must be valid in the context of the joined tables.
|
|
744
842
|
how: the type of join to perform.
|
|
745
843
|
|
|
746
844
|
- `'inner'`: only keep rows that have a match in both
|
|
747
|
-
- `'left'`: keep all rows from this
|
|
748
|
-
- `'right'`: keep all rows from the other table and only matching rows from this
|
|
749
|
-
- `'full_outer'`: keep all rows from both this
|
|
845
|
+
- `'left'`: keep all rows from this Query and only matching rows from the other table
|
|
846
|
+
- `'right'`: keep all rows from the other table and only matching rows from this Query
|
|
847
|
+
- `'full_outer'`: keep all rows from both this Query and the other table
|
|
750
848
|
- `'cross'`: Cartesian product; no `on` condition allowed
|
|
751
849
|
|
|
752
850
|
Returns:
|
|
753
|
-
A new
|
|
851
|
+
A new Query.
|
|
754
852
|
|
|
755
853
|
Examples:
|
|
756
854
|
Perform an inner join between t1 and t2 on the column id:
|
|
@@ -769,23 +867,25 @@ class DataFrame:
|
|
|
769
867
|
Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
|
|
770
868
|
key columns d1 and d2 in t):
|
|
771
869
|
|
|
772
|
-
>>>
|
|
870
|
+
>>> query = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
|
|
773
871
|
"""
|
|
774
|
-
|
|
872
|
+
if self.sample_clause is not None:
|
|
873
|
+
raise excs.Error('join() cannot be used with sample()')
|
|
874
|
+
join_pred: exprs.Expr | None
|
|
775
875
|
if how == 'cross':
|
|
776
876
|
if on is not None:
|
|
777
|
-
raise excs.Error(
|
|
877
|
+
raise excs.Error('`on` not allowed for cross join')
|
|
778
878
|
join_pred = None
|
|
779
879
|
else:
|
|
780
880
|
if on is None:
|
|
781
|
-
raise excs.Error(f
|
|
881
|
+
raise excs.Error(f'`how={how!r}` requires `on` to be present')
|
|
782
882
|
join_pred = self._create_join_predicate(other._tbl_version_path, on)
|
|
783
|
-
join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how,
|
|
883
|
+
join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, '`how`'), join_predicate=join_pred)
|
|
784
884
|
from_clause = plan.FromClause(
|
|
785
885
|
tbls=[*self._from_clause.tbls, other._tbl_version_path],
|
|
786
886
|
join_clauses=[*self._from_clause.join_clauses, join_clause],
|
|
787
887
|
)
|
|
788
|
-
return
|
|
888
|
+
return Query(
|
|
789
889
|
from_clause=from_clause,
|
|
790
890
|
select_list=self.select_list,
|
|
791
891
|
where_clause=self.where_clause,
|
|
@@ -795,70 +895,73 @@ class DataFrame:
|
|
|
795
895
|
limit=self.limit_val,
|
|
796
896
|
)
|
|
797
897
|
|
|
798
|
-
def group_by(self, *grouping_items: Any) ->
|
|
799
|
-
"""Add a group-by clause to this
|
|
898
|
+
def group_by(self, *grouping_items: Any) -> Query:
|
|
899
|
+
"""Add a group-by clause to this Query.
|
|
800
900
|
|
|
801
901
|
Variants:
|
|
802
|
-
- group_by(
|
|
803
|
-
- group_by(
|
|
902
|
+
- group_by(base_tbl): group a component view by their respective base table rows
|
|
903
|
+
- group_by(expr1, expr2, expr3): group by the given expressions
|
|
804
904
|
|
|
805
|
-
Note
|
|
905
|
+
Note that grouping will be applied to the rows and take effect when
|
|
806
906
|
used with an aggregation function like sum(), count() etc.
|
|
807
907
|
|
|
808
908
|
Args:
|
|
809
909
|
grouping_items: expressions to group by
|
|
810
910
|
|
|
811
911
|
Returns:
|
|
812
|
-
A new
|
|
912
|
+
A new Query with the specified group-by clause.
|
|
813
913
|
|
|
814
914
|
Raises:
|
|
815
915
|
Error: If the group-by clause is already specified,
|
|
816
916
|
or if the specified expression is invalid,
|
|
817
|
-
or refer to tables not in the
|
|
818
|
-
or if the
|
|
917
|
+
or refer to tables not in the Query,
|
|
918
|
+
or if the Query is a result of a join.
|
|
819
919
|
|
|
820
920
|
Examples:
|
|
821
|
-
Given the
|
|
921
|
+
Given the Query book from a table t with all its columns and rows:
|
|
822
922
|
|
|
823
923
|
>>> book = t.select()
|
|
824
924
|
|
|
825
|
-
Group the above
|
|
925
|
+
Group the above Query book by the 'genre' column (referenced in table t):
|
|
826
926
|
|
|
827
|
-
>>>
|
|
927
|
+
>>> query = book.group_by(t.genre)
|
|
828
928
|
|
|
829
|
-
Use the above
|
|
929
|
+
Use the above Query grouped by genre to count the number of
|
|
830
930
|
books for each 'genre':
|
|
831
931
|
|
|
832
|
-
>>>
|
|
932
|
+
>>> query = book.group_by(t.genre).select(t.genre, count=count(t.genre)).show()
|
|
833
933
|
|
|
834
|
-
Use the above
|
|
934
|
+
Use the above Query grouped by genre to the total price of
|
|
835
935
|
books for each 'genre':
|
|
836
936
|
|
|
837
|
-
>>>
|
|
937
|
+
>>> query = book.group_by(t.genre).select(t.genre, total=sum(t.price)).show()
|
|
838
938
|
"""
|
|
839
939
|
if self.group_by_clause is not None:
|
|
840
|
-
raise excs.Error('
|
|
841
|
-
|
|
842
|
-
|
|
940
|
+
raise excs.Error('group_by() already specified')
|
|
941
|
+
if self.sample_clause is not None:
|
|
942
|
+
raise excs.Error('group_by() cannot be used with sample()')
|
|
943
|
+
|
|
944
|
+
grouping_tbl: catalog.TableVersion | None = None
|
|
945
|
+
group_by_clause: list[exprs.Expr] | None = None
|
|
843
946
|
for item in grouping_items:
|
|
844
947
|
if isinstance(item, (catalog.Table, catalog.TableVersion)):
|
|
845
948
|
if len(grouping_items) > 1:
|
|
846
|
-
raise excs.Error('group_by(): only one
|
|
949
|
+
raise excs.Error('group_by(): only one Table can be specified')
|
|
847
950
|
if len(self._from_clause.tbls) > 1:
|
|
848
951
|
raise excs.Error('group_by() with Table not supported for joins')
|
|
849
952
|
grouping_tbl = item if isinstance(item, catalog.TableVersion) else item._tbl_version.get()
|
|
850
953
|
# we need to make sure that the grouping table is a base of self.tbl
|
|
851
954
|
base = self._first_tbl.find_tbl_version(grouping_tbl.id)
|
|
852
|
-
if base is None or base.id == self._first_tbl.tbl_id
|
|
955
|
+
if base is None or base.id == self._first_tbl.tbl_id:
|
|
853
956
|
raise excs.Error(
|
|
854
|
-
f'group_by(): {grouping_tbl.name} is not a base table of {self._first_tbl.tbl_name()}'
|
|
957
|
+
f'group_by(): {grouping_tbl.name!r} is not a base table of {self._first_tbl.tbl_name()!r}'
|
|
855
958
|
)
|
|
856
959
|
break
|
|
857
960
|
if not isinstance(item, exprs.Expr):
|
|
858
961
|
raise excs.Error(f'Invalid expression in group_by(): {item}')
|
|
859
962
|
if grouping_tbl is None:
|
|
860
963
|
group_by_clause = list(grouping_items)
|
|
861
|
-
return
|
|
964
|
+
return Query(
|
|
862
965
|
from_clause=self._from_clause,
|
|
863
966
|
select_list=self.select_list,
|
|
864
967
|
where_clause=self.where_clause,
|
|
@@ -868,11 +971,11 @@ class DataFrame:
|
|
|
868
971
|
limit=self.limit_val,
|
|
869
972
|
)
|
|
870
973
|
|
|
871
|
-
def distinct(self) ->
|
|
974
|
+
def distinct(self) -> Query:
|
|
872
975
|
"""
|
|
873
|
-
Remove duplicate rows from this
|
|
976
|
+
Remove duplicate rows from this Query.
|
|
874
977
|
|
|
875
|
-
Note that grouping will be applied to the rows based on the select clause of this
|
|
978
|
+
Note that grouping will be applied to the rows based on the select clause of this Query.
|
|
876
979
|
In the absence of a select clause, by default, all columns are selected in the grouping.
|
|
877
980
|
|
|
878
981
|
Examples:
|
|
@@ -891,8 +994,8 @@ class DataFrame:
|
|
|
891
994
|
exps, _ = self._normalize_select_list(self._from_clause.tbls, self.select_list)
|
|
892
995
|
return self.group_by(*exps)
|
|
893
996
|
|
|
894
|
-
def order_by(self, *expr_list: exprs.Expr, asc: bool = True) ->
|
|
895
|
-
"""Add an order-by clause to this
|
|
997
|
+
def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> Query:
|
|
998
|
+
"""Add an order-by clause to this Query.
|
|
896
999
|
|
|
897
1000
|
Args:
|
|
898
1001
|
expr_list: expressions to order by
|
|
@@ -900,33 +1003,35 @@ class DataFrame:
|
|
|
900
1003
|
Default is True.
|
|
901
1004
|
|
|
902
1005
|
Returns:
|
|
903
|
-
A new
|
|
1006
|
+
A new Query with the specified order-by clause.
|
|
904
1007
|
|
|
905
1008
|
Raises:
|
|
906
1009
|
Error: If the order-by clause is already specified,
|
|
907
1010
|
or if the specified expression is invalid,
|
|
908
|
-
or refer to tables not in the
|
|
1011
|
+
or refer to tables not in the Query.
|
|
909
1012
|
|
|
910
1013
|
Examples:
|
|
911
|
-
Given the
|
|
1014
|
+
Given the Query book from a table t with all its columns and rows:
|
|
912
1015
|
|
|
913
1016
|
>>> book = t.select()
|
|
914
1017
|
|
|
915
|
-
Order the above
|
|
1018
|
+
Order the above Query book by two columns (price, pages) in descending order:
|
|
916
1019
|
|
|
917
|
-
>>>
|
|
1020
|
+
>>> query = book.order_by(t.price, t.pages, asc=False)
|
|
918
1021
|
|
|
919
|
-
Order the above
|
|
1022
|
+
Order the above Query book by price in descending order, but order the pages
|
|
920
1023
|
in ascending order:
|
|
921
1024
|
|
|
922
|
-
>>>
|
|
1025
|
+
>>> query = book.order_by(t.price, asc=False).order_by(t.pages)
|
|
923
1026
|
"""
|
|
1027
|
+
if self.sample_clause is not None:
|
|
1028
|
+
raise excs.Error('order_by() cannot be used with sample()')
|
|
924
1029
|
for e in expr_list:
|
|
925
1030
|
if not isinstance(e, exprs.Expr):
|
|
926
1031
|
raise excs.Error(f'Invalid expression in order_by(): {e}')
|
|
927
1032
|
order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
|
|
928
1033
|
order_by_clause.extend([(e.copy(), asc) for e in expr_list])
|
|
929
|
-
return
|
|
1034
|
+
return Query(
|
|
930
1035
|
from_clause=self._from_clause,
|
|
931
1036
|
select_list=self.select_list,
|
|
932
1037
|
where_clause=self.where_clause,
|
|
@@ -936,31 +1041,148 @@ class DataFrame:
|
|
|
936
1041
|
limit=self.limit_val,
|
|
937
1042
|
)
|
|
938
1043
|
|
|
939
|
-
def limit(self, n: int) ->
|
|
940
|
-
"""Limit the number of rows in the
|
|
1044
|
+
def limit(self, n: int) -> Query:
|
|
1045
|
+
"""Limit the number of rows in the Query.
|
|
941
1046
|
|
|
942
1047
|
Args:
|
|
943
1048
|
n: Number of rows to select.
|
|
944
1049
|
|
|
945
1050
|
Returns:
|
|
946
|
-
A new
|
|
1051
|
+
A new Query with the specified limited rows.
|
|
947
1052
|
"""
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
return
|
|
1053
|
+
if self.sample_clause is not None:
|
|
1054
|
+
raise excs.Error('limit() cannot be used with sample()')
|
|
1055
|
+
|
|
1056
|
+
limit_expr = self._convert_param_to_typed_expr(n, ts.IntType(nullable=False), True, 'limit()')
|
|
1057
|
+
return Query(
|
|
1058
|
+
from_clause=self._from_clause,
|
|
1059
|
+
select_list=self.select_list,
|
|
1060
|
+
where_clause=self.where_clause,
|
|
1061
|
+
group_by_clause=self.group_by_clause,
|
|
1062
|
+
grouping_tbl=self.grouping_tbl,
|
|
1063
|
+
order_by_clause=self.order_by_clause,
|
|
1064
|
+
limit=limit_expr,
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
def sample(
|
|
1068
|
+
self,
|
|
1069
|
+
n: int | None = None,
|
|
1070
|
+
n_per_stratum: int | None = None,
|
|
1071
|
+
fraction: float | None = None,
|
|
1072
|
+
seed: int | None = None,
|
|
1073
|
+
stratify_by: Any = None,
|
|
1074
|
+
) -> Query:
|
|
1075
|
+
"""
|
|
1076
|
+
Return a new Query specifying a sample of rows from the Query, considered in a shuffled order.
|
|
1077
|
+
|
|
1078
|
+
The size of the sample can be specified in three ways:
|
|
1079
|
+
|
|
1080
|
+
- `n`: the total number of rows to produce as a sample
|
|
1081
|
+
- `n_per_stratum`: the number of rows to produce per stratum as a sample
|
|
1082
|
+
- `fraction`: the fraction of available rows to produce as a sample
|
|
1083
|
+
|
|
1084
|
+
The sample can be stratified by one or more columns, which means that the sample will
|
|
1085
|
+
be selected from each stratum separately.
|
|
1086
|
+
|
|
1087
|
+
The data is shuffled before creating the sample.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
n: Total number of rows to produce as a sample.
|
|
1091
|
+
n_per_stratum: Number of rows to produce per stratum as a sample. This parameter is only valid if
|
|
1092
|
+
`stratify_by` is specified. Only one of `n` or `n_per_stratum` can be specified.
|
|
1093
|
+
fraction: Fraction of available rows to produce as a sample. This parameter is not usable with `n` or
|
|
1094
|
+
`n_per_stratum`. The fraction must be between 0.0 and 1.0.
|
|
1095
|
+
seed: Random seed for reproducible shuffling
|
|
1096
|
+
stratify_by: If specified, the sample will be stratified by these values.
|
|
1097
|
+
|
|
1098
|
+
Returns:
|
|
1099
|
+
A new Query which specifies the sampled rows
|
|
1100
|
+
|
|
1101
|
+
Examples:
|
|
1102
|
+
Given the Table `person` containing the field 'age', we can create samples of the table in various ways:
|
|
1103
|
+
|
|
1104
|
+
Sample 100 rows from the above Table:
|
|
1105
|
+
|
|
1106
|
+
>>> query = person.sample(n=100)
|
|
1107
|
+
|
|
1108
|
+
Sample 10% of the rows from the above Table:
|
|
1109
|
+
|
|
1110
|
+
>>> query = person.sample(fraction=0.1)
|
|
1111
|
+
|
|
1112
|
+
Sample 10% of the rows from the above Table, stratified by the column 'age':
|
|
1113
|
+
|
|
1114
|
+
>>> query = person.sample(fraction=0.1, stratify_by=t.age)
|
|
1115
|
+
|
|
1116
|
+
Equal allocation sampling: Sample 2 rows from each age present in the above Table:
|
|
1117
|
+
|
|
1118
|
+
>>> query = person.sample(n_per_stratum=2, stratify_by=t.age)
|
|
1119
|
+
|
|
1120
|
+
Sampling is compatible with the where clause, so we can also sample from a filtered Query:
|
|
1121
|
+
|
|
1122
|
+
>>> query = person.where(t.age > 30).sample(n=100)
|
|
1123
|
+
"""
|
|
1124
|
+
# Check context of usage
|
|
1125
|
+
if self.sample_clause is not None:
|
|
1126
|
+
raise excs.Error('Multiple sample() clauses not allowed')
|
|
1127
|
+
if self.group_by_clause is not None:
|
|
1128
|
+
raise excs.Error('sample() cannot be used with group_by()')
|
|
1129
|
+
if self.order_by_clause is not None:
|
|
1130
|
+
raise excs.Error('sample() cannot be used with order_by()')
|
|
1131
|
+
if self.limit_val is not None:
|
|
1132
|
+
raise excs.Error('sample() cannot be used with limit()')
|
|
1133
|
+
if self._has_joins():
|
|
1134
|
+
raise excs.Error('sample() cannot be used with join()')
|
|
1135
|
+
|
|
1136
|
+
# Check paramter combinations
|
|
1137
|
+
if (n is not None) + (n_per_stratum is not None) + (fraction is not None) != 1:
|
|
1138
|
+
raise excs.Error('Exactly one of `n`, `n_per_stratum`, or `fraction` must be specified.')
|
|
1139
|
+
if n_per_stratum is not None and stratify_by is None:
|
|
1140
|
+
raise excs.Error('Must specify `stratify_by` to use `n_per_stratum`')
|
|
1141
|
+
|
|
1142
|
+
# Check parameter types and values
|
|
1143
|
+
n = self.validate_constant_type_range(n, ts.IntType(nullable=False), False, 'n', (1, None))
|
|
1144
|
+
n_per_stratum = self.validate_constant_type_range(
|
|
1145
|
+
n_per_stratum, ts.IntType(nullable=False), False, 'n_per_stratum', (1, None)
|
|
1146
|
+
)
|
|
1147
|
+
fraction = self.validate_constant_type_range(
|
|
1148
|
+
fraction, ts.FloatType(nullable=False), False, 'fraction', (0.0, 1.0)
|
|
1149
|
+
)
|
|
1150
|
+
seed = self.validate_constant_type_range(seed, ts.IntType(nullable=False), False, 'seed')
|
|
1151
|
+
|
|
1152
|
+
# analyze stratify list
|
|
1153
|
+
stratify_exprs: list[exprs.Expr] = []
|
|
1154
|
+
if stratify_by is not None:
|
|
1155
|
+
if isinstance(stratify_by, exprs.Expr):
|
|
1156
|
+
stratify_by = [stratify_by]
|
|
1157
|
+
if not isinstance(stratify_by, (list, tuple)):
|
|
1158
|
+
raise excs.Error('`stratify_by` must be a list of scalar expressions')
|
|
1159
|
+
for expr in stratify_by:
|
|
1160
|
+
if expr is None or not isinstance(expr, exprs.Expr):
|
|
1161
|
+
raise excs.Error(f'Invalid expression: {expr}')
|
|
1162
|
+
if not expr.col_type.is_scalar_type():
|
|
1163
|
+
raise excs.Error(f'Invalid type: expression must be a scalar type (not `{expr.col_type}`)')
|
|
1164
|
+
if not expr.is_bound_by(self._from_clause.tbls):
|
|
1165
|
+
raise excs.Error(
|
|
1166
|
+
f"That expression cannot be evaluated in the context of this query's tables "
|
|
1167
|
+
f'({",".join(tbl.tbl_name() for tbl in self._from_clause.tbls)}): {expr}'
|
|
1168
|
+
)
|
|
1169
|
+
stratify_exprs.append(expr)
|
|
1170
|
+
|
|
1171
|
+
sample_clause = SampleClause(None, n, n_per_stratum, fraction, seed, stratify_exprs)
|
|
1172
|
+
|
|
1173
|
+
return Query(
|
|
953
1174
|
from_clause=self._from_clause,
|
|
954
1175
|
select_list=self.select_list,
|
|
955
1176
|
where_clause=self.where_clause,
|
|
956
1177
|
group_by_clause=self.group_by_clause,
|
|
957
1178
|
grouping_tbl=self.grouping_tbl,
|
|
958
1179
|
order_by_clause=self.order_by_clause,
|
|
959
|
-
limit=
|
|
1180
|
+
limit=self.limit_val,
|
|
1181
|
+
sample_clause=sample_clause,
|
|
960
1182
|
)
|
|
961
1183
|
|
|
962
1184
|
def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
|
|
963
|
-
"""Update rows in the underlying table of the
|
|
1185
|
+
"""Update rows in the underlying table of the Query.
|
|
964
1186
|
|
|
965
1187
|
Update rows in the table with the specified value_spec.
|
|
966
1188
|
|
|
@@ -973,70 +1195,105 @@ class DataFrame:
|
|
|
973
1195
|
UpdateStatus: the status of the update operation.
|
|
974
1196
|
|
|
975
1197
|
Example:
|
|
976
|
-
Given the
|
|
1198
|
+
Given the Query person from a table t with all its columns and rows:
|
|
977
1199
|
|
|
978
1200
|
>>> person = t.select()
|
|
979
1201
|
|
|
980
|
-
Via the above
|
|
1202
|
+
Via the above Query person, update the column 'city' to 'Oakland'
|
|
981
1203
|
and 'state' to 'CA' in the table t:
|
|
982
1204
|
|
|
983
|
-
>>>
|
|
1205
|
+
>>> person.update({'city': 'Oakland', 'state': 'CA'})
|
|
984
1206
|
|
|
985
|
-
Via the above
|
|
1207
|
+
Via the above Query person, update the column 'age' to 30 for any
|
|
986
1208
|
rows where 'year' is 2014 in the table t:
|
|
987
1209
|
|
|
988
|
-
>>>
|
|
1210
|
+
>>> person.where(t.year == 2014).update({'age': 30})
|
|
989
1211
|
"""
|
|
990
1212
|
self._validate_mutable('update', False)
|
|
991
|
-
with
|
|
1213
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
992
1214
|
return self._first_tbl.tbl_version.get().update(value_spec, where=self.where_clause, cascade=cascade)
|
|
993
1215
|
|
|
994
|
-
def
|
|
995
|
-
|
|
1216
|
+
def recompute_columns(
|
|
1217
|
+
self, *columns: str | exprs.ColumnRef, errors_only: bool = False, cascade: bool = True
|
|
1218
|
+
) -> UpdateStatus:
|
|
1219
|
+
"""Recompute one or more computed columns of the underlying table of the Query.
|
|
996
1220
|
|
|
997
|
-
|
|
1221
|
+
Args:
|
|
1222
|
+
columns: The names or references of the computed columns to recompute.
|
|
1223
|
+
errors_only: If True, only run the recomputation for rows that have errors in the column (ie, the column's
|
|
1224
|
+
`errortype` property indicates that an error occurred). Only allowed for recomputing a single column.
|
|
1225
|
+
cascade: if True, also update all computed columns that transitively depend on the recomputed columns.
|
|
998
1226
|
|
|
999
1227
|
Returns:
|
|
1000
|
-
UpdateStatus: the status of the
|
|
1228
|
+
UpdateStatus: the status of the operation.
|
|
1001
1229
|
|
|
1002
1230
|
Example:
|
|
1003
|
-
|
|
1231
|
+
For table `person` with column `age` and computed column `height`, recompute the value of `height` for all
|
|
1232
|
+
rows where `age` is less than 18:
|
|
1004
1233
|
|
|
1005
|
-
>>>
|
|
1234
|
+
>>> query = person.where(t.age < 18).recompute_columns(person.height)
|
|
1235
|
+
"""
|
|
1236
|
+
self._validate_mutable('recompute_columns', False)
|
|
1237
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
1238
|
+
tbl = Catalog.get().get_table_by_id(self._first_tbl.tbl_id)
|
|
1239
|
+
return tbl.recompute_columns(*columns, where=self.where_clause, errors_only=errors_only, cascade=cascade)
|
|
1006
1240
|
|
|
1007
|
-
|
|
1241
|
+
def delete(self) -> UpdateStatus:
|
|
1242
|
+
"""Delete rows form the underlying table of the Query.
|
|
1243
|
+
|
|
1244
|
+
The delete operation is only allowed for Queries on base tables.
|
|
1245
|
+
|
|
1246
|
+
Returns:
|
|
1247
|
+
UpdateStatus: the status of the delete operation.
|
|
1008
1248
|
|
|
1009
|
-
|
|
1249
|
+
Example:
|
|
1250
|
+
For a table `person` with column `age`, delete all rows where 'age' is less than 18:
|
|
1251
|
+
|
|
1252
|
+
>>> person.where(t.age < 18).delete()
|
|
1010
1253
|
"""
|
|
1011
1254
|
self._validate_mutable('delete', False)
|
|
1012
1255
|
if not self._first_tbl.is_insertable():
|
|
1013
|
-
raise excs.Error('Cannot delete
|
|
1014
|
-
with
|
|
1256
|
+
raise excs.Error('Cannot use `delete` on a view.')
|
|
1257
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
1015
1258
|
return self._first_tbl.tbl_version.get().delete(where=self.where_clause)
|
|
1016
1259
|
|
|
1017
1260
|
def _validate_mutable(self, op_name: str, allow_select: bool) -> None:
|
|
1018
|
-
"""Tests whether this
|
|
1261
|
+
"""Tests whether this Query can be mutated (such as by an update operation).
|
|
1019
1262
|
|
|
1020
1263
|
Args:
|
|
1021
1264
|
op_name: The name of the operation for which the test is being performed.
|
|
1022
|
-
allow_select: If True, allow a select() specification in the
|
|
1265
|
+
allow_select: If True, allow a select() specification in the Query.
|
|
1023
1266
|
"""
|
|
1267
|
+
self._validate_mutable_op_sequence(op_name, allow_select)
|
|
1268
|
+
|
|
1269
|
+
# TODO: Reconcile these with Table.__check_mutable()
|
|
1270
|
+
assert len(self._from_clause.tbls) == 1
|
|
1271
|
+
# First check if it's a replica, since every replica handle is also a snapshot
|
|
1272
|
+
if self._first_tbl.is_replica():
|
|
1273
|
+
raise excs.Error(f'Cannot use `{op_name}` on a replica.')
|
|
1274
|
+
if self._first_tbl.is_snapshot():
|
|
1275
|
+
raise excs.Error(f'Cannot use `{op_name}` on a snapshot.')
|
|
1276
|
+
|
|
1277
|
+
def _validate_mutable_op_sequence(self, op_name: str, allow_select: bool) -> None:
|
|
1278
|
+
"""Tests whether the sequence of operations on this Query is valid for a mutation operation."""
|
|
1024
1279
|
if self.group_by_clause is not None or self.grouping_tbl is not None:
|
|
1025
|
-
raise excs.Error(f'Cannot use `{op_name}` after `group_by
|
|
1280
|
+
raise excs.Error(f'Cannot use `{op_name}` after `group_by`.')
|
|
1026
1281
|
if self.order_by_clause is not None:
|
|
1027
|
-
raise excs.Error(f'Cannot use `{op_name}` after `order_by
|
|
1282
|
+
raise excs.Error(f'Cannot use `{op_name}` after `order_by`.')
|
|
1028
1283
|
if self.select_list is not None and not allow_select:
|
|
1029
|
-
raise excs.Error(f'Cannot use `{op_name}` after `select
|
|
1284
|
+
raise excs.Error(f'Cannot use `{op_name}` after `select`.')
|
|
1030
1285
|
if self.limit_val is not None:
|
|
1031
|
-
raise excs.Error(f'Cannot use `{op_name}` after `limit
|
|
1286
|
+
raise excs.Error(f'Cannot use `{op_name}` after `limit`.')
|
|
1287
|
+
if self._has_joins():
|
|
1288
|
+
raise excs.Error(f'Cannot use `{op_name}` after `join`.')
|
|
1032
1289
|
|
|
1033
1290
|
def as_dict(self) -> dict[str, Any]:
|
|
1034
1291
|
"""
|
|
1035
1292
|
Returns:
|
|
1036
|
-
Dictionary representing this
|
|
1293
|
+
Dictionary representing this Query.
|
|
1037
1294
|
"""
|
|
1038
1295
|
d = {
|
|
1039
|
-
'_classname': '
|
|
1296
|
+
'_classname': 'Query',
|
|
1040
1297
|
'from_clause': {
|
|
1041
1298
|
'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
|
|
1042
1299
|
'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses],
|
|
@@ -1053,13 +1310,14 @@ class DataFrame:
|
|
|
1053
1310
|
if self.order_by_clause is not None
|
|
1054
1311
|
else None,
|
|
1055
1312
|
'limit_val': self.limit_val.as_dict() if self.limit_val is not None else None,
|
|
1313
|
+
'sample_clause': self.sample_clause.as_dict() if self.sample_clause is not None else None,
|
|
1056
1314
|
}
|
|
1057
1315
|
return d
|
|
1058
1316
|
|
|
1059
1317
|
@classmethod
|
|
1060
|
-
def from_dict(cls, d: dict[str, Any]) -> '
|
|
1318
|
+
def from_dict(cls, d: dict[str, Any]) -> 'Query':
|
|
1061
1319
|
# we need to wrap the construction with a transaction, because it might need to load metadata
|
|
1062
|
-
with
|
|
1320
|
+
with Catalog.get().begin_xact(for_write=False):
|
|
1063
1321
|
tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
|
|
1064
1322
|
join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
|
|
1065
1323
|
from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
|
|
@@ -1079,8 +1337,9 @@ class DataFrame:
|
|
|
1079
1337
|
else None
|
|
1080
1338
|
)
|
|
1081
1339
|
limit_val = exprs.Expr.from_dict(d['limit_val']) if d['limit_val'] is not None else None
|
|
1340
|
+
sample_clause = SampleClause.from_dict(d['sample_clause']) if d['sample_clause'] is not None else None
|
|
1082
1341
|
|
|
1083
|
-
return
|
|
1342
|
+
return Query(
|
|
1084
1343
|
from_clause=from_clause,
|
|
1085
1344
|
select_list=select_list,
|
|
1086
1345
|
where_clause=where_clause,
|
|
@@ -1088,6 +1347,7 @@ class DataFrame:
|
|
|
1088
1347
|
grouping_tbl=grouping_tbl,
|
|
1089
1348
|
order_by_clause=order_by_clause,
|
|
1090
1349
|
limit=limit_val,
|
|
1350
|
+
sample_clause=sample_clause,
|
|
1091
1351
|
)
|
|
1092
1352
|
|
|
1093
1353
|
def _hash_result_set(self) -> str:
|
|
@@ -1102,8 +1362,10 @@ class DataFrame:
|
|
|
1102
1362
|
return hashlib.sha256(summary_string.encode()).hexdigest()
|
|
1103
1363
|
|
|
1104
1364
|
def to_coco_dataset(self) -> Path:
|
|
1105
|
-
"""Convert the
|
|
1106
|
-
This
|
|
1365
|
+
"""Convert the Query to a COCO dataset.
|
|
1366
|
+
This Query must return a single json-typed output column in the following format:
|
|
1367
|
+
|
|
1368
|
+
```python
|
|
1107
1369
|
{
|
|
1108
1370
|
'image': PIL.Image.Image,
|
|
1109
1371
|
'annotations': [
|
|
@@ -1114,6 +1376,7 @@ class DataFrame:
|
|
|
1114
1376
|
...
|
|
1115
1377
|
],
|
|
1116
1378
|
}
|
|
1379
|
+
```
|
|
1117
1380
|
|
|
1118
1381
|
Returns:
|
|
1119
1382
|
Path to the COCO dataset file.
|
|
@@ -1129,12 +1392,13 @@ class DataFrame:
|
|
|
1129
1392
|
assert data_file_path.is_file()
|
|
1130
1393
|
return data_file_path
|
|
1131
1394
|
else:
|
|
1132
|
-
|
|
1395
|
+
# TODO: extend begin_xact() to accept multiple TVPs for joins
|
|
1396
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
|
|
1133
1397
|
return write_coco_dataset(self, dest_path)
|
|
1134
1398
|
|
|
1135
1399
|
def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
|
|
1136
1400
|
"""
|
|
1137
|
-
Convert the
|
|
1401
|
+
Convert the Query to a pytorch IterableDataset suitable for parallel loading
|
|
1138
1402
|
with torch.utils.data.DataLoader.
|
|
1139
1403
|
|
|
1140
1404
|
This method requires pyarrow >= 13, torch and torchvision to work.
|
|
@@ -1174,7 +1438,7 @@ class DataFrame:
|
|
|
1174
1438
|
if dest_path.exists(): # fast path: use cache
|
|
1175
1439
|
assert dest_path.is_dir()
|
|
1176
1440
|
else:
|
|
1177
|
-
with
|
|
1441
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
|
|
1178
1442
|
export_parquet(self, dest_path, inline_images=True)
|
|
1179
1443
|
|
|
1180
1444
|
return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
|