pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/_query.py
ADDED
|
@@ -0,0 +1,1444 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
import copy
|
|
5
|
+
import dataclasses
|
|
6
|
+
import hashlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import traceback
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Hashable, Iterator, NoReturn, Sequence, TypeVar
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import pydantic
|
|
15
|
+
import sqlalchemy.exc as sql_exc
|
|
16
|
+
|
|
17
|
+
from pixeltable import catalog, exceptions as excs, exec, exprs, plan, type_system as ts
|
|
18
|
+
from pixeltable.catalog import Catalog, is_valid_identifier
|
|
19
|
+
from pixeltable.catalog.update_status import UpdateStatus
|
|
20
|
+
from pixeltable.env import Env
|
|
21
|
+
from pixeltable.plan import Planner, SampleClause
|
|
22
|
+
from pixeltable.type_system import ColumnType
|
|
23
|
+
from pixeltable.utils.description_helper import DescriptionHelper
|
|
24
|
+
from pixeltable.utils.formatter import Formatter
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
import torch
|
|
28
|
+
import torch.utils.data
|
|
29
|
+
|
|
30
|
+
__all__ = ['Query']
|
|
31
|
+
|
|
32
|
+
_logger = logging.getLogger('pixeltable')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ResultSet:
|
|
36
|
+
_rows: list[list[Any]]
|
|
37
|
+
_col_names: list[str]
|
|
38
|
+
__schema: dict[str, ColumnType]
|
|
39
|
+
__formatter: Formatter
|
|
40
|
+
|
|
41
|
+
def __init__(self, rows: list[list[Any]], schema: dict[str, ColumnType]):
|
|
42
|
+
self._rows = rows
|
|
43
|
+
self._col_names = list(schema.keys())
|
|
44
|
+
self.__schema = schema
|
|
45
|
+
self.__formatter = Formatter(len(self._rows), len(self._col_names), Env.get().http_address)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def schema(self) -> dict[str, ColumnType]:
|
|
49
|
+
return self.__schema
|
|
50
|
+
|
|
51
|
+
def __len__(self) -> int:
|
|
52
|
+
return len(self._rows)
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
return self.to_pandas().__repr__()
|
|
56
|
+
|
|
57
|
+
def _repr_html_(self) -> str:
|
|
58
|
+
formatters: dict[Hashable, Callable[[object], str]] = {}
|
|
59
|
+
for col_name, col_type in self.schema.items():
|
|
60
|
+
formatter = self.__formatter.get_pandas_formatter(col_type)
|
|
61
|
+
if formatter is not None:
|
|
62
|
+
formatters[col_name] = formatter
|
|
63
|
+
return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
|
|
64
|
+
|
|
65
|
+
def __str__(self) -> str:
|
|
66
|
+
return self.to_pandas().to_string()
|
|
67
|
+
|
|
68
|
+
def _reverse(self) -> None:
|
|
69
|
+
"""Reverse order of rows"""
|
|
70
|
+
self._rows.reverse()
|
|
71
|
+
|
|
72
|
+
def to_pandas(self) -> pd.DataFrame:
|
|
73
|
+
return pd.DataFrame.from_records(self._rows, columns=self._col_names)
|
|
74
|
+
|
|
75
|
+
BaseModelT = TypeVar('BaseModelT', bound=pydantic.BaseModel)
|
|
76
|
+
|
|
77
|
+
def to_pydantic(self, model: type[BaseModelT]) -> Iterator[BaseModelT]:
|
|
78
|
+
"""
|
|
79
|
+
Convert the ResultSet to a list of Pydantic model instances.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model: A Pydantic model class.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
An iterator over Pydantic model instances, one for each row in the result set.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
Error: If the row data doesn't match the model schema.
|
|
89
|
+
"""
|
|
90
|
+
model_fields = model.model_fields
|
|
91
|
+
model_config = getattr(model, 'model_config', {})
|
|
92
|
+
forbid_extra_fields = model_config.get('extra') == 'forbid'
|
|
93
|
+
|
|
94
|
+
# schema validation
|
|
95
|
+
required_fields = {name for name, field in model_fields.items() if field.is_required()}
|
|
96
|
+
col_names = set(self._col_names)
|
|
97
|
+
missing_fields = required_fields - col_names
|
|
98
|
+
if len(missing_fields) > 0:
|
|
99
|
+
raise excs.Error(
|
|
100
|
+
f'Required model fields {missing_fields} are missing from result set columns {self._col_names}'
|
|
101
|
+
)
|
|
102
|
+
if forbid_extra_fields:
|
|
103
|
+
extra_fields = col_names - set(model_fields.keys())
|
|
104
|
+
if len(extra_fields) > 0:
|
|
105
|
+
raise excs.Error(f"Extra fields {extra_fields} are not allowed in model with extra='forbid'")
|
|
106
|
+
|
|
107
|
+
for row in self:
|
|
108
|
+
try:
|
|
109
|
+
yield model(**row)
|
|
110
|
+
except pydantic.ValidationError as e:
|
|
111
|
+
raise excs.Error(str(e)) from e
|
|
112
|
+
|
|
113
|
+
def _row_to_dict(self, row_idx: int) -> dict[str, Any]:
|
|
114
|
+
return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
|
|
115
|
+
|
|
116
|
+
def __getitem__(self, index: Any) -> Any:
|
|
117
|
+
if isinstance(index, str):
|
|
118
|
+
if index not in self._col_names:
|
|
119
|
+
raise excs.Error(f'Invalid column name: {index}')
|
|
120
|
+
col_idx = self._col_names.index(index)
|
|
121
|
+
return [row[col_idx] for row in self._rows]
|
|
122
|
+
if isinstance(index, int):
|
|
123
|
+
return self._row_to_dict(index)
|
|
124
|
+
if isinstance(index, tuple) and len(index) == 2:
|
|
125
|
+
if not isinstance(index[0], int) or not isinstance(index[1], (str, int)):
|
|
126
|
+
raise excs.Error(f'Bad index, expected [<row idx>, <column name | column index>]: {index}')
|
|
127
|
+
if isinstance(index[1], str) and index[1] not in self._col_names:
|
|
128
|
+
raise excs.Error(f'Invalid column name: {index[1]}')
|
|
129
|
+
col_idx = self._col_names.index(index[1]) if isinstance(index[1], str) else index[1]
|
|
130
|
+
return self._rows[index[0]][col_idx]
|
|
131
|
+
raise excs.Error(f'Bad index: {index}')
|
|
132
|
+
|
|
133
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
134
|
+
return (self._row_to_dict(i) for i in range(len(self)))
|
|
135
|
+
|
|
136
|
+
def __eq__(self, other: object) -> bool:
|
|
137
|
+
if not isinstance(other, ResultSet):
|
|
138
|
+
return False
|
|
139
|
+
return self.to_pandas().equals(other.to_pandas())
|
|
140
|
+
|
|
141
|
+
def __hash__(self) -> int:
|
|
142
|
+
return hash(self.to_pandas())
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
|
|
146
|
+
# class AnalysisInfo:
|
|
147
|
+
# def __init__(self, tbl: catalog.TableVersion):
|
|
148
|
+
# self.tbl = tbl
|
|
149
|
+
# # output of the SQL scan stage
|
|
150
|
+
# self.sql_scan_output_exprs: list[exprs.Expr] = []
|
|
151
|
+
# # output of the agg stage
|
|
152
|
+
# self.agg_output_exprs: list[exprs.Expr] = []
|
|
153
|
+
# # Where clause of the Select stmt of the SQL scan stage
|
|
154
|
+
# self.sql_where_clause: sql.ClauseElement | None = None
|
|
155
|
+
# # filter predicate applied to input rows of the SQL scan stage
|
|
156
|
+
# self.filter: exprs.Predicate | None = None
|
|
157
|
+
# self.similarity_clause: exprs.ImageSimilarityPredicate | None = None
|
|
158
|
+
# self.agg_fn_calls: list[exprs.FunctionCall] = [] # derived from unique_exprs
|
|
159
|
+
# self.has_frame_col: bool = False # True if we're referencing the frame col
|
|
160
|
+
#
|
|
161
|
+
# self.evaluator: exprs.Evaluator | None = None
|
|
162
|
+
# self.sql_scan_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of SQL scan stage
|
|
163
|
+
# self.agg_eval_ctx: list[exprs.Expr] = [] # needed to materialize output of agg stage
|
|
164
|
+
# self.filter_eval_ctx: list[exprs.Expr] = []
|
|
165
|
+
# self.group_by_eval_ctx: list[exprs.Expr] = []
|
|
166
|
+
#
|
|
167
|
+
# def finalize_exec(self) -> None:
|
|
168
|
+
# """
|
|
169
|
+
# Call release() on all collected Exprs.
|
|
170
|
+
# """
|
|
171
|
+
# exprs.Expr.release_list(self.sql_scan_output_exprs)
|
|
172
|
+
# exprs.Expr.release_list(self.agg_output_exprs)
|
|
173
|
+
# if self.filter is not None:
|
|
174
|
+
# self.filter.release()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Query:
|
|
178
|
+
"""Represents a query for retrieving and transforming data from Pixeltable tables."""
|
|
179
|
+
|
|
180
|
+
_from_clause: plan.FromClause
|
|
181
|
+
_select_list_exprs: list[exprs.Expr]
|
|
182
|
+
_schema: dict[str, ts.ColumnType]
|
|
183
|
+
select_list: list[tuple[exprs.Expr, str | None]] | None
|
|
184
|
+
where_clause: exprs.Expr | None
|
|
185
|
+
group_by_clause: list[exprs.Expr] | None
|
|
186
|
+
grouping_tbl: catalog.TableVersion | None
|
|
187
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None
|
|
188
|
+
limit_val: exprs.Expr | None
|
|
189
|
+
sample_clause: SampleClause | None
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
from_clause: plan.FromClause | None = None,
|
|
194
|
+
select_list: list[tuple[exprs.Expr, str | None]] | None = None,
|
|
195
|
+
where_clause: exprs.Expr | None = None,
|
|
196
|
+
group_by_clause: list[exprs.Expr] | None = None,
|
|
197
|
+
grouping_tbl: catalog.TableVersion | None = None,
|
|
198
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None = None, # list[(expr, asc)]
|
|
199
|
+
limit: exprs.Expr | None = None,
|
|
200
|
+
sample_clause: SampleClause | None = None,
|
|
201
|
+
):
|
|
202
|
+
self._from_clause = from_clause
|
|
203
|
+
|
|
204
|
+
# exprs contain execution state and therefore cannot be shared
|
|
205
|
+
select_list = copy.deepcopy(select_list)
|
|
206
|
+
select_list_exprs, column_names = Query._normalize_select_list(self._from_clause.tbls, select_list)
|
|
207
|
+
# check select list after expansion to catch early
|
|
208
|
+
# the following two lists are always non empty, even if select list is None.
|
|
209
|
+
assert len(column_names) == len(select_list_exprs)
|
|
210
|
+
self._select_list_exprs = select_list_exprs
|
|
211
|
+
self._schema = {column_names[i]: select_list_exprs[i].col_type for i in range(len(column_names))}
|
|
212
|
+
self.select_list = select_list
|
|
213
|
+
|
|
214
|
+
self.where_clause = copy.deepcopy(where_clause)
|
|
215
|
+
assert group_by_clause is None or grouping_tbl is None
|
|
216
|
+
self.group_by_clause = copy.deepcopy(group_by_clause)
|
|
217
|
+
self.grouping_tbl = grouping_tbl
|
|
218
|
+
self.order_by_clause = copy.deepcopy(order_by_clause)
|
|
219
|
+
self.limit_val = limit
|
|
220
|
+
self.sample_clause = sample_clause
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def _normalize_select_list(
|
|
224
|
+
cls, tbls: list[catalog.TableVersionPath], select_list: list[tuple[exprs.Expr, str | None]] | None
|
|
225
|
+
) -> tuple[list[exprs.Expr], list[str]]:
|
|
226
|
+
"""
|
|
227
|
+
Expand select list information with all columns and their names
|
|
228
|
+
Returns:
|
|
229
|
+
a pair composed of the list of expressions and the list of corresponding names
|
|
230
|
+
"""
|
|
231
|
+
if select_list is None:
|
|
232
|
+
select_list = [(exprs.ColumnRef(col), None) for tbl in tbls for col in tbl.columns()]
|
|
233
|
+
|
|
234
|
+
out_exprs: list[exprs.Expr] = []
|
|
235
|
+
out_names: list[str] = [] # keep track of order
|
|
236
|
+
seen_out_names: set[str] = set() # use to check for duplicates in loop, avoid square complexity
|
|
237
|
+
for i, (expr, name) in enumerate(select_list):
|
|
238
|
+
if name is None:
|
|
239
|
+
# use default, add suffix if needed so default adds no duplicates
|
|
240
|
+
default_name = expr.default_column_name()
|
|
241
|
+
if default_name is not None:
|
|
242
|
+
column_name = default_name
|
|
243
|
+
if default_name in seen_out_names:
|
|
244
|
+
# already used, then add suffix until unique name is found
|
|
245
|
+
for j in range(1, len(out_names) + 1):
|
|
246
|
+
column_name = f'{default_name}_{j}'
|
|
247
|
+
if column_name not in seen_out_names:
|
|
248
|
+
break
|
|
249
|
+
else: # no default name, eg some expressions
|
|
250
|
+
column_name = f'col_{i}'
|
|
251
|
+
else: # user provided name, no attempt to rename
|
|
252
|
+
column_name = name
|
|
253
|
+
|
|
254
|
+
out_exprs.append(expr)
|
|
255
|
+
out_names.append(column_name)
|
|
256
|
+
seen_out_names.add(column_name)
|
|
257
|
+
assert len(out_exprs) == len(out_names)
|
|
258
|
+
assert set(out_names) == seen_out_names
|
|
259
|
+
return out_exprs, out_names
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def _first_tbl(self) -> catalog.TableVersionPath:
|
|
263
|
+
return self._from_clause._first_tbl
|
|
264
|
+
|
|
265
|
+
def _vars(self) -> dict[str, exprs.Variable]:
|
|
266
|
+
"""
|
|
267
|
+
Return a dict mapping variable name to Variable for all Variables contained in any component of the Query
|
|
268
|
+
"""
|
|
269
|
+
all_exprs: list[exprs.Expr] = []
|
|
270
|
+
all_exprs.extend(self._select_list_exprs)
|
|
271
|
+
if self.where_clause is not None:
|
|
272
|
+
all_exprs.append(self.where_clause)
|
|
273
|
+
if self.group_by_clause is not None:
|
|
274
|
+
all_exprs.extend(self.group_by_clause)
|
|
275
|
+
if self.order_by_clause is not None:
|
|
276
|
+
all_exprs.extend([expr for expr, _ in self.order_by_clause])
|
|
277
|
+
if self.limit_val is not None:
|
|
278
|
+
all_exprs.append(self.limit_val)
|
|
279
|
+
vars = exprs.Expr.list_subexprs(all_exprs, expr_class=exprs.Variable)
|
|
280
|
+
unique_vars: dict[str, exprs.Variable] = {}
|
|
281
|
+
for var in vars:
|
|
282
|
+
if var.name not in unique_vars:
|
|
283
|
+
unique_vars[var.name] = var
|
|
284
|
+
elif unique_vars[var.name].col_type != var.col_type:
|
|
285
|
+
raise excs.Error(f'Multiple definitions of parameter {var.name!r}')
|
|
286
|
+
return unique_vars
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def _convert_param_to_typed_expr(
|
|
290
|
+
cls, v: Any, required_type: ts.ColumnType, required: bool, name: str, range: tuple[Any, Any] | None = None
|
|
291
|
+
) -> exprs.Expr | None:
|
|
292
|
+
if v is None:
|
|
293
|
+
if required:
|
|
294
|
+
raise excs.Error(f'{name!r} parameter must be present')
|
|
295
|
+
return v
|
|
296
|
+
v_expr = exprs.Expr.from_object(v)
|
|
297
|
+
if not v_expr.col_type.matches(required_type):
|
|
298
|
+
raise excs.Error(f'{name!r} parameter must be of type `{required_type}`; got `{v_expr.col_type}`')
|
|
299
|
+
if range is not None:
|
|
300
|
+
if not isinstance(v_expr, exprs.Literal):
|
|
301
|
+
raise excs.Error(f'{name!r} parameter must be a constant; got: {v_expr}')
|
|
302
|
+
if range[0] is not None and not (v_expr.val >= range[0]):
|
|
303
|
+
raise excs.Error(f'{name!r} parameter must be >= {range[0]}')
|
|
304
|
+
if range[1] is not None and not (v_expr.val <= range[1]):
|
|
305
|
+
raise excs.Error(f'{name!r} parameter must be <= {range[1]}')
|
|
306
|
+
return v_expr
|
|
307
|
+
|
|
308
|
+
@classmethod
|
|
309
|
+
def validate_constant_type_range(
|
|
310
|
+
cls, v: Any, required_type: ts.ColumnType, required: bool, name: str, range: tuple[Any, Any] | None = None
|
|
311
|
+
) -> Any:
|
|
312
|
+
"""Validate that the given named parameter is a constant of the required type and within the specified range."""
|
|
313
|
+
v_expr = cls._convert_param_to_typed_expr(v, required_type, required, name, range)
|
|
314
|
+
if v_expr is None:
|
|
315
|
+
return None
|
|
316
|
+
return v_expr.val
|
|
317
|
+
|
|
318
|
+
def parameters(self) -> dict[str, ColumnType]:
|
|
319
|
+
"""Return a dict mapping parameter name to parameter type.
|
|
320
|
+
|
|
321
|
+
Parameters are Variables contained in any component of the Query.
|
|
322
|
+
"""
|
|
323
|
+
return {name: var.col_type for name, var in self._vars().items()}
|
|
324
|
+
|
|
325
|
+
def _exec(self) -> Iterator[exprs.DataRow]:
|
|
326
|
+
"""Run the query and return rows as a generator.
|
|
327
|
+
This function must not modify the state of the Query, otherwise it breaks dataset caching.
|
|
328
|
+
"""
|
|
329
|
+
plan = self._create_query_plan()
|
|
330
|
+
|
|
331
|
+
def exec_plan() -> Iterator[exprs.DataRow]:
|
|
332
|
+
plan.open()
|
|
333
|
+
try:
|
|
334
|
+
for row_batch in plan:
|
|
335
|
+
yield from row_batch
|
|
336
|
+
finally:
|
|
337
|
+
plan.close()
|
|
338
|
+
|
|
339
|
+
yield from exec_plan()
|
|
340
|
+
|
|
341
|
+
async def _aexec(self) -> AsyncIterator[exprs.DataRow]:
|
|
342
|
+
"""Run the query and return rows as a generator.
|
|
343
|
+
This function must not modify the state of the Query, otherwise it breaks dataset caching.
|
|
344
|
+
"""
|
|
345
|
+
plan = self._create_query_plan()
|
|
346
|
+
plan.open()
|
|
347
|
+
try:
|
|
348
|
+
async for row_batch in plan:
|
|
349
|
+
for row in row_batch:
|
|
350
|
+
yield row
|
|
351
|
+
finally:
|
|
352
|
+
plan.close()
|
|
353
|
+
|
|
354
|
+
def _create_query_plan(self) -> exec.ExecNode:
|
|
355
|
+
# construct a group-by clause if we're grouping by a table
|
|
356
|
+
group_by_clause: list[exprs.Expr] | None = None
|
|
357
|
+
if self.grouping_tbl is not None:
|
|
358
|
+
assert self.group_by_clause is None
|
|
359
|
+
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
360
|
+
# the grouping table must be a base of self.tbl
|
|
361
|
+
assert num_rowid_cols <= len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
362
|
+
group_by_clause = self.__rowid_columns(num_rowid_cols)
|
|
363
|
+
elif self.group_by_clause is not None:
|
|
364
|
+
group_by_clause = self.group_by_clause
|
|
365
|
+
|
|
366
|
+
for item in self._select_list_exprs:
|
|
367
|
+
item.bind_rel_paths()
|
|
368
|
+
|
|
369
|
+
return Planner.create_query_plan(
|
|
370
|
+
self._from_clause,
|
|
371
|
+
self._select_list_exprs,
|
|
372
|
+
where_clause=self.where_clause,
|
|
373
|
+
group_by_clause=group_by_clause,
|
|
374
|
+
order_by_clause=self.order_by_clause,
|
|
375
|
+
limit=self.limit_val,
|
|
376
|
+
sample_clause=self.sample_clause,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def __rowid_columns(self, num_rowid_cols: int | None = None) -> list[exprs.Expr]:
|
|
380
|
+
"""Return list of RowidRef for the given number of associated rowids"""
|
|
381
|
+
return Planner.rowid_columns(self._first_tbl.tbl_version, num_rowid_cols)
|
|
382
|
+
|
|
383
|
+
def _has_joins(self) -> bool:
|
|
384
|
+
return len(self._from_clause.join_clauses) > 0
|
|
385
|
+
|
|
386
|
+
def show(self, n: int = 20) -> ResultSet:
|
|
387
|
+
if self.sample_clause is not None:
|
|
388
|
+
raise excs.Error('show() cannot be used with sample()')
|
|
389
|
+
assert n is not None
|
|
390
|
+
return self.limit(n).collect()
|
|
391
|
+
|
|
392
|
+
def head(self, n: int = 10) -> ResultSet:
|
|
393
|
+
"""Return the first n rows of the Query, in insertion order of the underlying Table.
|
|
394
|
+
|
|
395
|
+
head() is not supported for joins.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
n: Number of rows to select. Default is 10.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
A ResultSet with the first n rows of the Query.
|
|
402
|
+
|
|
403
|
+
Raises:
|
|
404
|
+
Error: If the Query is the result of a join or
|
|
405
|
+
if the Query has an order_by clause.
|
|
406
|
+
"""
|
|
407
|
+
if self.order_by_clause is not None:
|
|
408
|
+
raise excs.Error('head() cannot be used with order_by()')
|
|
409
|
+
if self._has_joins():
|
|
410
|
+
raise excs.Error('head() not supported for joins')
|
|
411
|
+
if self.sample_clause is not None:
|
|
412
|
+
raise excs.Error('head() cannot be used with sample()')
|
|
413
|
+
if self.group_by_clause is not None:
|
|
414
|
+
raise excs.Error('head() cannot be used with group_by()')
|
|
415
|
+
num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
416
|
+
order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
417
|
+
return self.order_by(*order_by_clause, asc=True).limit(n).collect()
|
|
418
|
+
|
|
419
|
+
def tail(self, n: int = 10) -> ResultSet:
|
|
420
|
+
"""Return the last n rows of the Query, in insertion order of the underlying Table.
|
|
421
|
+
|
|
422
|
+
tail() is not supported for joins.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
n: Number of rows to select. Default is 10.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
A ResultSet with the last n rows of the Query.
|
|
429
|
+
|
|
430
|
+
Raises:
|
|
431
|
+
Error: If the Query is the result of a join or
|
|
432
|
+
if the Query has an order_by clause.
|
|
433
|
+
"""
|
|
434
|
+
if self.order_by_clause is not None:
|
|
435
|
+
raise excs.Error('tail() cannot be used with order_by()')
|
|
436
|
+
if self._has_joins():
|
|
437
|
+
raise excs.Error('tail() not supported for joins')
|
|
438
|
+
if self.sample_clause is not None:
|
|
439
|
+
raise excs.Error('tail() cannot be used with sample()')
|
|
440
|
+
if self.group_by_clause is not None:
|
|
441
|
+
raise excs.Error('tail() cannot be used with group_by()')
|
|
442
|
+
num_rowid_cols = len(self._first_tbl.tbl_version.get().store_tbl.rowid_columns())
|
|
443
|
+
order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
444
|
+
result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
|
|
445
|
+
result._reverse()
|
|
446
|
+
return result
|
|
447
|
+
|
|
448
|
+
@property
|
|
449
|
+
def schema(self) -> dict[str, ColumnType]:
|
|
450
|
+
"""Column names and types in this Query."""
|
|
451
|
+
return self._schema
|
|
452
|
+
|
|
453
|
+
def bind(self, args: dict[str, Any]) -> Query:
|
|
454
|
+
"""Bind arguments to parameters and return a new Query."""
|
|
455
|
+
# substitute Variables with the corresponding values according to 'args', converted to Literals
|
|
456
|
+
select_list_exprs = copy.deepcopy(self._select_list_exprs)
|
|
457
|
+
where_clause = copy.deepcopy(self.where_clause)
|
|
458
|
+
group_by_clause = copy.deepcopy(self.group_by_clause)
|
|
459
|
+
order_by_exprs = (
|
|
460
|
+
[copy.deepcopy(order_by_expr) for order_by_expr, _ in self.order_by_clause]
|
|
461
|
+
if self.order_by_clause is not None
|
|
462
|
+
else None
|
|
463
|
+
)
|
|
464
|
+
limit_val = copy.deepcopy(self.limit_val)
|
|
465
|
+
|
|
466
|
+
var_exprs: dict[exprs.Expr, exprs.Expr] = {}
|
|
467
|
+
vars = self._vars()
|
|
468
|
+
for arg_name, arg_val in args.items():
|
|
469
|
+
if arg_name not in vars:
|
|
470
|
+
# ignore unused variables
|
|
471
|
+
continue
|
|
472
|
+
var_expr = vars[arg_name]
|
|
473
|
+
arg_expr = exprs.Expr.from_object(arg_val)
|
|
474
|
+
if arg_expr is None:
|
|
475
|
+
raise excs.Error(f'That argument cannot be converted to a Pixeltable expression: {arg_val}')
|
|
476
|
+
var_exprs[var_expr] = arg_expr
|
|
477
|
+
|
|
478
|
+
exprs.Expr.list_substitute(select_list_exprs, var_exprs)
|
|
479
|
+
if where_clause is not None:
|
|
480
|
+
where_clause = where_clause.substitute(var_exprs)
|
|
481
|
+
if group_by_clause is not None:
|
|
482
|
+
exprs.Expr.list_substitute(group_by_clause, var_exprs)
|
|
483
|
+
if order_by_exprs is not None:
|
|
484
|
+
exprs.Expr.list_substitute(order_by_exprs, var_exprs)
|
|
485
|
+
|
|
486
|
+
select_list = list(zip(select_list_exprs, self.schema.keys()))
|
|
487
|
+
order_by_clause: list[tuple[exprs.Expr, bool]] | None = None
|
|
488
|
+
if order_by_exprs is not None:
|
|
489
|
+
order_by_clause = [
|
|
490
|
+
(expr, asc) for expr, asc in zip(order_by_exprs, [asc for _, asc in self.order_by_clause])
|
|
491
|
+
]
|
|
492
|
+
if limit_val is not None:
|
|
493
|
+
limit_val = limit_val.substitute(var_exprs)
|
|
494
|
+
if limit_val is not None and not isinstance(limit_val, exprs.Literal):
|
|
495
|
+
raise excs.Error(f'limit(): parameter must be a constant; got: {limit_val}')
|
|
496
|
+
|
|
497
|
+
return Query(
|
|
498
|
+
from_clause=self._from_clause,
|
|
499
|
+
select_list=select_list,
|
|
500
|
+
where_clause=where_clause,
|
|
501
|
+
group_by_clause=group_by_clause,
|
|
502
|
+
grouping_tbl=self.grouping_tbl,
|
|
503
|
+
order_by_clause=order_by_clause,
|
|
504
|
+
limit=limit_val,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
def _raise_expr_eval_err(self, e: excs.ExprEvalError) -> NoReturn:
|
|
508
|
+
msg = f'In row {e.row_num} the {e.expr_msg} encountered exception {type(e.exc).__name__}:\n{e.exc}'
|
|
509
|
+
if len(e.input_vals) > 0:
|
|
510
|
+
input_msgs = [
|
|
511
|
+
f"'{d}' = {d.col_type.print_value(e.input_vals[i])}" for i, d in enumerate(e.expr.dependencies())
|
|
512
|
+
]
|
|
513
|
+
msg += f'\nwith {", ".join(input_msgs)}'
|
|
514
|
+
assert e.exc_tb is not None
|
|
515
|
+
stack_trace = traceback.format_tb(e.exc_tb)
|
|
516
|
+
if len(stack_trace) > 2:
|
|
517
|
+
# append a stack trace if the exception happened in user code
|
|
518
|
+
# (frame 0 is ExprEvaluator and frame 1 is some expr's eval()
|
|
519
|
+
nl = '\n'
|
|
520
|
+
# [-1:0:-1]: leave out entry 0 and reverse order, so that the most recent frame is at the top
|
|
521
|
+
msg += f'\nStack:\n{nl.join(stack_trace[-1:1:-1])}'
|
|
522
|
+
raise excs.Error(msg) from e
|
|
523
|
+
|
|
524
|
+
def _output_row_iterator(self) -> Iterator[list]:
|
|
525
|
+
# TODO: extend begin_xact() to accept multiple TVPs for joins
|
|
526
|
+
single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
|
|
527
|
+
with Catalog.get().begin_xact(tbl=single_tbl, for_write=False):
|
|
528
|
+
try:
|
|
529
|
+
for data_row in self._exec():
|
|
530
|
+
yield [data_row[e.slot_idx] for e in self._select_list_exprs]
|
|
531
|
+
except excs.ExprEvalError as e:
|
|
532
|
+
self._raise_expr_eval_err(e)
|
|
533
|
+
except (sql_exc.DBAPIError, sql_exc.OperationalError, sql_exc.InternalError) as e:
|
|
534
|
+
Catalog.get().convert_sql_exc(e, tbl=(single_tbl.tbl_version if single_tbl is not None else None))
|
|
535
|
+
raise # just re-raise if not converted to a Pixeltable error
|
|
536
|
+
|
|
537
|
+
def collect(self) -> ResultSet:
|
|
538
|
+
return ResultSet(list(self._output_row_iterator()), self.schema)
|
|
539
|
+
|
|
540
|
+
async def _acollect(self) -> ResultSet:
|
|
541
|
+
single_tbl = self._first_tbl if len(self._from_clause.tbls) == 1 else None
|
|
542
|
+
try:
|
|
543
|
+
result = [[row[e.slot_idx] for e in self._select_list_exprs] async for row in self._aexec()]
|
|
544
|
+
return ResultSet(result, self.schema)
|
|
545
|
+
except excs.ExprEvalError as e:
|
|
546
|
+
self._raise_expr_eval_err(e)
|
|
547
|
+
except (sql_exc.DBAPIError, sql_exc.OperationalError, sql_exc.InternalError) as e:
|
|
548
|
+
Catalog.get().convert_sql_exc(e, tbl=(single_tbl.tbl_version if single_tbl is not None else None))
|
|
549
|
+
raise # just re-raise if not converted to a Pixeltable error
|
|
550
|
+
|
|
551
|
+
def count(self) -> int:
|
|
552
|
+
"""Return the number of rows in the Query.
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
The number of rows in the Query.
|
|
556
|
+
"""
|
|
557
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False) as conn:
|
|
558
|
+
count_stmt = Planner.create_count_stmt(self)
|
|
559
|
+
result: int = conn.execute(count_stmt).scalar_one()
|
|
560
|
+
assert isinstance(result, int)
|
|
561
|
+
return result
|
|
562
|
+
|
|
563
|
+
def _descriptors(self) -> DescriptionHelper:
|
|
564
|
+
helper = DescriptionHelper()
|
|
565
|
+
helper.append(self._col_descriptor())
|
|
566
|
+
qd = self._query_descriptor()
|
|
567
|
+
if not qd.empty:
|
|
568
|
+
helper.append(qd, show_index=True, show_header=False)
|
|
569
|
+
return helper
|
|
570
|
+
|
|
571
|
+
def _col_descriptor(self) -> pd.DataFrame:
|
|
572
|
+
return pd.DataFrame(
|
|
573
|
+
[
|
|
574
|
+
{
|
|
575
|
+
'Name': name,
|
|
576
|
+
'Type': expr.col_type._to_str(as_schema=True),
|
|
577
|
+
'Expression': expr.display_str(inline=False),
|
|
578
|
+
}
|
|
579
|
+
for name, expr in zip(self.schema.keys(), self._select_list_exprs)
|
|
580
|
+
]
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def _query_descriptor(self) -> pd.DataFrame:
|
|
584
|
+
heading_vals: list[str] = []
|
|
585
|
+
info_vals: list[str] = []
|
|
586
|
+
heading_vals.append('From')
|
|
587
|
+
info_vals.extend(tbl.tbl_name() for tbl in self._from_clause.tbls)
|
|
588
|
+
if self.where_clause is not None:
|
|
589
|
+
heading_vals.append('Where')
|
|
590
|
+
info_vals.append(self.where_clause.display_str(inline=False))
|
|
591
|
+
if self.group_by_clause is not None:
|
|
592
|
+
heading_vals.append('Group By')
|
|
593
|
+
heading_vals.extend([''] * (len(self.group_by_clause) - 1))
|
|
594
|
+
info_vals.extend(e.display_str(inline=False) for e in self.group_by_clause)
|
|
595
|
+
if self.order_by_clause is not None:
|
|
596
|
+
heading_vals.append('Order By')
|
|
597
|
+
heading_vals.extend([''] * (len(self.order_by_clause) - 1))
|
|
598
|
+
info_vals.extend(
|
|
599
|
+
[f'{e[0].display_str(inline=False)} {"asc" if e[1] else "desc"}' for e in self.order_by_clause]
|
|
600
|
+
)
|
|
601
|
+
if self.limit_val is not None:
|
|
602
|
+
heading_vals.append('Limit')
|
|
603
|
+
info_vals.append(self.limit_val.display_str(inline=False))
|
|
604
|
+
if self.sample_clause is not None:
|
|
605
|
+
heading_vals.append('Sample')
|
|
606
|
+
info_vals.append(self.sample_clause.display_str(inline=False))
|
|
607
|
+
assert len(heading_vals) == len(info_vals)
|
|
608
|
+
return pd.DataFrame(info_vals, index=heading_vals)
|
|
609
|
+
|
|
610
|
+
def describe(self) -> None:
|
|
611
|
+
"""
|
|
612
|
+
Prints a tabular description of this Query.
|
|
613
|
+
The description has two columns, heading and info, which list the contents of each 'component'
|
|
614
|
+
(select list, where clause, ...) vertically.
|
|
615
|
+
"""
|
|
616
|
+
if getattr(builtins, '__IPYTHON__', False):
|
|
617
|
+
from IPython.display import Markdown, display
|
|
618
|
+
|
|
619
|
+
display(Markdown(self._repr_html_()))
|
|
620
|
+
else:
|
|
621
|
+
print(repr(self))
|
|
622
|
+
|
|
623
|
+
def __repr__(self) -> str:
|
|
624
|
+
return self._descriptors().to_string()
|
|
625
|
+
|
|
626
|
+
def _repr_html_(self) -> str:
|
|
627
|
+
return self._descriptors().to_html()
|
|
628
|
+
|
|
629
|
+
def select(self, *items: Any, **named_items: Any) -> Query:
|
|
630
|
+
"""Select columns or expressions from the Query.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
items: expressions to be selected
|
|
634
|
+
named_items: named expressions to be selected
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
A new Query with the specified select list.
|
|
638
|
+
|
|
639
|
+
Raises:
|
|
640
|
+
Error: If the select list is already specified,
|
|
641
|
+
or if any of the specified expressions are invalid,
|
|
642
|
+
or refer to tables not in the Query.
|
|
643
|
+
|
|
644
|
+
Examples:
|
|
645
|
+
Given the Query person from a table t with all its columns and rows:
|
|
646
|
+
|
|
647
|
+
>>> person = t.select()
|
|
648
|
+
|
|
649
|
+
Select the columns 'name' and 'age' (referenced in table t) from the Query person:
|
|
650
|
+
|
|
651
|
+
>>> query = person.select(t.name, t.age)
|
|
652
|
+
|
|
653
|
+
Select the columns 'name' (referenced in table t) from the Query person,
|
|
654
|
+
and a named column 'is_adult' from the expression `age >= 18` where 'age' is
|
|
655
|
+
another column in table t:
|
|
656
|
+
|
|
657
|
+
>>> query = person.select(t.name, is_adult=(t.age >= 18))
|
|
658
|
+
|
|
659
|
+
"""
|
|
660
|
+
if self.select_list is not None:
|
|
661
|
+
raise excs.Error('Select list already specified')
|
|
662
|
+
for name, _ in named_items.items():
|
|
663
|
+
if not isinstance(name, str) or not is_valid_identifier(name):
|
|
664
|
+
raise excs.Error(f'Invalid name: {name}')
|
|
665
|
+
base_list = [(expr, None) for expr in items] + [(expr, k) for (k, expr) in named_items.items()]
|
|
666
|
+
if len(base_list) == 0:
|
|
667
|
+
return self
|
|
668
|
+
|
|
669
|
+
# analyze select list; wrap literals with the corresponding expressions
|
|
670
|
+
select_list: list[tuple[exprs.Expr, str | None]] = []
|
|
671
|
+
for raw_expr, name in base_list:
|
|
672
|
+
expr = exprs.Expr.from_object(raw_expr)
|
|
673
|
+
if expr is None:
|
|
674
|
+
raise excs.Error(f'Invalid expression: {raw_expr}')
|
|
675
|
+
if expr.col_type.is_invalid_type() and not (isinstance(expr, exprs.Literal) and expr.val is None):
|
|
676
|
+
raise excs.Error(f'Invalid type: {raw_expr}')
|
|
677
|
+
if len(self._from_clause.tbls) == 1:
|
|
678
|
+
# Select expressions need to be retargeted in order to handle snapshots correctly, as in expressions
|
|
679
|
+
# such as `snapshot.select(base_tbl.col)`
|
|
680
|
+
# TODO: For joins involving snapshots, we need a more sophisticated retarget() that can handle
|
|
681
|
+
# multiple TableVersionPaths.
|
|
682
|
+
expr = expr.copy()
|
|
683
|
+
try:
|
|
684
|
+
expr.retarget(self._from_clause.tbls[0])
|
|
685
|
+
except Exception:
|
|
686
|
+
# If retarget() fails, then the succeeding is_bound_by() will raise an error.
|
|
687
|
+
pass
|
|
688
|
+
if not expr.is_bound_by(self._from_clause.tbls):
|
|
689
|
+
raise excs.Error(
|
|
690
|
+
f"That expression cannot be evaluated in the context of this query's tables "
|
|
691
|
+
f'({",".join(tbl.tbl_version.get().versioned_name for tbl in self._from_clause.tbls)}): {expr}'
|
|
692
|
+
)
|
|
693
|
+
select_list.append((expr, name))
|
|
694
|
+
|
|
695
|
+
# check user provided names do not conflict among themselves or with auto-generated ones
|
|
696
|
+
seen: set[str] = set()
|
|
697
|
+
_, names = Query._normalize_select_list(self._from_clause.tbls, select_list)
|
|
698
|
+
for name in names:
|
|
699
|
+
if name in seen:
|
|
700
|
+
repeated_names = [j for j, x in enumerate(names) if x == name]
|
|
701
|
+
pretty = ', '.join(map(str, repeated_names))
|
|
702
|
+
raise excs.Error(f'Repeated column name {name!r} in select() at positions: {pretty}')
|
|
703
|
+
seen.add(name)
|
|
704
|
+
|
|
705
|
+
return Query(
|
|
706
|
+
from_clause=self._from_clause,
|
|
707
|
+
select_list=select_list,
|
|
708
|
+
where_clause=self.where_clause,
|
|
709
|
+
group_by_clause=self.group_by_clause,
|
|
710
|
+
grouping_tbl=self.grouping_tbl,
|
|
711
|
+
order_by_clause=self.order_by_clause,
|
|
712
|
+
limit=self.limit_val,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
def where(self, pred: exprs.Expr) -> Query:
|
|
716
|
+
"""Filter rows based on a predicate.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
pred: the predicate to filter rows
|
|
720
|
+
|
|
721
|
+
Returns:
|
|
722
|
+
A new Query with the specified predicates replacing the where-clause.
|
|
723
|
+
|
|
724
|
+
Raises:
|
|
725
|
+
Error: If the predicate is not a Pixeltable expression,
|
|
726
|
+
or if it does not return a boolean value,
|
|
727
|
+
or refers to tables not in the Query.
|
|
728
|
+
|
|
729
|
+
Examples:
|
|
730
|
+
Given the Query person from a table t with all its columns and rows:
|
|
731
|
+
|
|
732
|
+
>>> person = t.select()
|
|
733
|
+
|
|
734
|
+
Filter the above Query person to only include rows where the column 'age'
|
|
735
|
+
(referenced in table t) is greater than 30:
|
|
736
|
+
|
|
737
|
+
>>> query = person.where(t.age > 30)
|
|
738
|
+
"""
|
|
739
|
+
if self.where_clause is not None:
|
|
740
|
+
raise excs.Error('where() clause already specified')
|
|
741
|
+
if self.sample_clause is not None:
|
|
742
|
+
raise excs.Error('where() cannot be used after sample()')
|
|
743
|
+
if not isinstance(pred, exprs.Expr):
|
|
744
|
+
raise excs.Error(f'where() expects a Pixeltable expression; got: {pred}')
|
|
745
|
+
if not pred.col_type.is_bool_type():
|
|
746
|
+
raise excs.Error(f'where() expression needs to return `Bool`, but instead returns `{pred.col_type}`')
|
|
747
|
+
return Query(
|
|
748
|
+
from_clause=self._from_clause,
|
|
749
|
+
select_list=self.select_list,
|
|
750
|
+
where_clause=pred,
|
|
751
|
+
group_by_clause=self.group_by_clause,
|
|
752
|
+
grouping_tbl=self.grouping_tbl,
|
|
753
|
+
order_by_clause=self.order_by_clause,
|
|
754
|
+
limit=self.limit_val,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
def _create_join_predicate(
|
|
758
|
+
self, other: catalog.TableVersionPath, on: exprs.Expr | Sequence[exprs.ColumnRef]
|
|
759
|
+
) -> exprs.Expr:
|
|
760
|
+
"""Verifies user-specified 'on' argument and converts it into a join predicate."""
|
|
761
|
+
col_refs: list[exprs.ColumnRef] = []
|
|
762
|
+
joined_tbls = [*self._from_clause.tbls, other]
|
|
763
|
+
|
|
764
|
+
if isinstance(on, exprs.ColumnRef):
|
|
765
|
+
on = [on]
|
|
766
|
+
elif isinstance(on, exprs.Expr):
|
|
767
|
+
if not on.is_bound_by(joined_tbls):
|
|
768
|
+
raise excs.Error(f'`on` expression cannot be evaluated in the context of the joined tables: {on}')
|
|
769
|
+
if not on.col_type.is_bool_type():
|
|
770
|
+
raise excs.Error(
|
|
771
|
+
f'`on` expects an expression of type `Bool`, but got one of type `{on.col_type}`: {on}'
|
|
772
|
+
)
|
|
773
|
+
return on
|
|
774
|
+
elif not isinstance(on, Sequence) or len(on) == 0:
|
|
775
|
+
raise excs.Error('`on` must be a sequence of column references or a boolean expression')
|
|
776
|
+
|
|
777
|
+
assert isinstance(on, Sequence)
|
|
778
|
+
for col_ref in on:
|
|
779
|
+
if not isinstance(col_ref, exprs.ColumnRef):
|
|
780
|
+
raise excs.Error('`on` must be a sequence of column references or a boolean expression')
|
|
781
|
+
if not col_ref.is_bound_by(joined_tbls):
|
|
782
|
+
raise excs.Error(f'`on` expression cannot be evaluated in the context of the joined tables: {col_ref}')
|
|
783
|
+
col_refs.append(col_ref)
|
|
784
|
+
|
|
785
|
+
predicates: list[exprs.Expr] = []
|
|
786
|
+
# try to turn ColumnRefs into equality predicates
|
|
787
|
+
assert len(col_refs) > 0 and len(joined_tbls) >= 2
|
|
788
|
+
for col_ref in col_refs:
|
|
789
|
+
# identify the referenced column by name in 'other'
|
|
790
|
+
rhs_col = other.get_column(col_ref.col.name)
|
|
791
|
+
if rhs_col is None:
|
|
792
|
+
raise excs.Error(f'`on` column {col_ref.col.name!r} not found in joined table')
|
|
793
|
+
rhs_col_ref = exprs.ColumnRef(rhs_col)
|
|
794
|
+
|
|
795
|
+
lhs_col_ref: exprs.ColumnRef | None = None
|
|
796
|
+
if any(tbl.has_column(col_ref.col) for tbl in self._from_clause.tbls):
|
|
797
|
+
# col_ref comes from the existing from_clause, we use that directly
|
|
798
|
+
lhs_col_ref = col_ref
|
|
799
|
+
else:
|
|
800
|
+
# col_ref comes from other, we need to look for a match in the existing from_clause by name
|
|
801
|
+
for tbl in self._from_clause.tbls:
|
|
802
|
+
col = tbl.get_column(col_ref.col.name)
|
|
803
|
+
if col is None:
|
|
804
|
+
continue
|
|
805
|
+
if lhs_col_ref is not None:
|
|
806
|
+
raise excs.Error(f'`on`: ambiguous column reference: {col_ref.col.name}')
|
|
807
|
+
lhs_col_ref = exprs.ColumnRef(col)
|
|
808
|
+
if lhs_col_ref is None:
|
|
809
|
+
tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
|
|
810
|
+
raise excs.Error(f'`on`: column {col_ref.col.name!r} not found in any of: {" ".join(tbl_names)}')
|
|
811
|
+
pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
|
|
812
|
+
predicates.append(pred)
|
|
813
|
+
|
|
814
|
+
assert len(predicates) > 0
|
|
815
|
+
if len(predicates) == 1:
|
|
816
|
+
return predicates[0]
|
|
817
|
+
else:
|
|
818
|
+
return exprs.CompoundPredicate(operator=exprs.LogicalOperator.AND, operands=predicates)
|
|
819
|
+
|
|
820
|
+
def join(
|
|
821
|
+
self,
|
|
822
|
+
other: catalog.Table,
|
|
823
|
+
on: exprs.Expr | Sequence[exprs.ColumnRef] | None = None,
|
|
824
|
+
how: plan.JoinType.LiteralType = 'inner',
|
|
825
|
+
) -> Query:
|
|
826
|
+
"""
|
|
827
|
+
Join this Query with a table.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
other: the table to join with
|
|
831
|
+
on: the join condition, which can be either a) references to one or more columns or b) a boolean
|
|
832
|
+
expression.
|
|
833
|
+
|
|
834
|
+
- column references: implies an equality predicate that matches columns in both this
|
|
835
|
+
Query and `other` by name.
|
|
836
|
+
|
|
837
|
+
- column in `other`: A column with that same name must be present in this Query, and **it must
|
|
838
|
+
be unique** (otherwise the join is ambiguous).
|
|
839
|
+
- column in this Query: A column with that same name must be present in `other`.
|
|
840
|
+
|
|
841
|
+
- boolean expression: The expressions must be valid in the context of the joined tables.
|
|
842
|
+
how: the type of join to perform.
|
|
843
|
+
|
|
844
|
+
- `'inner'`: only keep rows that have a match in both
|
|
845
|
+
- `'left'`: keep all rows from this Query and only matching rows from the other table
|
|
846
|
+
- `'right'`: keep all rows from the other table and only matching rows from this Query
|
|
847
|
+
- `'full_outer'`: keep all rows from both this Query and the other table
|
|
848
|
+
- `'cross'`: Cartesian product; no `on` condition allowed
|
|
849
|
+
|
|
850
|
+
Returns:
|
|
851
|
+
A new Query.
|
|
852
|
+
|
|
853
|
+
Examples:
|
|
854
|
+
Perform an inner join between t1 and t2 on the column id:
|
|
855
|
+
|
|
856
|
+
>>> join1 = t1.join(t2, on=t2.id)
|
|
857
|
+
|
|
858
|
+
Perform a left outer join of join1 with t3, also on id (note that we can't specify `on=t3.id` here,
|
|
859
|
+
because that would be ambiguous, since both t1 and t2 have a column named id):
|
|
860
|
+
|
|
861
|
+
>>> join2 = join1.join(t3, on=t2.id, how='left')
|
|
862
|
+
|
|
863
|
+
Do the same, but now with an explicit join predicate:
|
|
864
|
+
|
|
865
|
+
>>> join2 = join1.join(t3, on=t2.id == t3.id, how='left')
|
|
866
|
+
|
|
867
|
+
Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
|
|
868
|
+
key columns d1 and d2 in t):
|
|
869
|
+
|
|
870
|
+
>>> query = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
|
|
871
|
+
"""
|
|
872
|
+
if self.sample_clause is not None:
|
|
873
|
+
raise excs.Error('join() cannot be used with sample()')
|
|
874
|
+
join_pred: exprs.Expr | None
|
|
875
|
+
if how == 'cross':
|
|
876
|
+
if on is not None:
|
|
877
|
+
raise excs.Error('`on` not allowed for cross join')
|
|
878
|
+
join_pred = None
|
|
879
|
+
else:
|
|
880
|
+
if on is None:
|
|
881
|
+
raise excs.Error(f'`how={how!r}` requires `on` to be present')
|
|
882
|
+
join_pred = self._create_join_predicate(other._tbl_version_path, on)
|
|
883
|
+
join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, '`how`'), join_predicate=join_pred)
|
|
884
|
+
from_clause = plan.FromClause(
|
|
885
|
+
tbls=[*self._from_clause.tbls, other._tbl_version_path],
|
|
886
|
+
join_clauses=[*self._from_clause.join_clauses, join_clause],
|
|
887
|
+
)
|
|
888
|
+
return Query(
|
|
889
|
+
from_clause=from_clause,
|
|
890
|
+
select_list=self.select_list,
|
|
891
|
+
where_clause=self.where_clause,
|
|
892
|
+
group_by_clause=self.group_by_clause,
|
|
893
|
+
grouping_tbl=self.grouping_tbl,
|
|
894
|
+
order_by_clause=self.order_by_clause,
|
|
895
|
+
limit=self.limit_val,
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
def group_by(self, *grouping_items: Any) -> Query:
|
|
899
|
+
"""Add a group-by clause to this Query.
|
|
900
|
+
|
|
901
|
+
Variants:
|
|
902
|
+
- group_by(base_tbl): group a component view by their respective base table rows
|
|
903
|
+
- group_by(expr1, expr2, expr3): group by the given expressions
|
|
904
|
+
|
|
905
|
+
Note that grouping will be applied to the rows and take effect when
|
|
906
|
+
used with an aggregation function like sum(), count() etc.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
grouping_items: expressions to group by
|
|
910
|
+
|
|
911
|
+
Returns:
|
|
912
|
+
A new Query with the specified group-by clause.
|
|
913
|
+
|
|
914
|
+
Raises:
|
|
915
|
+
Error: If the group-by clause is already specified,
|
|
916
|
+
or if the specified expression is invalid,
|
|
917
|
+
or refer to tables not in the Query,
|
|
918
|
+
or if the Query is a result of a join.
|
|
919
|
+
|
|
920
|
+
Examples:
|
|
921
|
+
Given the Query book from a table t with all its columns and rows:
|
|
922
|
+
|
|
923
|
+
>>> book = t.select()
|
|
924
|
+
|
|
925
|
+
Group the above Query book by the 'genre' column (referenced in table t):
|
|
926
|
+
|
|
927
|
+
>>> query = book.group_by(t.genre)
|
|
928
|
+
|
|
929
|
+
Use the above Query grouped by genre to count the number of
|
|
930
|
+
books for each 'genre':
|
|
931
|
+
|
|
932
|
+
>>> query = book.group_by(t.genre).select(t.genre, count=count(t.genre)).show()
|
|
933
|
+
|
|
934
|
+
Use the above Query grouped by genre to the total price of
|
|
935
|
+
books for each 'genre':
|
|
936
|
+
|
|
937
|
+
>>> query = book.group_by(t.genre).select(t.genre, total=sum(t.price)).show()
|
|
938
|
+
"""
|
|
939
|
+
if self.group_by_clause is not None:
|
|
940
|
+
raise excs.Error('group_by() already specified')
|
|
941
|
+
if self.sample_clause is not None:
|
|
942
|
+
raise excs.Error('group_by() cannot be used with sample()')
|
|
943
|
+
|
|
944
|
+
grouping_tbl: catalog.TableVersion | None = None
|
|
945
|
+
group_by_clause: list[exprs.Expr] | None = None
|
|
946
|
+
for item in grouping_items:
|
|
947
|
+
if isinstance(item, (catalog.Table, catalog.TableVersion)):
|
|
948
|
+
if len(grouping_items) > 1:
|
|
949
|
+
raise excs.Error('group_by(): only one Table can be specified')
|
|
950
|
+
if len(self._from_clause.tbls) > 1:
|
|
951
|
+
raise excs.Error('group_by() with Table not supported for joins')
|
|
952
|
+
grouping_tbl = item if isinstance(item, catalog.TableVersion) else item._tbl_version.get()
|
|
953
|
+
# we need to make sure that the grouping table is a base of self.tbl
|
|
954
|
+
base = self._first_tbl.find_tbl_version(grouping_tbl.id)
|
|
955
|
+
if base is None or base.id == self._first_tbl.tbl_id:
|
|
956
|
+
raise excs.Error(
|
|
957
|
+
f'group_by(): {grouping_tbl.name!r} is not a base table of {self._first_tbl.tbl_name()!r}'
|
|
958
|
+
)
|
|
959
|
+
break
|
|
960
|
+
if not isinstance(item, exprs.Expr):
|
|
961
|
+
raise excs.Error(f'Invalid expression in group_by(): {item}')
|
|
962
|
+
if grouping_tbl is None:
|
|
963
|
+
group_by_clause = list(grouping_items)
|
|
964
|
+
return Query(
|
|
965
|
+
from_clause=self._from_clause,
|
|
966
|
+
select_list=self.select_list,
|
|
967
|
+
where_clause=self.where_clause,
|
|
968
|
+
group_by_clause=group_by_clause,
|
|
969
|
+
grouping_tbl=grouping_tbl,
|
|
970
|
+
order_by_clause=self.order_by_clause,
|
|
971
|
+
limit=self.limit_val,
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
def distinct(self) -> Query:
|
|
975
|
+
"""
|
|
976
|
+
Remove duplicate rows from this Query.
|
|
977
|
+
|
|
978
|
+
Note that grouping will be applied to the rows based on the select clause of this Query.
|
|
979
|
+
In the absence of a select clause, by default, all columns are selected in the grouping.
|
|
980
|
+
|
|
981
|
+
Examples:
|
|
982
|
+
Select unique addresses from table `addresses`.
|
|
983
|
+
|
|
984
|
+
>>> results = addresses.distinct()
|
|
985
|
+
|
|
986
|
+
Select unique cities in table `addresses`
|
|
987
|
+
|
|
988
|
+
>>> results = addresses.city.distinct()
|
|
989
|
+
|
|
990
|
+
Select unique locations (street, city) in the state of `CA`
|
|
991
|
+
|
|
992
|
+
>>> results = addresses.select(addresses.street, addresses.city).where(addresses.state == 'CA').distinct()
|
|
993
|
+
"""
|
|
994
|
+
exps, _ = self._normalize_select_list(self._from_clause.tbls, self.select_list)
|
|
995
|
+
return self.group_by(*exps)
|
|
996
|
+
|
|
997
|
+
def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> Query:
|
|
998
|
+
"""Add an order-by clause to this Query.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
expr_list: expressions to order by
|
|
1002
|
+
asc: whether to order in ascending order (True) or descending order (False).
|
|
1003
|
+
Default is True.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
A new Query with the specified order-by clause.
|
|
1007
|
+
|
|
1008
|
+
Raises:
|
|
1009
|
+
Error: If the order-by clause is already specified,
|
|
1010
|
+
or if the specified expression is invalid,
|
|
1011
|
+
or refer to tables not in the Query.
|
|
1012
|
+
|
|
1013
|
+
Examples:
|
|
1014
|
+
Given the Query book from a table t with all its columns and rows:
|
|
1015
|
+
|
|
1016
|
+
>>> book = t.select()
|
|
1017
|
+
|
|
1018
|
+
Order the above Query book by two columns (price, pages) in descending order:
|
|
1019
|
+
|
|
1020
|
+
>>> query = book.order_by(t.price, t.pages, asc=False)
|
|
1021
|
+
|
|
1022
|
+
Order the above Query book by price in descending order, but order the pages
|
|
1023
|
+
in ascending order:
|
|
1024
|
+
|
|
1025
|
+
>>> query = book.order_by(t.price, asc=False).order_by(t.pages)
|
|
1026
|
+
"""
|
|
1027
|
+
if self.sample_clause is not None:
|
|
1028
|
+
raise excs.Error('order_by() cannot be used with sample()')
|
|
1029
|
+
for e in expr_list:
|
|
1030
|
+
if not isinstance(e, exprs.Expr):
|
|
1031
|
+
raise excs.Error(f'Invalid expression in order_by(): {e}')
|
|
1032
|
+
order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
|
|
1033
|
+
order_by_clause.extend([(e.copy(), asc) for e in expr_list])
|
|
1034
|
+
return Query(
|
|
1035
|
+
from_clause=self._from_clause,
|
|
1036
|
+
select_list=self.select_list,
|
|
1037
|
+
where_clause=self.where_clause,
|
|
1038
|
+
group_by_clause=self.group_by_clause,
|
|
1039
|
+
grouping_tbl=self.grouping_tbl,
|
|
1040
|
+
order_by_clause=order_by_clause,
|
|
1041
|
+
limit=self.limit_val,
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
def limit(self, n: int) -> Query:
|
|
1045
|
+
"""Limit the number of rows in the Query.
|
|
1046
|
+
|
|
1047
|
+
Args:
|
|
1048
|
+
n: Number of rows to select.
|
|
1049
|
+
|
|
1050
|
+
Returns:
|
|
1051
|
+
A new Query with the specified limited rows.
|
|
1052
|
+
"""
|
|
1053
|
+
if self.sample_clause is not None:
|
|
1054
|
+
raise excs.Error('limit() cannot be used with sample()')
|
|
1055
|
+
|
|
1056
|
+
limit_expr = self._convert_param_to_typed_expr(n, ts.IntType(nullable=False), True, 'limit()')
|
|
1057
|
+
return Query(
|
|
1058
|
+
from_clause=self._from_clause,
|
|
1059
|
+
select_list=self.select_list,
|
|
1060
|
+
where_clause=self.where_clause,
|
|
1061
|
+
group_by_clause=self.group_by_clause,
|
|
1062
|
+
grouping_tbl=self.grouping_tbl,
|
|
1063
|
+
order_by_clause=self.order_by_clause,
|
|
1064
|
+
limit=limit_expr,
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
def sample(
|
|
1068
|
+
self,
|
|
1069
|
+
n: int | None = None,
|
|
1070
|
+
n_per_stratum: int | None = None,
|
|
1071
|
+
fraction: float | None = None,
|
|
1072
|
+
seed: int | None = None,
|
|
1073
|
+
stratify_by: Any = None,
|
|
1074
|
+
) -> Query:
|
|
1075
|
+
"""
|
|
1076
|
+
Return a new Query specifying a sample of rows from the Query, considered in a shuffled order.
|
|
1077
|
+
|
|
1078
|
+
The size of the sample can be specified in three ways:
|
|
1079
|
+
|
|
1080
|
+
- `n`: the total number of rows to produce as a sample
|
|
1081
|
+
- `n_per_stratum`: the number of rows to produce per stratum as a sample
|
|
1082
|
+
- `fraction`: the fraction of available rows to produce as a sample
|
|
1083
|
+
|
|
1084
|
+
The sample can be stratified by one or more columns, which means that the sample will
|
|
1085
|
+
be selected from each stratum separately.
|
|
1086
|
+
|
|
1087
|
+
The data is shuffled before creating the sample.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
n: Total number of rows to produce as a sample.
|
|
1091
|
+
n_per_stratum: Number of rows to produce per stratum as a sample. This parameter is only valid if
|
|
1092
|
+
`stratify_by` is specified. Only one of `n` or `n_per_stratum` can be specified.
|
|
1093
|
+
fraction: Fraction of available rows to produce as a sample. This parameter is not usable with `n` or
|
|
1094
|
+
`n_per_stratum`. The fraction must be between 0.0 and 1.0.
|
|
1095
|
+
seed: Random seed for reproducible shuffling
|
|
1096
|
+
stratify_by: If specified, the sample will be stratified by these values.
|
|
1097
|
+
|
|
1098
|
+
Returns:
|
|
1099
|
+
A new Query which specifies the sampled rows
|
|
1100
|
+
|
|
1101
|
+
Examples:
|
|
1102
|
+
Given the Table `person` containing the field 'age', we can create samples of the table in various ways:
|
|
1103
|
+
|
|
1104
|
+
Sample 100 rows from the above Table:
|
|
1105
|
+
|
|
1106
|
+
>>> query = person.sample(n=100)
|
|
1107
|
+
|
|
1108
|
+
Sample 10% of the rows from the above Table:
|
|
1109
|
+
|
|
1110
|
+
>>> query = person.sample(fraction=0.1)
|
|
1111
|
+
|
|
1112
|
+
Sample 10% of the rows from the above Table, stratified by the column 'age':
|
|
1113
|
+
|
|
1114
|
+
>>> query = person.sample(fraction=0.1, stratify_by=t.age)
|
|
1115
|
+
|
|
1116
|
+
Equal allocation sampling: Sample 2 rows from each age present in the above Table:
|
|
1117
|
+
|
|
1118
|
+
>>> query = person.sample(n_per_stratum=2, stratify_by=t.age)
|
|
1119
|
+
|
|
1120
|
+
Sampling is compatible with the where clause, so we can also sample from a filtered Query:
|
|
1121
|
+
|
|
1122
|
+
>>> query = person.where(t.age > 30).sample(n=100)
|
|
1123
|
+
"""
|
|
1124
|
+
# Check context of usage
|
|
1125
|
+
if self.sample_clause is not None:
|
|
1126
|
+
raise excs.Error('Multiple sample() clauses not allowed')
|
|
1127
|
+
if self.group_by_clause is not None:
|
|
1128
|
+
raise excs.Error('sample() cannot be used with group_by()')
|
|
1129
|
+
if self.order_by_clause is not None:
|
|
1130
|
+
raise excs.Error('sample() cannot be used with order_by()')
|
|
1131
|
+
if self.limit_val is not None:
|
|
1132
|
+
raise excs.Error('sample() cannot be used with limit()')
|
|
1133
|
+
if self._has_joins():
|
|
1134
|
+
raise excs.Error('sample() cannot be used with join()')
|
|
1135
|
+
|
|
1136
|
+
# Check paramter combinations
|
|
1137
|
+
if (n is not None) + (n_per_stratum is not None) + (fraction is not None) != 1:
|
|
1138
|
+
raise excs.Error('Exactly one of `n`, `n_per_stratum`, or `fraction` must be specified.')
|
|
1139
|
+
if n_per_stratum is not None and stratify_by is None:
|
|
1140
|
+
raise excs.Error('Must specify `stratify_by` to use `n_per_stratum`')
|
|
1141
|
+
|
|
1142
|
+
# Check parameter types and values
|
|
1143
|
+
n = self.validate_constant_type_range(n, ts.IntType(nullable=False), False, 'n', (1, None))
|
|
1144
|
+
n_per_stratum = self.validate_constant_type_range(
|
|
1145
|
+
n_per_stratum, ts.IntType(nullable=False), False, 'n_per_stratum', (1, None)
|
|
1146
|
+
)
|
|
1147
|
+
fraction = self.validate_constant_type_range(
|
|
1148
|
+
fraction, ts.FloatType(nullable=False), False, 'fraction', (0.0, 1.0)
|
|
1149
|
+
)
|
|
1150
|
+
seed = self.validate_constant_type_range(seed, ts.IntType(nullable=False), False, 'seed')
|
|
1151
|
+
|
|
1152
|
+
# analyze stratify list
|
|
1153
|
+
stratify_exprs: list[exprs.Expr] = []
|
|
1154
|
+
if stratify_by is not None:
|
|
1155
|
+
if isinstance(stratify_by, exprs.Expr):
|
|
1156
|
+
stratify_by = [stratify_by]
|
|
1157
|
+
if not isinstance(stratify_by, (list, tuple)):
|
|
1158
|
+
raise excs.Error('`stratify_by` must be a list of scalar expressions')
|
|
1159
|
+
for expr in stratify_by:
|
|
1160
|
+
if expr is None or not isinstance(expr, exprs.Expr):
|
|
1161
|
+
raise excs.Error(f'Invalid expression: {expr}')
|
|
1162
|
+
if not expr.col_type.is_scalar_type():
|
|
1163
|
+
raise excs.Error(f'Invalid type: expression must be a scalar type (not `{expr.col_type}`)')
|
|
1164
|
+
if not expr.is_bound_by(self._from_clause.tbls):
|
|
1165
|
+
raise excs.Error(
|
|
1166
|
+
f"That expression cannot be evaluated in the context of this query's tables "
|
|
1167
|
+
f'({",".join(tbl.tbl_name() for tbl in self._from_clause.tbls)}): {expr}'
|
|
1168
|
+
)
|
|
1169
|
+
stratify_exprs.append(expr)
|
|
1170
|
+
|
|
1171
|
+
sample_clause = SampleClause(None, n, n_per_stratum, fraction, seed, stratify_exprs)
|
|
1172
|
+
|
|
1173
|
+
return Query(
|
|
1174
|
+
from_clause=self._from_clause,
|
|
1175
|
+
select_list=self.select_list,
|
|
1176
|
+
where_clause=self.where_clause,
|
|
1177
|
+
group_by_clause=self.group_by_clause,
|
|
1178
|
+
grouping_tbl=self.grouping_tbl,
|
|
1179
|
+
order_by_clause=self.order_by_clause,
|
|
1180
|
+
limit=self.limit_val,
|
|
1181
|
+
sample_clause=sample_clause,
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
|
|
1185
|
+
"""Update rows in the underlying table of the Query.
|
|
1186
|
+
|
|
1187
|
+
Update rows in the table with the specified value_spec.
|
|
1188
|
+
|
|
1189
|
+
Args:
|
|
1190
|
+
value_spec: a dict of column names to update and the new value to update it to.
|
|
1191
|
+
cascade: if True, also update all computed columns that transitively depend
|
|
1192
|
+
on the updated columns, including within views. Default is True.
|
|
1193
|
+
|
|
1194
|
+
Returns:
|
|
1195
|
+
UpdateStatus: the status of the update operation.
|
|
1196
|
+
|
|
1197
|
+
Example:
|
|
1198
|
+
Given the Query person from a table t with all its columns and rows:
|
|
1199
|
+
|
|
1200
|
+
>>> person = t.select()
|
|
1201
|
+
|
|
1202
|
+
Via the above Query person, update the column 'city' to 'Oakland'
|
|
1203
|
+
and 'state' to 'CA' in the table t:
|
|
1204
|
+
|
|
1205
|
+
>>> person.update({'city': 'Oakland', 'state': 'CA'})
|
|
1206
|
+
|
|
1207
|
+
Via the above Query person, update the column 'age' to 30 for any
|
|
1208
|
+
rows where 'year' is 2014 in the table t:
|
|
1209
|
+
|
|
1210
|
+
>>> person.where(t.year == 2014).update({'age': 30})
|
|
1211
|
+
"""
|
|
1212
|
+
self._validate_mutable('update', False)
|
|
1213
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
1214
|
+
return self._first_tbl.tbl_version.get().update(value_spec, where=self.where_clause, cascade=cascade)
|
|
1215
|
+
|
|
1216
|
+
def recompute_columns(
|
|
1217
|
+
self, *columns: str | exprs.ColumnRef, errors_only: bool = False, cascade: bool = True
|
|
1218
|
+
) -> UpdateStatus:
|
|
1219
|
+
"""Recompute one or more computed columns of the underlying table of the Query.
|
|
1220
|
+
|
|
1221
|
+
Args:
|
|
1222
|
+
columns: The names or references of the computed columns to recompute.
|
|
1223
|
+
errors_only: If True, only run the recomputation for rows that have errors in the column (ie, the column's
|
|
1224
|
+
`errortype` property indicates that an error occurred). Only allowed for recomputing a single column.
|
|
1225
|
+
cascade: if True, also update all computed columns that transitively depend on the recomputed columns.
|
|
1226
|
+
|
|
1227
|
+
Returns:
|
|
1228
|
+
UpdateStatus: the status of the operation.
|
|
1229
|
+
|
|
1230
|
+
Example:
|
|
1231
|
+
For table `person` with column `age` and computed column `height`, recompute the value of `height` for all
|
|
1232
|
+
rows where `age` is less than 18:
|
|
1233
|
+
|
|
1234
|
+
>>> query = person.where(t.age < 18).recompute_columns(person.height)
|
|
1235
|
+
"""
|
|
1236
|
+
self._validate_mutable('recompute_columns', False)
|
|
1237
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
1238
|
+
tbl = Catalog.get().get_table_by_id(self._first_tbl.tbl_id)
|
|
1239
|
+
return tbl.recompute_columns(*columns, where=self.where_clause, errors_only=errors_only, cascade=cascade)
|
|
1240
|
+
|
|
1241
|
+
def delete(self) -> UpdateStatus:
|
|
1242
|
+
"""Delete rows form the underlying table of the Query.
|
|
1243
|
+
|
|
1244
|
+
The delete operation is only allowed for Queries on base tables.
|
|
1245
|
+
|
|
1246
|
+
Returns:
|
|
1247
|
+
UpdateStatus: the status of the delete operation.
|
|
1248
|
+
|
|
1249
|
+
Example:
|
|
1250
|
+
For a table `person` with column `age`, delete all rows where 'age' is less than 18:
|
|
1251
|
+
|
|
1252
|
+
>>> person.where(t.age < 18).delete()
|
|
1253
|
+
"""
|
|
1254
|
+
self._validate_mutable('delete', False)
|
|
1255
|
+
if not self._first_tbl.is_insertable():
|
|
1256
|
+
raise excs.Error('Cannot use `delete` on a view.')
|
|
1257
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=True, lock_mutable_tree=True):
|
|
1258
|
+
return self._first_tbl.tbl_version.get().delete(where=self.where_clause)
|
|
1259
|
+
|
|
1260
|
+
def _validate_mutable(self, op_name: str, allow_select: bool) -> None:
|
|
1261
|
+
"""Tests whether this Query can be mutated (such as by an update operation).
|
|
1262
|
+
|
|
1263
|
+
Args:
|
|
1264
|
+
op_name: The name of the operation for which the test is being performed.
|
|
1265
|
+
allow_select: If True, allow a select() specification in the Query.
|
|
1266
|
+
"""
|
|
1267
|
+
self._validate_mutable_op_sequence(op_name, allow_select)
|
|
1268
|
+
|
|
1269
|
+
# TODO: Reconcile these with Table.__check_mutable()
|
|
1270
|
+
assert len(self._from_clause.tbls) == 1
|
|
1271
|
+
# First check if it's a replica, since every replica handle is also a snapshot
|
|
1272
|
+
if self._first_tbl.is_replica():
|
|
1273
|
+
raise excs.Error(f'Cannot use `{op_name}` on a replica.')
|
|
1274
|
+
if self._first_tbl.is_snapshot():
|
|
1275
|
+
raise excs.Error(f'Cannot use `{op_name}` on a snapshot.')
|
|
1276
|
+
|
|
1277
|
+
def _validate_mutable_op_sequence(self, op_name: str, allow_select: bool) -> None:
|
|
1278
|
+
"""Tests whether the sequence of operations on this Query is valid for a mutation operation."""
|
|
1279
|
+
if self.group_by_clause is not None or self.grouping_tbl is not None:
|
|
1280
|
+
raise excs.Error(f'Cannot use `{op_name}` after `group_by`.')
|
|
1281
|
+
if self.order_by_clause is not None:
|
|
1282
|
+
raise excs.Error(f'Cannot use `{op_name}` after `order_by`.')
|
|
1283
|
+
if self.select_list is not None and not allow_select:
|
|
1284
|
+
raise excs.Error(f'Cannot use `{op_name}` after `select`.')
|
|
1285
|
+
if self.limit_val is not None:
|
|
1286
|
+
raise excs.Error(f'Cannot use `{op_name}` after `limit`.')
|
|
1287
|
+
if self._has_joins():
|
|
1288
|
+
raise excs.Error(f'Cannot use `{op_name}` after `join`.')
|
|
1289
|
+
|
|
1290
|
+
def as_dict(self) -> dict[str, Any]:
|
|
1291
|
+
"""
|
|
1292
|
+
Returns:
|
|
1293
|
+
Dictionary representing this Query.
|
|
1294
|
+
"""
|
|
1295
|
+
d = {
|
|
1296
|
+
'_classname': 'Query',
|
|
1297
|
+
'from_clause': {
|
|
1298
|
+
'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
|
|
1299
|
+
'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses],
|
|
1300
|
+
},
|
|
1301
|
+
'select_list': [(e.as_dict(), name) for (e, name) in self.select_list]
|
|
1302
|
+
if self.select_list is not None
|
|
1303
|
+
else None,
|
|
1304
|
+
'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
|
|
1305
|
+
'group_by_clause': [e.as_dict() for e in self.group_by_clause]
|
|
1306
|
+
if self.group_by_clause is not None
|
|
1307
|
+
else None,
|
|
1308
|
+
'grouping_tbl': self.grouping_tbl.as_dict() if self.grouping_tbl is not None else None,
|
|
1309
|
+
'order_by_clause': [(e.as_dict(), asc) for (e, asc) in self.order_by_clause]
|
|
1310
|
+
if self.order_by_clause is not None
|
|
1311
|
+
else None,
|
|
1312
|
+
'limit_val': self.limit_val.as_dict() if self.limit_val is not None else None,
|
|
1313
|
+
'sample_clause': self.sample_clause.as_dict() if self.sample_clause is not None else None,
|
|
1314
|
+
}
|
|
1315
|
+
return d
|
|
1316
|
+
|
|
1317
|
+
@classmethod
|
|
1318
|
+
def from_dict(cls, d: dict[str, Any]) -> 'Query':
|
|
1319
|
+
# we need to wrap the construction with a transaction, because it might need to load metadata
|
|
1320
|
+
with Catalog.get().begin_xact(for_write=False):
|
|
1321
|
+
tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
|
|
1322
|
+
join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
|
|
1323
|
+
from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
|
|
1324
|
+
select_list = (
|
|
1325
|
+
[(exprs.Expr.from_dict(e), name) for e, name in d['select_list']]
|
|
1326
|
+
if d['select_list'] is not None
|
|
1327
|
+
else None
|
|
1328
|
+
)
|
|
1329
|
+
where_clause = exprs.Expr.from_dict(d['where_clause']) if d['where_clause'] is not None else None
|
|
1330
|
+
group_by_clause = (
|
|
1331
|
+
[exprs.Expr.from_dict(e) for e in d['group_by_clause']] if d['group_by_clause'] is not None else None
|
|
1332
|
+
)
|
|
1333
|
+
grouping_tbl = catalog.TableVersion.from_dict(d['grouping_tbl']) if d['grouping_tbl'] is not None else None
|
|
1334
|
+
order_by_clause = (
|
|
1335
|
+
[(exprs.Expr.from_dict(e), asc) for e, asc in d['order_by_clause']]
|
|
1336
|
+
if d['order_by_clause'] is not None
|
|
1337
|
+
else None
|
|
1338
|
+
)
|
|
1339
|
+
limit_val = exprs.Expr.from_dict(d['limit_val']) if d['limit_val'] is not None else None
|
|
1340
|
+
sample_clause = SampleClause.from_dict(d['sample_clause']) if d['sample_clause'] is not None else None
|
|
1341
|
+
|
|
1342
|
+
return Query(
|
|
1343
|
+
from_clause=from_clause,
|
|
1344
|
+
select_list=select_list,
|
|
1345
|
+
where_clause=where_clause,
|
|
1346
|
+
group_by_clause=group_by_clause,
|
|
1347
|
+
grouping_tbl=grouping_tbl,
|
|
1348
|
+
order_by_clause=order_by_clause,
|
|
1349
|
+
limit=limit_val,
|
|
1350
|
+
sample_clause=sample_clause,
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1353
|
+
def _hash_result_set(self) -> str:
|
|
1354
|
+
"""Return a hash that changes when the result set changes."""
|
|
1355
|
+
d = self.as_dict()
|
|
1356
|
+
# add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
|
|
1357
|
+
# invalidation when any of the referenced tables changes
|
|
1358
|
+
d['tbl_versions'] = [
|
|
1359
|
+
tbl_version.get().version for tbl in self._from_clause.tbls for tbl_version in tbl.get_tbl_versions()
|
|
1360
|
+
]
|
|
1361
|
+
summary_string = json.dumps(d)
|
|
1362
|
+
return hashlib.sha256(summary_string.encode()).hexdigest()
|
|
1363
|
+
|
|
1364
|
+
def to_coco_dataset(self) -> Path:
|
|
1365
|
+
"""Convert the Query to a COCO dataset.
|
|
1366
|
+
This Query must return a single json-typed output column in the following format:
|
|
1367
|
+
|
|
1368
|
+
```python
|
|
1369
|
+
{
|
|
1370
|
+
'image': PIL.Image.Image,
|
|
1371
|
+
'annotations': [
|
|
1372
|
+
{
|
|
1373
|
+
'bbox': [x: int, y: int, w: int, h: int],
|
|
1374
|
+
'category': str | int,
|
|
1375
|
+
},
|
|
1376
|
+
...
|
|
1377
|
+
],
|
|
1378
|
+
}
|
|
1379
|
+
```
|
|
1380
|
+
|
|
1381
|
+
Returns:
|
|
1382
|
+
Path to the COCO dataset file.
|
|
1383
|
+
"""
|
|
1384
|
+
from pixeltable.utils.coco import write_coco_dataset
|
|
1385
|
+
|
|
1386
|
+
cache_key = self._hash_result_set()
|
|
1387
|
+
dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
|
|
1388
|
+
if dest_path.exists():
|
|
1389
|
+
assert dest_path.is_dir()
|
|
1390
|
+
data_file_path = dest_path / 'data.json'
|
|
1391
|
+
assert data_file_path.exists()
|
|
1392
|
+
assert data_file_path.is_file()
|
|
1393
|
+
return data_file_path
|
|
1394
|
+
else:
|
|
1395
|
+
# TODO: extend begin_xact() to accept multiple TVPs for joins
|
|
1396
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
|
|
1397
|
+
return write_coco_dataset(self, dest_path)
|
|
1398
|
+
|
|
1399
|
+
def to_pytorch_dataset(self, image_format: str = 'pt') -> 'torch.utils.data.IterableDataset':
|
|
1400
|
+
"""
|
|
1401
|
+
Convert the Query to a pytorch IterableDataset suitable for parallel loading
|
|
1402
|
+
with torch.utils.data.DataLoader.
|
|
1403
|
+
|
|
1404
|
+
This method requires pyarrow >= 13, torch and torchvision to work.
|
|
1405
|
+
|
|
1406
|
+
This method serializes data so it can be read from disk efficiently and repeatedly without
|
|
1407
|
+
re-executing the query. This data is cached to disk for future re-use.
|
|
1408
|
+
|
|
1409
|
+
Args:
|
|
1410
|
+
image_format: format of the images. Can be 'pt' (pytorch tensor) or 'np' (numpy array).
|
|
1411
|
+
'np' means image columns return as an RGB uint8 array of shape HxWxC.
|
|
1412
|
+
'pt' means image columns return as a CxHxW tensor with values in [0,1] and type torch.float32.
|
|
1413
|
+
(the format output by torchvision.transforms.ToTensor())
|
|
1414
|
+
|
|
1415
|
+
Returns:
|
|
1416
|
+
A pytorch IterableDataset: Columns become fields of the dataset, where rows are returned as a dictionary
|
|
1417
|
+
compatible with torch.utils.data.DataLoader default collation.
|
|
1418
|
+
|
|
1419
|
+
Constraints:
|
|
1420
|
+
The default collate_fn for torch.data.util.DataLoader cannot represent null values as part of a
|
|
1421
|
+
pytorch tensor when forming batches. These values will raise an exception while running the dataloader.
|
|
1422
|
+
|
|
1423
|
+
If you have them, you can work around None values by providing your custom collate_fn to the DataLoader
|
|
1424
|
+
(and have your model handle it). Or, if these are not meaningful values within a minibtach, you can
|
|
1425
|
+
modify or remove any such values through selections and filters prior to calling to_pytorch_dataset().
|
|
1426
|
+
"""
|
|
1427
|
+
# check dependencies
|
|
1428
|
+
Env.get().require_package('pyarrow', [13])
|
|
1429
|
+
Env.get().require_package('torch')
|
|
1430
|
+
Env.get().require_package('torchvision')
|
|
1431
|
+
|
|
1432
|
+
from pixeltable.io import export_parquet
|
|
1433
|
+
from pixeltable.utils.pytorch import PixeltablePytorchDataset
|
|
1434
|
+
|
|
1435
|
+
cache_key = self._hash_result_set()
|
|
1436
|
+
|
|
1437
|
+
dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet')
|
|
1438
|
+
if dest_path.exists(): # fast path: use cache
|
|
1439
|
+
assert dest_path.is_dir()
|
|
1440
|
+
else:
|
|
1441
|
+
with Catalog.get().begin_xact(tbl=self._first_tbl, for_write=False):
|
|
1442
|
+
export_parquet(self, dest_path, inline_images=True)
|
|
1443
|
+
|
|
1444
|
+
return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
|