pixeltable 0.2.24__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +2 -2
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +1 -1
- pixeltable/catalog/dir.py +6 -0
- pixeltable/catalog/globals.py +25 -0
- pixeltable/catalog/named_function.py +4 -0
- pixeltable/catalog/path_dict.py +37 -11
- pixeltable/catalog/schema_object.py +6 -0
- pixeltable/catalog/table.py +531 -251
- pixeltable/catalog/table_version.py +22 -8
- pixeltable/catalog/view.py +8 -7
- pixeltable/dataframe.py +439 -105
- pixeltable/env.py +19 -5
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/exec_node.py +6 -7
- pixeltable/exec/expr_eval_node.py +1 -1
- pixeltable/exec/sql_node.py +92 -45
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +1 -1
- pixeltable/exprs/array_slice.py +1 -1
- pixeltable/exprs/column_property_ref.py +1 -1
- pixeltable/exprs/column_ref.py +29 -2
- pixeltable/exprs/comparison.py +1 -1
- pixeltable/exprs/compound_predicate.py +1 -1
- pixeltable/exprs/expr.py +12 -5
- pixeltable/exprs/expr_set.py +8 -0
- pixeltable/exprs/function_call.py +147 -39
- pixeltable/exprs/in_predicate.py +1 -1
- pixeltable/exprs/inline_expr.py +25 -5
- pixeltable/exprs/is_null.py +1 -1
- pixeltable/exprs/json_mapper.py +1 -1
- pixeltable/exprs/json_path.py +1 -1
- pixeltable/exprs/method_ref.py +1 -1
- pixeltable/exprs/row_builder.py +1 -1
- pixeltable/exprs/rowid_ref.py +1 -1
- pixeltable/exprs/similarity_expr.py +17 -7
- pixeltable/exprs/sql_element_cache.py +4 -0
- pixeltable/exprs/type_cast.py +2 -2
- pixeltable/exprs/variable.py +3 -0
- pixeltable/func/__init__.py +5 -4
- pixeltable/func/aggregate_function.py +151 -68
- pixeltable/func/callable_function.py +48 -16
- pixeltable/func/expr_template_function.py +64 -23
- pixeltable/func/function.py +227 -23
- pixeltable/func/function_registry.py +2 -1
- pixeltable/func/query_template_function.py +51 -9
- pixeltable/func/signature.py +65 -7
- pixeltable/func/tools.py +153 -0
- pixeltable/func/udf.py +57 -35
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/anthropic.py +51 -4
- pixeltable/functions/gemini.py +85 -0
- pixeltable/functions/globals.py +54 -34
- pixeltable/functions/huggingface.py +10 -28
- pixeltable/functions/json.py +3 -8
- pixeltable/functions/math.py +67 -0
- pixeltable/functions/mistralai.py +0 -2
- pixeltable/functions/ollama.py +8 -8
- pixeltable/functions/openai.py +51 -4
- pixeltable/functions/timestamp.py +1 -1
- pixeltable/functions/video.py +3 -9
- pixeltable/functions/vision.py +1 -1
- pixeltable/globals.py +374 -89
- pixeltable/index/embedding_index.py +106 -29
- pixeltable/io/__init__.py +1 -1
- pixeltable/io/label_studio.py +1 -1
- pixeltable/io/parquet.py +39 -19
- pixeltable/iterators/__init__.py +1 -0
- pixeltable/iterators/document.py +12 -0
- pixeltable/iterators/image.py +100 -0
- pixeltable/iterators/video.py +7 -8
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_16.py +2 -1
- pixeltable/metadata/converters/convert_17.py +2 -1
- pixeltable/metadata/converters/convert_22.py +17 -0
- pixeltable/metadata/converters/convert_23.py +35 -0
- pixeltable/metadata/converters/convert_24.py +56 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/util.py +4 -2
- pixeltable/metadata/notes.py +4 -0
- pixeltable/metadata/schema.py +1 -0
- pixeltable/plan.py +129 -51
- pixeltable/store.py +1 -1
- pixeltable/type_system.py +196 -54
- pixeltable/utils/arrow.py +8 -3
- pixeltable/utils/description_helper.py +89 -0
- pixeltable/utils/documents.py +14 -0
- {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/METADATA +32 -22
- pixeltable-0.3.0.dist-info/RECORD +155 -0
- {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
- pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
- pixeltable/tool/create_test_db_dump.py +0 -308
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable-0.2.24.dist-info/RECORD +0 -153
- pixeltable-0.2.24.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.24.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
pixeltable/env.py
CHANGED
|
@@ -8,6 +8,7 @@ import importlib.util
|
|
|
8
8
|
import inspect
|
|
9
9
|
import logging
|
|
10
10
|
import os
|
|
11
|
+
import platform
|
|
11
12
|
import shutil
|
|
12
13
|
import subprocess
|
|
13
14
|
import sys
|
|
@@ -275,6 +276,7 @@ class Env:
|
|
|
275
276
|
if self._config.get_bool_value('hide_warnings'):
|
|
276
277
|
# Disable more warnings
|
|
277
278
|
warnings.simplefilter('ignore', category=UserWarning)
|
|
279
|
+
warnings.simplefilter('ignore', category=FutureWarning)
|
|
278
280
|
|
|
279
281
|
# configure _logger to log to a file
|
|
280
282
|
self._logfilename = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + '.log'
|
|
@@ -311,8 +313,12 @@ class Env:
|
|
|
311
313
|
self._db_name = os.environ.get('PIXELTABLE_DB', 'pixeltable')
|
|
312
314
|
self._pgdata_dir = Path(os.environ.get('PIXELTABLE_PGDATA', str(self._home / 'pgdata')))
|
|
313
315
|
|
|
314
|
-
#
|
|
315
|
-
|
|
316
|
+
# cleanup_mode=None will leave the postgres process running after Python exits
|
|
317
|
+
# cleanup_mode='stop' will terminate the postgres process when Python exits
|
|
318
|
+
# On Windows, we need cleanup_mode='stop' because child processes are killed automatically when the parent
|
|
319
|
+
# process (such as Terminal or VSCode) exits, potentially leaving it in an unusable state.
|
|
320
|
+
cleanup_mode = 'stop' if platform.system() == 'Windows' else None
|
|
321
|
+
self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=cleanup_mode)
|
|
316
322
|
self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
|
|
317
323
|
|
|
318
324
|
tz_name = self.config.get_string_value('time_zone')
|
|
@@ -357,7 +363,7 @@ class Env:
|
|
|
357
363
|
self.db_url,
|
|
358
364
|
echo=echo,
|
|
359
365
|
future=True,
|
|
360
|
-
isolation_level='
|
|
366
|
+
isolation_level='REPEATABLE READ',
|
|
361
367
|
connect_args=connect_args,
|
|
362
368
|
)
|
|
363
369
|
self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
|
|
@@ -496,6 +502,7 @@ class Env:
|
|
|
496
502
|
self.__register_package('datasets')
|
|
497
503
|
self.__register_package('fiftyone')
|
|
498
504
|
self.__register_package('fireworks', library_name='fireworks-ai')
|
|
505
|
+
self.__register_package('google.generativeai', library_name='google-generativeai')
|
|
499
506
|
self.__register_package('huggingface_hub', library_name='huggingface-hub')
|
|
500
507
|
self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
|
|
501
508
|
self.__register_package('llama_cpp', library_name='llama-cpp-python')
|
|
@@ -505,6 +512,7 @@ class Env:
|
|
|
505
512
|
self.__register_package('openai')
|
|
506
513
|
self.__register_package('openpyxl')
|
|
507
514
|
self.__register_package('pyarrow')
|
|
515
|
+
self.__register_package('pydantic')
|
|
508
516
|
self.__register_package('replicate')
|
|
509
517
|
self.__register_package('sentencepiece')
|
|
510
518
|
self.__register_package('sentence_transformers', library_name='sentence-transformers')
|
|
@@ -520,8 +528,14 @@ class Env:
|
|
|
520
528
|
self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
|
|
521
529
|
|
|
522
530
|
def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
|
|
531
|
+
is_installed: bool
|
|
532
|
+
try:
|
|
533
|
+
is_installed = importlib.util.find_spec(package_name) is not None
|
|
534
|
+
except ModuleNotFoundError:
|
|
535
|
+
# This can happen if the parent of `package_name` is not installed.
|
|
536
|
+
is_installed = False
|
|
523
537
|
self.__optional_packages[package_name] = PackageInfo(
|
|
524
|
-
is_installed=
|
|
538
|
+
is_installed=is_installed,
|
|
525
539
|
library_name=library_name or package_name # defaults to package_name unless specified otherwise
|
|
526
540
|
)
|
|
527
541
|
|
|
@@ -577,7 +591,7 @@ class Env:
|
|
|
577
591
|
self._logger.info(f'Ensuring spaCy model is installed: {filename}')
|
|
578
592
|
ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
|
|
579
593
|
if ret.returncode != 0:
|
|
580
|
-
self._logger.
|
|
594
|
+
self._logger.warning(f'pip install failed for spaCy model: {filename}')
|
|
581
595
|
try:
|
|
582
596
|
self._logger.info(f'Loading spaCy model: {spacy_model}')
|
|
583
597
|
self._spacy_nlp = spacy.load(spacy_model)
|
pixeltable/exec/__init__.py
CHANGED
|
@@ -7,4 +7,4 @@ from .exec_node import ExecNode
|
|
|
7
7
|
from .expr_eval_node import ExprEvalNode
|
|
8
8
|
from .in_memory_data_node import InMemoryDataNode
|
|
9
9
|
from .row_update_node import RowUpdateNode
|
|
10
|
-
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
|
|
10
|
+
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode, SqlJoinNode
|
pixeltable/exec/exec_node.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import TYPE_CHECKING, Iterable, Iterator, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
import pixeltable.exprs as exprs
|
|
7
7
|
|
|
8
8
|
from .data_row_batch import DataRowBatch
|
|
9
9
|
from .exec_context import ExecContext
|
|
10
10
|
|
|
11
|
-
if TYPE_CHECKING:
|
|
12
|
-
from pixeltable import exec
|
|
13
11
|
|
|
14
12
|
class ExecNode(abc.ABC):
|
|
15
13
|
"""Base class of all execution nodes"""
|
|
@@ -77,12 +75,13 @@ class ExecNode(abc.ABC):
|
|
|
77
75
|
def _close(self) -> None:
|
|
78
76
|
pass
|
|
79
77
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
78
|
+
T = TypeVar('T', bound='ExecNode')
|
|
79
|
+
|
|
80
|
+
def get_node(self, node_class: type[T]) -> Optional[T]:
|
|
81
|
+
if isinstance(self, node_class):
|
|
83
82
|
return self
|
|
84
83
|
if self.input is not None:
|
|
85
|
-
return self.input.
|
|
84
|
+
return self.input.get_node(node_class)
|
|
86
85
|
return None
|
|
87
86
|
|
|
88
87
|
def set_limit(self, limit: int) -> None:
|
|
@@ -208,7 +208,7 @@ class ExprEvalNode(ExecNode):
|
|
|
208
208
|
}
|
|
209
209
|
start_ts = time.perf_counter()
|
|
210
210
|
assert isinstance(fn_call.fn, CallableFunction)
|
|
211
|
-
result_batch = fn_call.fn.exec_batch(
|
|
211
|
+
result_batch = fn_call.fn.exec_batch(call_args, call_kwargs)
|
|
212
212
|
self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
|
|
213
213
|
self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
|
|
214
214
|
|
pixeltable/exec/sql_node.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
3
|
from decimal import Decimal
|
|
4
|
-
from typing import Iterable, Iterator, NamedTuple, Optional
|
|
4
|
+
from typing import Iterable, Iterator, NamedTuple, Optional, TYPE_CHECKING, Sequence
|
|
5
5
|
from uuid import UUID
|
|
6
6
|
|
|
7
7
|
import sqlalchemy as sql
|
|
8
8
|
|
|
9
9
|
import pixeltable.catalog as catalog
|
|
10
10
|
import pixeltable.exprs as exprs
|
|
11
|
-
|
|
12
11
|
from .data_row_batch import DataRowBatch
|
|
13
12
|
from .exec_node import ExecNode
|
|
14
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import pixeltable.plan
|
|
16
|
+
|
|
15
17
|
_logger = logging.getLogger('pixeltable')
|
|
16
18
|
|
|
17
19
|
|
|
@@ -67,12 +69,17 @@ class SqlNode(ExecNode):
|
|
|
67
69
|
select_list: exprs.ExprSet
|
|
68
70
|
set_pk: bool
|
|
69
71
|
num_pk_cols: int
|
|
70
|
-
|
|
71
|
-
|
|
72
|
+
py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
|
|
73
|
+
py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
|
|
72
74
|
cte: Optional[sql.CTE]
|
|
73
75
|
sql_elements: exprs.SqlElementCache
|
|
74
|
-
|
|
76
|
+
|
|
77
|
+
# where_clause/-_element: allow subclass to set one or the other (but not both)
|
|
78
|
+
where_clause: Optional[exprs.Expr]
|
|
79
|
+
where_clause_element: Optional[sql.ColumnElement]
|
|
80
|
+
|
|
75
81
|
order_by_clause: OrderByClause
|
|
82
|
+
limit: Optional[int]
|
|
76
83
|
|
|
77
84
|
def __init__(
|
|
78
85
|
self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
|
|
@@ -89,6 +96,7 @@ class SqlNode(ExecNode):
|
|
|
89
96
|
# create Select stmt
|
|
90
97
|
self.sql_elements = sql_elements
|
|
91
98
|
self.tbl = tbl
|
|
99
|
+
assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
|
|
92
100
|
self.select_list = exprs.ExprSet(select_list)
|
|
93
101
|
# unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
|
|
94
102
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
@@ -112,10 +120,12 @@ class SqlNode(ExecNode):
|
|
|
112
120
|
# additional state
|
|
113
121
|
self.result_cursor = None
|
|
114
122
|
# the filter is provided by the subclass
|
|
115
|
-
self.
|
|
116
|
-
self.
|
|
123
|
+
self.py_filter = None
|
|
124
|
+
self.py_filter_eval_ctx = None
|
|
117
125
|
self.cte = None
|
|
118
126
|
self.limit = None
|
|
127
|
+
self.where_clause = None
|
|
128
|
+
self.where_clause_element = None
|
|
119
129
|
self.order_by_clause = []
|
|
120
130
|
|
|
121
131
|
def _create_stmt(self) -> sql.Select:
|
|
@@ -124,9 +134,16 @@ class SqlNode(ExecNode):
|
|
|
124
134
|
assert self.sql_elements.contains_all(self.select_list)
|
|
125
135
|
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
126
136
|
if self.set_pk:
|
|
137
|
+
assert self.tbl is not None
|
|
127
138
|
sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
|
|
128
139
|
stmt = sql.select(*sql_select_list)
|
|
129
140
|
|
|
141
|
+
where_clause_element = (
|
|
142
|
+
self.sql_elements.get(self.where_clause) if self.where_clause is not None else self.where_clause_element
|
|
143
|
+
)
|
|
144
|
+
if where_clause_element is not None:
|
|
145
|
+
stmt = stmt.where(where_clause_element)
|
|
146
|
+
|
|
130
147
|
order_by_clause: list[sql.ColumnElement] = []
|
|
131
148
|
for e, asc in self.order_by_clause:
|
|
132
149
|
if isinstance(e, exprs.SimilarityExpr):
|
|
@@ -135,7 +152,7 @@ class SqlNode(ExecNode):
|
|
|
135
152
|
order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
|
|
136
153
|
stmt = stmt.order_by(*order_by_clause)
|
|
137
154
|
|
|
138
|
-
if self.
|
|
155
|
+
if self.py_filter is None and self.limit is not None:
|
|
139
156
|
# if we don't have a Python filter, we can apply the limit to stmt
|
|
140
157
|
stmt = stmt.limit(self.limit)
|
|
141
158
|
|
|
@@ -151,7 +168,7 @@ class SqlNode(ExecNode):
|
|
|
151
168
|
Returns:
|
|
152
169
|
(CTE, dict from Expr to output column)
|
|
153
170
|
"""
|
|
154
|
-
if self.
|
|
171
|
+
if self.py_filter is not None:
|
|
155
172
|
# the filter needs to run in Python
|
|
156
173
|
return None
|
|
157
174
|
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
@@ -215,8 +232,17 @@ class SqlNode(ExecNode):
|
|
|
215
232
|
prev_tbl = tbl
|
|
216
233
|
return stmt
|
|
217
234
|
|
|
218
|
-
def
|
|
219
|
-
|
|
235
|
+
def set_where(self, where_clause: exprs.Expr) -> None:
|
|
236
|
+
assert self.where_clause_element is None
|
|
237
|
+
self.where_clause = where_clause
|
|
238
|
+
|
|
239
|
+
def set_py_filter(self, py_filter: exprs.Expr) -> None:
|
|
240
|
+
assert self.py_filter is None
|
|
241
|
+
self.py_filter = py_filter
|
|
242
|
+
self.py_filter_eval_ctx = self.row_builder.create_eval_ctx([py_filter], exclude=self.select_list)
|
|
243
|
+
|
|
244
|
+
def set_order_by(self, ordering: OrderByClause) -> None:
|
|
245
|
+
"""Add Order By clause"""
|
|
220
246
|
if self.tbl is not None:
|
|
221
247
|
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
222
248
|
# the number of tables that need to be joined to the target table
|
|
@@ -236,7 +262,7 @@ class SqlNode(ExecNode):
|
|
|
236
262
|
explain_str = '\n'.join([str(row) for row in explain_result])
|
|
237
263
|
_logger.debug(f'SqlScanNode explain:\n{explain_str}')
|
|
238
264
|
except Exception as e:
|
|
239
|
-
_logger.warning(f'EXPLAIN failed')
|
|
265
|
+
_logger.warning(f'EXPLAIN failed with error: {e}')
|
|
240
266
|
|
|
241
267
|
def __iter__(self) -> Iterator[DataRowBatch]:
|
|
242
268
|
# run the query; do this here rather than in _open(), exceptions are only expected during iteration
|
|
@@ -280,10 +306,10 @@ class SqlNode(ExecNode):
|
|
|
280
306
|
else:
|
|
281
307
|
output_row[slot_idx] = sql_row[i]
|
|
282
308
|
|
|
283
|
-
if self.
|
|
309
|
+
if self.py_filter is not None:
|
|
284
310
|
# evaluate filter
|
|
285
|
-
self.row_builder.eval(output_row, self.
|
|
286
|
-
if self.
|
|
311
|
+
self.row_builder.eval(output_row, self.py_filter_eval_ctx, profile=self.ctx.profile)
|
|
312
|
+
if self.py_filter is not None and not output_row[self.py_filter.slot_idx]:
|
|
287
313
|
# we re-use this row for the next sql row since it didn't pass the filter
|
|
288
314
|
output_row = output_batch.pop_row()
|
|
289
315
|
output_row.clear()
|
|
@@ -315,21 +341,16 @@ class SqlScanNode(SqlNode):
|
|
|
315
341
|
|
|
316
342
|
Supports filtering and ordering.
|
|
317
343
|
"""
|
|
318
|
-
where_clause: Optional[exprs.Expr]
|
|
319
344
|
exact_version_only: list[catalog.TableVersion]
|
|
320
345
|
|
|
321
346
|
def __init__(
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
347
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
348
|
+
select_list: Iterable[exprs.Expr],
|
|
349
|
+
set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
326
350
|
):
|
|
327
351
|
"""
|
|
328
352
|
Args:
|
|
329
353
|
select_list: output of the query
|
|
330
|
-
sql_where_clause: SQL Where clause
|
|
331
|
-
filter: additional Where-clause predicate that can't be evaluated via SQL
|
|
332
|
-
limit: max number of rows to return: 0 = no limit
|
|
333
354
|
set_pk: if True, sets the primary for each DataRow
|
|
334
355
|
exact_version_only: tables for which we only want to see rows created at the current version
|
|
335
356
|
"""
|
|
@@ -338,12 +359,7 @@ class SqlScanNode(SqlNode):
|
|
|
338
359
|
# create Select stmt
|
|
339
360
|
if exact_version_only is None:
|
|
340
361
|
exact_version_only = []
|
|
341
|
-
target = tbl.tbl_version # the stored table we're scanning
|
|
342
|
-
self.filter = filter
|
|
343
|
-
self.filter_eval_ctx = \
|
|
344
|
-
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
345
362
|
|
|
346
|
-
self.where_clause = where_clause
|
|
347
363
|
self.exact_version_only = exact_version_only
|
|
348
364
|
|
|
349
365
|
def _create_stmt(self) -> sql.Select:
|
|
@@ -352,12 +368,6 @@ class SqlScanNode(SqlNode):
|
|
|
352
368
|
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
|
|
353
369
|
stmt = self.create_from_clause(
|
|
354
370
|
self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
|
|
355
|
-
|
|
356
|
-
if self.where_clause is not None:
|
|
357
|
-
sql_where_clause = self.sql_elements.get(self.where_clause)
|
|
358
|
-
assert sql_where_clause is not None
|
|
359
|
-
stmt = stmt.where(sql_where_clause)
|
|
360
|
-
|
|
361
371
|
return stmt
|
|
362
372
|
|
|
363
373
|
|
|
@@ -366,11 +376,9 @@ class SqlLookupNode(SqlNode):
|
|
|
366
376
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
367
377
|
"""
|
|
368
378
|
|
|
369
|
-
where_clause: sql.ColumnElement
|
|
370
|
-
|
|
371
379
|
def __init__(
|
|
372
|
-
|
|
373
|
-
|
|
380
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
381
|
+
select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
|
|
374
382
|
):
|
|
375
383
|
"""
|
|
376
384
|
Args:
|
|
@@ -381,15 +389,15 @@ class SqlLookupNode(SqlNode):
|
|
|
381
389
|
sql_elements = exprs.SqlElementCache()
|
|
382
390
|
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
383
391
|
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
384
|
-
self.
|
|
392
|
+
self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
385
393
|
|
|
386
394
|
def _create_stmt(self) -> sql.Select:
|
|
387
395
|
stmt = super()._create_stmt()
|
|
388
396
|
refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
|
|
389
397
|
stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
|
|
390
|
-
stmt = stmt.where(self.where_clause)
|
|
391
398
|
return stmt
|
|
392
399
|
|
|
400
|
+
|
|
393
401
|
class SqlAggregationNode(SqlNode):
|
|
394
402
|
"""
|
|
395
403
|
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
@@ -398,11 +406,11 @@ class SqlAggregationNode(SqlNode):
|
|
|
398
406
|
group_by_items: Optional[list[exprs.Expr]]
|
|
399
407
|
|
|
400
408
|
def __init__(
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
409
|
+
self, row_builder: exprs.RowBuilder,
|
|
410
|
+
input: SqlNode,
|
|
411
|
+
select_list: Iterable[exprs.Expr],
|
|
412
|
+
group_by_items: Optional[list[exprs.Expr]] = None,
|
|
413
|
+
limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
|
|
406
414
|
):
|
|
407
415
|
"""
|
|
408
416
|
Args:
|
|
@@ -422,3 +430,42 @@ class SqlAggregationNode(SqlNode):
|
|
|
422
430
|
assert all(e is not None for e in sql_group_by_items)
|
|
423
431
|
stmt = stmt.group_by(*sql_group_by_items)
|
|
424
432
|
return stmt
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class SqlJoinNode(SqlNode):
|
|
436
|
+
"""
|
|
437
|
+
Materializes data from the store via a Select ... From ... that contains joins
|
|
438
|
+
"""
|
|
439
|
+
input_ctes: list[sql.CTE]
|
|
440
|
+
join_clauses: list['pixeltable.plan.JoinClause']
|
|
441
|
+
|
|
442
|
+
def __init__(
|
|
443
|
+
self, row_builder: exprs.RowBuilder,
|
|
444
|
+
inputs: Sequence[SqlNode], join_clauses: list['pixeltable.plan.JoinClause'], select_list: Iterable[exprs.Expr]
|
|
445
|
+
):
|
|
446
|
+
assert len(inputs) > 1
|
|
447
|
+
assert len(inputs) == len(join_clauses) + 1
|
|
448
|
+
self.input_ctes = []
|
|
449
|
+
self.join_clauses = join_clauses
|
|
450
|
+
sql_elements = exprs.SqlElementCache()
|
|
451
|
+
for input_node in inputs:
|
|
452
|
+
input_cte, input_col_map = input_node.to_cte()
|
|
453
|
+
self.input_ctes.append(input_cte)
|
|
454
|
+
sql_elements.extend(input_col_map)
|
|
455
|
+
super().__init__(None, row_builder, select_list, sql_elements)
|
|
456
|
+
|
|
457
|
+
def _create_stmt(self) -> sql.Select:
|
|
458
|
+
from pixeltable import plan
|
|
459
|
+
stmt = super()._create_stmt()
|
|
460
|
+
stmt = stmt.select_from(self.input_ctes[0])
|
|
461
|
+
for i in range(len(self.join_clauses)):
|
|
462
|
+
join_clause = self.join_clauses[i]
|
|
463
|
+
on_clause = (
|
|
464
|
+
self.sql_elements.get(join_clause.join_predicate) if join_clause.join_type != plan.JoinType.CROSS
|
|
465
|
+
else sql.sql.expression.literal(True)
|
|
466
|
+
)
|
|
467
|
+
is_outer = join_clause.join_type == plan.JoinType.LEFT or join_clause.join_type == plan.JoinType.FULL_OUTER
|
|
468
|
+
stmt = stmt.join(
|
|
469
|
+
self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
|
|
470
|
+
full=join_clause == plan.JoinType.FULL_OUTER)
|
|
471
|
+
return stmt
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -35,7 +35,7 @@ class ArithmeticExpr(Expr):
|
|
|
35
35
|
|
|
36
36
|
self.id = self._create_id()
|
|
37
37
|
|
|
38
|
-
def
|
|
38
|
+
def __repr__(self) -> str:
|
|
39
39
|
# add parentheses around operands that are ArithmeticExprs to express precedence
|
|
40
40
|
op1_str = f'({self._op1})' if isinstance(self._op1, ArithmeticExpr) else str(self._op1)
|
|
41
41
|
op2_str = f'({self._op2})' if isinstance(self._op2, ArithmeticExpr) else str(self._op2)
|
pixeltable/exprs/array_slice.py
CHANGED
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -5,10 +5,12 @@ from uuid import UUID
|
|
|
5
5
|
|
|
6
6
|
import sqlalchemy as sql
|
|
7
7
|
|
|
8
|
+
import pixeltable as pxt
|
|
8
9
|
import pixeltable.catalog as catalog
|
|
9
10
|
import pixeltable.exceptions as excs
|
|
10
11
|
import pixeltable.iterators as iters
|
|
11
12
|
|
|
13
|
+
from ..utils.description_helper import DescriptionHelper
|
|
12
14
|
from .data_row import DataRow
|
|
13
15
|
from .expr import Expr
|
|
14
16
|
from .row_builder import RowBuilder
|
|
@@ -126,6 +128,22 @@ class ColumnRef(Expr):
|
|
|
126
128
|
def _equals(self, other: ColumnRef) -> bool:
|
|
127
129
|
return self.col == other.col and self.perform_validation == other.perform_validation
|
|
128
130
|
|
|
131
|
+
def _df(self) -> 'pxt.dataframe.DataFrame':
|
|
132
|
+
tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
|
|
133
|
+
return tbl.select(self)
|
|
134
|
+
|
|
135
|
+
def show(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
136
|
+
return self._df().show(*args, **kwargs)
|
|
137
|
+
|
|
138
|
+
def head(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
139
|
+
return self._df().head(*args, **kwargs)
|
|
140
|
+
|
|
141
|
+
def tail(self, *args, **kwargs) -> 'pxt.dataframe.DataFrameResultSet':
|
|
142
|
+
return self._df().tail(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
def count(self) -> int:
|
|
145
|
+
return self._df().count()
|
|
146
|
+
|
|
129
147
|
def __str__(self) -> str:
|
|
130
148
|
if self.col.name is None:
|
|
131
149
|
return f'<unnamed column {self.col.id}>'
|
|
@@ -133,11 +151,20 @@ class ColumnRef(Expr):
|
|
|
133
151
|
return self.col.name
|
|
134
152
|
|
|
135
153
|
def __repr__(self) -> str:
|
|
136
|
-
return
|
|
154
|
+
return self._descriptors().to_string()
|
|
137
155
|
|
|
138
156
|
def _repr_html_(self) -> str:
|
|
157
|
+
return self._descriptors().to_html()
|
|
158
|
+
|
|
159
|
+
def _descriptors(self) -> DescriptionHelper:
|
|
139
160
|
tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
|
|
140
|
-
|
|
161
|
+
helper = DescriptionHelper()
|
|
162
|
+
helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
|
|
163
|
+
helper.append(tbl._col_descriptor([self.col.name]))
|
|
164
|
+
idxs = tbl._index_descriptor([self.col.name])
|
|
165
|
+
if len(idxs) > 0:
|
|
166
|
+
helper.append(idxs)
|
|
167
|
+
return helper
|
|
141
168
|
|
|
142
169
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
143
170
|
return None if self.perform_validation else self.col.sa_col
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -30,7 +30,7 @@ class CompoundPredicate(Expr):
|
|
|
30
30
|
|
|
31
31
|
self.id = self._create_id()
|
|
32
32
|
|
|
33
|
-
def
|
|
33
|
+
def __repr__(self) -> str:
|
|
34
34
|
if self.operator == LogicalOperator.NOT:
|
|
35
35
|
return f'~({self.components[0]})'
|
|
36
36
|
return f' {self.operator} '.join([f'({e})' for e in self.components])
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -190,6 +190,7 @@ class Expr(abc.ABC):
|
|
|
190
190
|
return new.copy()
|
|
191
191
|
for i in range(len(self.components)):
|
|
192
192
|
self.components[i] = self.components[i].substitute(spec)
|
|
193
|
+
self.id = self._create_id()
|
|
193
194
|
return self
|
|
194
195
|
|
|
195
196
|
@classmethod
|
|
@@ -216,12 +217,12 @@ class Expr(abc.ABC):
|
|
|
216
217
|
return result
|
|
217
218
|
result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
|
|
218
219
|
|
|
219
|
-
def is_bound_by(self,
|
|
220
|
-
"""Returns True if this expr can be evaluated in the context of
|
|
220
|
+
def is_bound_by(self, tbls: list[catalog.TableVersionPath]) -> bool:
|
|
221
|
+
"""Returns True if this expr can be evaluated in the context of tbls."""
|
|
221
222
|
from .column_ref import ColumnRef
|
|
222
223
|
col_refs = self.subexprs(ColumnRef)
|
|
223
224
|
for col_ref in col_refs:
|
|
224
|
-
if not tbl.has_column(col_ref.col):
|
|
225
|
+
if not any(tbl.has_column(col_ref.col) for tbl in tbls):
|
|
225
226
|
return False
|
|
226
227
|
return True
|
|
227
228
|
|
|
@@ -235,7 +236,7 @@ class Expr(abc.ABC):
|
|
|
235
236
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
|
236
237
|
return self
|
|
237
238
|
|
|
238
|
-
def
|
|
239
|
+
def __repr__(self) -> str:
|
|
239
240
|
return f'<Expression of type {type(self)}>'
|
|
240
241
|
|
|
241
242
|
def display_str(self, inline: bool = True) -> str:
|
|
@@ -450,7 +451,13 @@ class Expr(abc.ABC):
|
|
|
450
451
|
|
|
451
452
|
def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'exprs.TypeCast':
|
|
452
453
|
from pixeltable.exprs import TypeCast
|
|
453
|
-
|
|
454
|
+
# Interpret the type argument the same way we would if given in a schema
|
|
455
|
+
col_type = ts.ColumnType.normalize_type(new_type, nullable_default=True, allow_builtin_types=False)
|
|
456
|
+
if not self.col_type.nullable:
|
|
457
|
+
# This expression is non-nullable; we can prove that the output is non-nullable, regardless of
|
|
458
|
+
# whether new_type is given as nullable.
|
|
459
|
+
col_type = col_type.copy(nullable=False)
|
|
460
|
+
return TypeCast(self, col_type)
|
|
454
461
|
|
|
455
462
|
def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'exprs.FunctionCall':
|
|
456
463
|
if col_type is not None:
|
pixeltable/exprs/expr_set.py
CHANGED
|
@@ -60,6 +60,14 @@ class ExprSet(Generic[T]):
|
|
|
60
60
|
def __le__(self, other: ExprSet[T]) -> bool:
|
|
61
61
|
return other.issuperset(self)
|
|
62
62
|
|
|
63
|
+
def union(self, *others: Iterable[T]) -> ExprSet[T]:
|
|
64
|
+
result = ExprSet(self.exprs.values())
|
|
65
|
+
result.update(*others)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
def __or__(self, other: ExprSet[T]) -> ExprSet[T]:
|
|
69
|
+
return self.union(other)
|
|
70
|
+
|
|
63
71
|
def difference(self, *others: Iterable[T]) -> ExprSet[T]:
|
|
64
72
|
id_diff = set(self.exprs.keys()).difference(e.id for other_set in others for e in other_set)
|
|
65
73
|
return ExprSet(self.exprs[id] for id in id_diff)
|