pixeltable 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/column.py +5 -0
- pixeltable/catalog/globals.py +8 -0
- pixeltable/catalog/insertable_table.py +2 -2
- pixeltable/catalog/table.py +27 -9
- pixeltable/catalog/table_version.py +41 -68
- pixeltable/catalog/view.py +3 -3
- pixeltable/dataframe.py +7 -6
- pixeltable/exec/__init__.py +2 -1
- pixeltable/exec/expr_eval_node.py +8 -1
- pixeltable/exec/row_update_node.py +61 -0
- pixeltable/exec/{sql_scan_node.py → sql_node.py} +120 -56
- pixeltable/exprs/__init__.py +1 -2
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +12 -12
- pixeltable/exprs/expr.py +67 -22
- pixeltable/exprs/function_call.py +60 -29
- pixeltable/exprs/globals.py +2 -0
- pixeltable/exprs/in_predicate.py +3 -3
- pixeltable/exprs/inline_array.py +18 -11
- pixeltable/exprs/is_null.py +5 -5
- pixeltable/exprs/method_ref.py +63 -0
- pixeltable/ext/__init__.py +9 -0
- pixeltable/ext/functions/__init__.py +8 -0
- pixeltable/ext/functions/whisperx.py +45 -5
- pixeltable/ext/functions/yolox.py +60 -14
- pixeltable/func/aggregate_function.py +10 -4
- pixeltable/func/callable_function.py +16 -4
- pixeltable/func/expr_template_function.py +1 -1
- pixeltable/func/function.py +12 -2
- pixeltable/func/function_registry.py +26 -9
- pixeltable/func/udf.py +32 -4
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/fireworks.py +33 -0
- pixeltable/functions/globals.py +36 -1
- pixeltable/functions/huggingface.py +155 -7
- pixeltable/functions/image.py +242 -40
- pixeltable/functions/openai.py +214 -0
- pixeltable/functions/string.py +600 -8
- pixeltable/functions/timestamp.py +210 -0
- pixeltable/functions/together.py +106 -0
- pixeltable/functions/video.py +28 -10
- pixeltable/functions/whisper.py +32 -0
- pixeltable/globals.py +3 -3
- pixeltable/io/__init__.py +1 -1
- pixeltable/io/globals.py +186 -5
- pixeltable/io/label_studio.py +42 -2
- pixeltable/io/pandas.py +70 -34
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_18.py +39 -0
- pixeltable/metadata/notes.py +10 -0
- pixeltable/plan.py +82 -7
- pixeltable/tool/create_test_db_dump.py +4 -5
- pixeltable/tool/doc_plugins/griffe.py +81 -0
- pixeltable/tool/doc_plugins/mkdocstrings.py +6 -0
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +135 -0
- pixeltable/type_system.py +15 -14
- pixeltable/utils/s3.py +1 -1
- pixeltable-0.2.14.dist-info/METADATA +206 -0
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/RECORD +64 -56
- pixeltable-0.2.14.dist-info/entry_points.txt +3 -0
- pixeltable/exprs/image_member_access.py +0 -96
- pixeltable/exprs/predicate.py +0 -44
- pixeltable-0.2.12.dist-info/METADATA +0 -137
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/WHEEL +0 -0
|
@@ -13,30 +13,23 @@ import pixeltable.catalog as catalog
|
|
|
13
13
|
|
|
14
14
|
_logger = logging.getLogger('pixeltable')
|
|
15
15
|
|
|
16
|
-
class
|
|
17
|
-
"""Materializes data from the store via
|
|
18
|
-
|
|
16
|
+
class SqlNode(ExecNode):
|
|
17
|
+
"""Materializes data from the store via a Select stmt."""
|
|
18
|
+
|
|
19
19
|
def __init__(
|
|
20
20
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
21
|
-
select_list: Iterable[exprs.Expr],
|
|
22
|
-
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Predicate] = None,
|
|
23
|
-
order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
|
|
24
|
-
limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
|
|
21
|
+
select_list: Iterable[exprs.Expr], set_pk: bool = False
|
|
25
22
|
):
|
|
26
23
|
"""
|
|
24
|
+
Initialize self.stmt with expressions derived from select_list.
|
|
25
|
+
|
|
26
|
+
This only provides the select list. The subclass is responsible for the From clause and any additional clauses.
|
|
27
|
+
|
|
27
28
|
Args:
|
|
28
29
|
select_list: output of the query
|
|
29
|
-
sql_where_clause: SQL Where clause
|
|
30
|
-
filter: additional Where-clause predicate that can't be evaluated via SQL
|
|
31
|
-
limit: max number of rows to return: 0 = no limit
|
|
32
30
|
set_pk: if True, sets the primary for each DataRow
|
|
33
|
-
exact_version_only: tables for which we only want to see rows created at the current version
|
|
34
31
|
"""
|
|
35
32
|
# create Select stmt
|
|
36
|
-
if order_by_items is None:
|
|
37
|
-
order_by_items = []
|
|
38
|
-
if exact_version_only is None:
|
|
39
|
-
exact_version_only = []
|
|
40
33
|
self.tbl = tbl
|
|
41
34
|
target = tbl.tbl_version # the stored table we're scanning
|
|
42
35
|
self.sql_exprs = exprs.ExprSet(select_list)
|
|
@@ -45,21 +38,15 @@ class SqlScanNode(ExecNode):
|
|
|
45
38
|
sql_subexprs = iter_arg.subexprs(filter=lambda e: e.sql_expr() is not None, traverse_matches=False)
|
|
46
39
|
[self.sql_exprs.append(e) for e in sql_subexprs]
|
|
47
40
|
super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
|
|
48
|
-
self.filter = filter
|
|
49
|
-
self.filter_eval_ctx = \
|
|
50
|
-
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
51
|
-
self.limit = limit
|
|
52
41
|
|
|
53
42
|
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
54
43
|
# the number of tables that need to be joined to the target table
|
|
55
44
|
for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
|
|
56
45
|
rowid_ref.set_tbl(tbl)
|
|
57
46
|
|
|
58
|
-
where_clause_tbl_ids = where_clause.tbl_ids() if where_clause is not None else set()
|
|
59
|
-
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs) | where_clause_tbl_ids
|
|
60
47
|
sql_select_list = [e.sql_expr() for e in self.sql_exprs]
|
|
61
48
|
assert len(sql_select_list) == len(self.sql_exprs)
|
|
62
|
-
assert all(
|
|
49
|
+
assert all(e is not None for e in sql_select_list)
|
|
63
50
|
self.set_pk = set_pk
|
|
64
51
|
self.num_pk_cols = 0
|
|
65
52
|
if set_pk:
|
|
@@ -69,42 +56,12 @@ class SqlScanNode(ExecNode):
|
|
|
69
56
|
sql_select_list += pk_columns
|
|
70
57
|
|
|
71
58
|
self.stmt = sql.select(*sql_select_list)
|
|
72
|
-
self.stmt = self.create_from_clause(
|
|
73
|
-
tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
|
|
74
|
-
|
|
75
|
-
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
76
|
-
# the number of tables that need to be joined to the target table
|
|
77
|
-
for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
|
|
78
|
-
rowid_ref.set_tbl(tbl)
|
|
79
|
-
order_by_clause: List[sql.ClauseElement] = []
|
|
80
|
-
for e, asc in order_by_items:
|
|
81
|
-
if isinstance(e, exprs.SimilarityExpr):
|
|
82
|
-
order_by_clause.append(e.as_order_by_clause(asc))
|
|
83
|
-
else:
|
|
84
|
-
order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
|
|
85
|
-
|
|
86
|
-
if where_clause is not None:
|
|
87
|
-
sql_where_clause = where_clause.sql_expr()
|
|
88
|
-
assert sql_where_clause is not None
|
|
89
|
-
self.stmt = self.stmt.where(sql_where_clause)
|
|
90
|
-
if len(order_by_clause) > 0:
|
|
91
|
-
self.stmt = self.stmt.order_by(*order_by_clause)
|
|
92
|
-
elif target.id in row_builder.unstored_iter_args:
|
|
93
|
-
# we are referencing unstored iter columns from this view and try to order by our primary key,
|
|
94
|
-
# which ensures that iterators will see monotonically increasing pos values
|
|
95
|
-
self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
|
|
96
|
-
if limit != 0 and self.filter is None:
|
|
97
|
-
# if we need to do post-SQL filtering, we can't use LIMIT
|
|
98
|
-
self.stmt = self.stmt.limit(limit)
|
|
99
59
|
|
|
60
|
+
# additional state
|
|
100
61
|
self.result_cursor: Optional[sql.engine.CursorResult] = None
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
105
|
-
_logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
|
|
106
|
-
except Exception as e:
|
|
107
|
-
pass
|
|
62
|
+
# the filter is provided by the subclass
|
|
63
|
+
self.filter: Optional[exprs.Expr] = None
|
|
64
|
+
self.filter_eval_ctx: Optional[exprs.EvalContext] = None
|
|
108
65
|
|
|
109
66
|
@classmethod
|
|
110
67
|
def create_from_clause(
|
|
@@ -224,3 +181,110 @@ class SqlScanNode(ExecNode):
|
|
|
224
181
|
if self.result_cursor is not None:
|
|
225
182
|
self.result_cursor.close()
|
|
226
183
|
|
|
184
|
+
|
|
185
|
+
class SqlScanNode(SqlNode):
|
|
186
|
+
"""
|
|
187
|
+
Materializes data from the store via a Select stmt.
|
|
188
|
+
|
|
189
|
+
Supports filtering and ordering.
|
|
190
|
+
"""
|
|
191
|
+
def __init__(
|
|
192
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
193
|
+
select_list: Iterable[exprs.Expr],
|
|
194
|
+
where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
|
|
195
|
+
order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
|
|
196
|
+
limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
|
|
197
|
+
):
|
|
198
|
+
"""
|
|
199
|
+
Args:
|
|
200
|
+
select_list: output of the query
|
|
201
|
+
sql_where_clause: SQL Where clause
|
|
202
|
+
filter: additional Where-clause predicate that can't be evaluated via SQL
|
|
203
|
+
limit: max number of rows to return: 0 = no limit
|
|
204
|
+
set_pk: if True, sets the primary for each DataRow
|
|
205
|
+
exact_version_only: tables for which we only want to see rows created at the current version
|
|
206
|
+
"""
|
|
207
|
+
super().__init__(tbl, row_builder, select_list, set_pk=set_pk)
|
|
208
|
+
# create Select stmt
|
|
209
|
+
if order_by_items is None:
|
|
210
|
+
order_by_items = []
|
|
211
|
+
if exact_version_only is None:
|
|
212
|
+
exact_version_only = []
|
|
213
|
+
target = tbl.tbl_version # the stored table we're scanning
|
|
214
|
+
self.filter = filter
|
|
215
|
+
self.filter_eval_ctx = \
|
|
216
|
+
row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
|
|
217
|
+
self.limit = limit
|
|
218
|
+
|
|
219
|
+
where_clause_tbl_ids = where_clause.tbl_ids() if where_clause is not None else set()
|
|
220
|
+
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs) | where_clause_tbl_ids
|
|
221
|
+
self.stmt = self.create_from_clause(
|
|
222
|
+
tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
|
|
223
|
+
|
|
224
|
+
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
225
|
+
# the number of tables that need to be joined to the target table
|
|
226
|
+
for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
|
|
227
|
+
rowid_ref.set_tbl(tbl)
|
|
228
|
+
order_by_clause: List[sql.ClauseElement] = []
|
|
229
|
+
for e, asc in order_by_items:
|
|
230
|
+
if isinstance(e, exprs.SimilarityExpr):
|
|
231
|
+
order_by_clause.append(e.as_order_by_clause(asc))
|
|
232
|
+
else:
|
|
233
|
+
order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
|
|
234
|
+
|
|
235
|
+
if where_clause is not None:
|
|
236
|
+
sql_where_clause = where_clause.sql_expr()
|
|
237
|
+
assert sql_where_clause is not None
|
|
238
|
+
self.stmt = self.stmt.where(sql_where_clause)
|
|
239
|
+
if len(order_by_clause) > 0:
|
|
240
|
+
self.stmt = self.stmt.order_by(*order_by_clause)
|
|
241
|
+
elif target.id in row_builder.unstored_iter_args:
|
|
242
|
+
# we are referencing unstored iter columns from this view and try to order by our primary key,
|
|
243
|
+
# which ensures that iterators will see monotonically increasing pos values
|
|
244
|
+
self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
|
|
245
|
+
if limit != 0 and self.filter is None:
|
|
246
|
+
# if we need to do post-SQL filtering, we can't use LIMIT
|
|
247
|
+
self.stmt = self.stmt.limit(limit)
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# log stmt, if possible
|
|
251
|
+
stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
252
|
+
_logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
|
|
253
|
+
except Exception as e:
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class SqlLookupNode(SqlNode):
|
|
258
|
+
"""
|
|
259
|
+
Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
|
|
260
|
+
"""
|
|
261
|
+
def __init__(
|
|
262
|
+
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
263
|
+
select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
Args:
|
|
267
|
+
select_list: output of the query
|
|
268
|
+
sa_key_cols: list of key columns in the store table
|
|
269
|
+
key_vals: list of key values to look up
|
|
270
|
+
"""
|
|
271
|
+
super().__init__(tbl, row_builder, select_list, set_pk=True)
|
|
272
|
+
target = tbl.tbl_version # the stored table we're scanning
|
|
273
|
+
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
|
|
274
|
+
self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
|
|
275
|
+
# Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
|
|
276
|
+
self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
|
|
277
|
+
self.stmt = self.stmt.where(self.where_clause)
|
|
278
|
+
|
|
279
|
+
if target.id in row_builder.unstored_iter_args:
|
|
280
|
+
# we are referencing unstored iter columns from this view and try to order by our primary key,
|
|
281
|
+
# which ensures that iterators will see monotonically increasing pos values
|
|
282
|
+
self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# log stmt, if possible
|
|
286
|
+
stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
287
|
+
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
288
|
+
except Exception as e:
|
|
289
|
+
pass
|
|
290
|
+
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -8,7 +8,6 @@ from .data_row import DataRow
|
|
|
8
8
|
from .expr import Expr
|
|
9
9
|
from .expr_set import ExprSet
|
|
10
10
|
from .function_call import FunctionCall
|
|
11
|
-
from .image_member_access import ImageMemberAccess
|
|
12
11
|
from .in_predicate import InPredicate
|
|
13
12
|
from .inline_array import InlineArray
|
|
14
13
|
from .inline_dict import InlineDict
|
|
@@ -16,8 +15,8 @@ from .is_null import IsNull
|
|
|
16
15
|
from .json_mapper import JsonMapper
|
|
17
16
|
from .json_path import RELATIVE_PATH_ROOT, JsonPath
|
|
18
17
|
from .literal import Literal
|
|
18
|
+
from .method_ref import MethodRef
|
|
19
19
|
from .object_ref import ObjectRef
|
|
20
|
-
from .predicate import Predicate
|
|
21
20
|
from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
|
|
22
21
|
from .rowid_ref import RowidRef
|
|
23
22
|
from .similarity_expr import SimilarityExpr
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Optional, List, Any, Dict
|
|
3
|
+
from typing import Optional, List, Any, Dict
|
|
4
4
|
|
|
5
5
|
import sqlalchemy as sql
|
|
6
6
|
|
|
@@ -9,15 +9,15 @@ from .data_row import DataRow
|
|
|
9
9
|
from .expr import Expr
|
|
10
10
|
from .globals import ComparisonOperator
|
|
11
11
|
from .literal import Literal
|
|
12
|
-
from .predicate import Predicate
|
|
13
12
|
from .row_builder import RowBuilder
|
|
14
13
|
import pixeltable.exceptions as excs
|
|
15
14
|
import pixeltable.index as index
|
|
15
|
+
import pixeltable.type_system as ts
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class Comparison(
|
|
18
|
+
class Comparison(Expr):
|
|
19
19
|
def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
|
|
20
|
-
super().__init__()
|
|
20
|
+
super().__init__(ts.BoolType())
|
|
21
21
|
self.operator = operator
|
|
22
22
|
|
|
23
23
|
# if this is a comparison of a column to a literal (ie, could be used as a search argument in an index lookup),
|
|
@@ -50,7 +50,7 @@ class Comparison(Predicate):
|
|
|
50
50
|
def _equals(self, other: Comparison) -> bool:
|
|
51
51
|
return self.operator == other.operator
|
|
52
52
|
|
|
53
|
-
def _id_attrs(self) ->
|
|
53
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
54
54
|
return super()._id_attrs() + [('operator', self.operator.value)]
|
|
55
55
|
|
|
56
56
|
@property
|
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import operator
|
|
4
|
+
from typing import Optional, List, Any, Dict, Callable
|
|
4
5
|
|
|
5
6
|
import sqlalchemy as sql
|
|
6
7
|
|
|
8
|
+
from .data_row import DataRow
|
|
7
9
|
from .expr import Expr
|
|
8
10
|
from .globals import LogicalOperator
|
|
9
|
-
from .predicate import Predicate
|
|
10
|
-
from .data_row import DataRow
|
|
11
11
|
from .row_builder import RowBuilder
|
|
12
|
-
import pixeltable.
|
|
12
|
+
import pixeltable.type_system as ts
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class CompoundPredicate(
|
|
16
|
-
def __init__(self, operator: LogicalOperator, operands: List[
|
|
17
|
-
super().__init__()
|
|
15
|
+
class CompoundPredicate(Expr):
|
|
16
|
+
def __init__(self, operator: LogicalOperator, operands: List[Expr]):
|
|
17
|
+
super().__init__(ts.BoolType())
|
|
18
18
|
self.operator = operator
|
|
19
19
|
# operands are stored in self.components
|
|
20
20
|
if self.operator == LogicalOperator.NOT:
|
|
@@ -22,7 +22,7 @@ class CompoundPredicate(Predicate):
|
|
|
22
22
|
self.components = operands
|
|
23
23
|
else:
|
|
24
24
|
assert len(operands) > 1
|
|
25
|
-
self.operands: List[
|
|
25
|
+
self.operands: List[Expr] = []
|
|
26
26
|
for operand in operands:
|
|
27
27
|
self._merge_operand(operand)
|
|
28
28
|
|
|
@@ -34,14 +34,14 @@ class CompoundPredicate(Predicate):
|
|
|
34
34
|
return f' {self.operator} '.join([f'({e})' for e in self.components])
|
|
35
35
|
|
|
36
36
|
@classmethod
|
|
37
|
-
def make_conjunction(cls, operands: List[
|
|
37
|
+
def make_conjunction(cls, operands: List[Expr]) -> Optional[Expr]:
|
|
38
38
|
if len(operands) == 0:
|
|
39
39
|
return None
|
|
40
40
|
if len(operands) == 1:
|
|
41
41
|
return operands[0]
|
|
42
42
|
return CompoundPredicate(LogicalOperator.AND, operands)
|
|
43
43
|
|
|
44
|
-
def _merge_operand(self, operand:
|
|
44
|
+
def _merge_operand(self, operand: Expr) -> None:
|
|
45
45
|
"""
|
|
46
46
|
Merge this operand, if possible, otherwise simply record it.
|
|
47
47
|
"""
|
|
@@ -55,11 +55,11 @@ class CompoundPredicate(Predicate):
|
|
|
55
55
|
def _equals(self, other: CompoundPredicate) -> bool:
|
|
56
56
|
return self.operator == other.operator
|
|
57
57
|
|
|
58
|
-
def _id_attrs(self) ->
|
|
58
|
+
def _id_attrs(self) -> list[tuple[str, Any]]:
|
|
59
59
|
return super()._id_attrs() + [('operator', self.operator.value)]
|
|
60
60
|
|
|
61
61
|
def split_conjuncts(
|
|
62
|
-
self, condition: Callable[[
|
|
62
|
+
self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
|
|
63
63
|
if self.operator == LogicalOperator.OR or self.operator == LogicalOperator.NOT:
|
|
64
64
|
return super().split_conjuncts(condition)
|
|
65
65
|
matches = [op for op in self.components if condition(op)]
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -7,7 +7,6 @@ import inspect
|
|
|
7
7
|
import json
|
|
8
8
|
import sys
|
|
9
9
|
import typing
|
|
10
|
-
from itertools import islice
|
|
11
10
|
from typing import Union, Optional, List, Callable, Any, Dict, Tuple, Set, Generator, Type
|
|
12
11
|
from uuid import UUID
|
|
13
12
|
|
|
@@ -16,8 +15,8 @@ import sqlalchemy as sql
|
|
|
16
15
|
import pixeltable
|
|
17
16
|
import pixeltable.catalog as catalog
|
|
18
17
|
import pixeltable.exceptions as excs
|
|
19
|
-
import pixeltable.type_system as ts
|
|
20
18
|
import pixeltable.func as func
|
|
19
|
+
import pixeltable.type_system as ts
|
|
21
20
|
from .data_row import DataRow
|
|
22
21
|
from .globals import ComparisonOperator, LogicalOperator, LiteralPythonTypes, ArithmeticOperator
|
|
23
22
|
|
|
@@ -91,8 +90,8 @@ class Expr(abc.ABC):
|
|
|
91
90
|
|
|
92
91
|
def default_column_name(self) -> Optional[str]:
|
|
93
92
|
"""
|
|
94
|
-
Returns:
|
|
95
|
-
None if this expression lacks a default name,
|
|
93
|
+
Returns:
|
|
94
|
+
None if this expression lacks a default name,
|
|
96
95
|
or a valid identifier (according to catalog.is_valid_identifer) otherwise.
|
|
97
96
|
"""
|
|
98
97
|
return None
|
|
@@ -231,9 +230,8 @@ class Expr(abc.ABC):
|
|
|
231
230
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
|
232
231
|
return self
|
|
233
232
|
|
|
234
|
-
@abc.abstractmethod
|
|
235
233
|
def __str__(self) -> str:
|
|
236
|
-
|
|
234
|
+
return f'<Expression of type {type(self)}>'
|
|
237
235
|
|
|
238
236
|
def display_str(self, inline: bool = True) -> str:
|
|
239
237
|
"""
|
|
@@ -264,7 +262,7 @@ class Expr(abc.ABC):
|
|
|
264
262
|
if is_match:
|
|
265
263
|
yield self
|
|
266
264
|
|
|
267
|
-
def
|
|
265
|
+
def _contains(self, cls: Optional[Type[Expr]] = None, filter: Optional[Callable[[Expr], bool]] = None) -> bool:
|
|
268
266
|
"""
|
|
269
267
|
Returns True if any subexpr is an instance of cls.
|
|
270
268
|
"""
|
|
@@ -319,17 +317,20 @@ class Expr(abc.ABC):
|
|
|
319
317
|
"""
|
|
320
318
|
if isinstance(o, Expr):
|
|
321
319
|
return o
|
|
322
|
-
#
|
|
320
|
+
# Try to create a literal. We need to check for InlineArray/InlineDict
|
|
321
|
+
# first, to prevent arrays from inappropriately being interpreted as JsonType
|
|
322
|
+
# literals.
|
|
323
|
+
# TODO: general cleanup of InlineArray/InlineDict
|
|
324
|
+
if isinstance(o, list):
|
|
325
|
+
from .inline_array import InlineArray
|
|
326
|
+
return InlineArray(tuple(o))
|
|
327
|
+
if isinstance(o, dict):
|
|
328
|
+
from .inline_dict import InlineDict
|
|
329
|
+
return InlineDict(o)
|
|
323
330
|
obj_type = ts.ColumnType.infer_literal_type(o)
|
|
324
331
|
if obj_type is not None:
|
|
325
332
|
from .literal import Literal
|
|
326
333
|
return Literal(o, col_type=obj_type)
|
|
327
|
-
if isinstance(o, dict):
|
|
328
|
-
from .inline_dict import InlineDict
|
|
329
|
-
return InlineDict(o)
|
|
330
|
-
elif isinstance(o, list):
|
|
331
|
-
from .inline_array import InlineArray
|
|
332
|
-
return InlineArray(tuple(o))
|
|
333
334
|
return None
|
|
334
335
|
|
|
335
336
|
@abc.abstractmethod
|
|
@@ -427,6 +428,14 @@ class Expr(abc.ABC):
|
|
|
427
428
|
# Return a `FunctionCall` obtained by passing this `Expr` to the new `function`.
|
|
428
429
|
return function(self)
|
|
429
430
|
|
|
431
|
+
def __dir__(self) -> list[str]:
|
|
432
|
+
attrs = ['isin', 'astype', 'apply']
|
|
433
|
+
attrs += [
|
|
434
|
+
f.name
|
|
435
|
+
for f in func.FunctionRegistry.get().get_type_methods(self.col_type.type_enum)
|
|
436
|
+
]
|
|
437
|
+
return attrs
|
|
438
|
+
|
|
430
439
|
def __getitem__(self, index: object) -> Expr:
|
|
431
440
|
if self.col_type.is_json_type():
|
|
432
441
|
from .json_path import JsonPath
|
|
@@ -434,19 +443,23 @@ class Expr(abc.ABC):
|
|
|
434
443
|
if self.col_type.is_array_type():
|
|
435
444
|
from .array_slice import ArraySlice
|
|
436
445
|
return ArraySlice(self, index)
|
|
437
|
-
raise
|
|
446
|
+
raise AttributeError(f'Type {self.col_type} is not subscriptable')
|
|
438
447
|
|
|
439
|
-
def __getattr__(self, name: str) -> Union['pixeltable.exprs.
|
|
448
|
+
def __getattr__(self, name: str) -> Union['pixeltable.exprs.MethodRef', 'pixeltable.exprs.FunctionCall', 'pixeltable.exprs.JsonPath']:
|
|
440
449
|
"""
|
|
441
450
|
ex.: <img col>.rotate(60)
|
|
442
451
|
"""
|
|
443
|
-
if self.col_type.is_image_type():
|
|
444
|
-
from .image_member_access import ImageMemberAccess
|
|
445
|
-
return ImageMemberAccess(name, self)
|
|
446
452
|
if self.col_type.is_json_type():
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
453
|
+
return pixeltable.exprs.JsonPath(self).__getattr__(name)
|
|
454
|
+
else:
|
|
455
|
+
method_ref = pixeltable.exprs.MethodRef(self, name)
|
|
456
|
+
if method_ref.fn.is_property:
|
|
457
|
+
# Marked as a property, so autoinvoke the method to obtain a `FunctionCall`
|
|
458
|
+
assert method_ref.fn.arity == 1
|
|
459
|
+
return method_ref.fn(method_ref.base_expr)
|
|
460
|
+
else:
|
|
461
|
+
# Return the `MethodRef` object itself; it requires arguments to become a `FunctionCall`
|
|
462
|
+
return method_ref
|
|
450
463
|
|
|
451
464
|
def __bool__(self) -> bool:
|
|
452
465
|
raise TypeError(
|
|
@@ -518,6 +531,38 @@ class Expr(abc.ABC):
|
|
|
518
531
|
return ArithmeticExpr(op, self, Literal(other)) # type: ignore[arg-type]
|
|
519
532
|
raise TypeError(f'Other must be Expr or literal: {type(other)}')
|
|
520
533
|
|
|
534
|
+
def __and__(self, other: object) -> Expr:
|
|
535
|
+
if not isinstance(other, Expr):
|
|
536
|
+
raise TypeError(f'Other needs to be an expression: {type(other)}')
|
|
537
|
+
if not other.col_type.is_bool_type():
|
|
538
|
+
raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
|
|
539
|
+
from .compound_predicate import CompoundPredicate
|
|
540
|
+
return CompoundPredicate(LogicalOperator.AND, [self, other])
|
|
541
|
+
|
|
542
|
+
def __or__(self, other: object) -> Expr:
|
|
543
|
+
if not isinstance(other, Expr):
|
|
544
|
+
raise TypeError(f'Other needs to be an expression: {type(other)}')
|
|
545
|
+
if not other.col_type.is_bool_type():
|
|
546
|
+
raise TypeError(f'Other needs to be an expression that returns a boolean: {other.col_type}')
|
|
547
|
+
from .compound_predicate import CompoundPredicate
|
|
548
|
+
return CompoundPredicate(LogicalOperator.OR, [self, other])
|
|
549
|
+
|
|
550
|
+
def __invert__(self) -> Expr:
|
|
551
|
+
from .compound_predicate import CompoundPredicate
|
|
552
|
+
return CompoundPredicate(LogicalOperator.NOT, [self])
|
|
553
|
+
|
|
554
|
+
def split_conjuncts(
|
|
555
|
+
self, condition: Callable[[Expr], bool]) -> tuple[list[Expr], Optional[Expr]]:
|
|
556
|
+
"""
|
|
557
|
+
Returns clauses of a conjunction that meet condition in the first element.
|
|
558
|
+
The second element contains remaining clauses, rolled into a conjunction.
|
|
559
|
+
"""
|
|
560
|
+
assert self.col_type.is_bool_type() # only valid for predicates
|
|
561
|
+
if condition(self):
|
|
562
|
+
return [self], None
|
|
563
|
+
else:
|
|
564
|
+
return [], self
|
|
565
|
+
|
|
521
566
|
def _make_applicator_function(self, fn: Callable, col_type: Optional[ts.ColumnType]) -> 'pixeltable.func.Function':
|
|
522
567
|
"""
|
|
523
568
|
Creates a unary pixeltable `Function` that encapsulates a python `Callable`. The result type of
|
|
@@ -20,6 +20,22 @@ from .rowid_ref import RowidRef
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class FunctionCall(Expr):
|
|
23
|
+
|
|
24
|
+
fn: func.Function
|
|
25
|
+
is_method_call: bool
|
|
26
|
+
agg_init_args: Dict[str, Any]
|
|
27
|
+
args: List[Tuple[Optional[int], Optional[Any]]]
|
|
28
|
+
kwargs: Dict[str, Tuple[Optional[int], Optional[Any]]]
|
|
29
|
+
arg_types: List[ts.ColumnType]
|
|
30
|
+
kwarg_types: Dict[str, ts.ColumnType]
|
|
31
|
+
group_by_start_idx: int
|
|
32
|
+
group_by_stop_idx: int
|
|
33
|
+
fn_expr_idx: int
|
|
34
|
+
order_by_start_idx: int
|
|
35
|
+
constant_args: set[str]
|
|
36
|
+
aggregator: Optional[Any]
|
|
37
|
+
current_partition_vals: Optional[List[Any]]
|
|
38
|
+
|
|
23
39
|
def __init__(
|
|
24
40
|
self, fn: func.Function, bound_args: Dict[str, Any], order_by_clause: Optional[List[Any]] = None,
|
|
25
41
|
group_by_clause: Optional[List[Any]] = None, is_method_call: bool = False):
|
|
@@ -31,9 +47,9 @@ class FunctionCall(Expr):
|
|
|
31
47
|
super().__init__(fn.call_return_type(bound_args))
|
|
32
48
|
self.fn = fn
|
|
33
49
|
self.is_method_call = is_method_call
|
|
34
|
-
self.
|
|
50
|
+
self.normalize_args(signature, bound_args)
|
|
35
51
|
|
|
36
|
-
self.agg_init_args
|
|
52
|
+
self.agg_init_args = {}
|
|
37
53
|
if self.is_agg_fn_call:
|
|
38
54
|
# we separate out the init args for the aggregator
|
|
39
55
|
self.agg_init_args = {
|
|
@@ -42,17 +58,16 @@ class FunctionCall(Expr):
|
|
|
42
58
|
bound_args = {arg_name: arg for arg_name, arg in bound_args.items() if arg_name not in fn.init_param_names}
|
|
43
59
|
|
|
44
60
|
# construct components, args, kwargs
|
|
45
|
-
self.components: List[Expr] = []
|
|
46
61
|
|
|
47
62
|
# Tuple[int, Any]:
|
|
48
63
|
# - for Exprs: (index into components, None)
|
|
49
64
|
# - otherwise: (None, val)
|
|
50
|
-
self.args
|
|
51
|
-
self.kwargs
|
|
65
|
+
self.args = []
|
|
66
|
+
self.kwargs = {}
|
|
52
67
|
|
|
53
68
|
# we record the types of non-variable parameters for runtime type checks
|
|
54
|
-
self.arg_types
|
|
55
|
-
self.kwarg_types
|
|
69
|
+
self.arg_types = []
|
|
70
|
+
self.kwarg_types = {}
|
|
56
71
|
# the prefix of parameters that are bound can be passed by position
|
|
57
72
|
for param in fn.signature.py_signature.parameters.values():
|
|
58
73
|
if param.name not in bound_args or param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
@@ -111,8 +126,8 @@ class FunctionCall(Expr):
|
|
|
111
126
|
|
|
112
127
|
self.constant_args = {param_name for param_name, arg in bound_args.items() if not isinstance(arg, Expr)}
|
|
113
128
|
# execution state for aggregate functions
|
|
114
|
-
self.aggregator
|
|
115
|
-
self.current_partition_vals
|
|
129
|
+
self.aggregator = None
|
|
130
|
+
self.current_partition_vals = None
|
|
116
131
|
|
|
117
132
|
self.id = self._create_id()
|
|
118
133
|
|
|
@@ -120,26 +135,37 @@ class FunctionCall(Expr):
|
|
|
120
135
|
target = tbl._tbl_version_path.tbl_version
|
|
121
136
|
return [RowidRef(target, i) for i in range(target.num_rowid_columns())]
|
|
122
137
|
|
|
138
|
+
def default_column_name(self) -> Optional[str]:
|
|
139
|
+
if self.fn.is_property:
|
|
140
|
+
return self.fn.name
|
|
141
|
+
return super().default_column_name()
|
|
142
|
+
|
|
123
143
|
@classmethod
|
|
124
|
-
def
|
|
125
|
-
"""
|
|
144
|
+
def normalize_args(cls, signature: func.Signature, bound_args: Dict[str, Any]) -> None:
|
|
145
|
+
"""Converts all args to Exprs and checks that they are compatible with signature.
|
|
126
146
|
|
|
127
|
-
|
|
147
|
+
Updates bound_args in place, where necessary.
|
|
128
148
|
"""
|
|
129
149
|
for param_name, arg in bound_args.items():
|
|
130
150
|
param = signature.parameters[param_name]
|
|
151
|
+
is_var_param = param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
|
152
|
+
|
|
131
153
|
if isinstance(arg, dict):
|
|
132
154
|
try:
|
|
133
155
|
arg = InlineDict(arg)
|
|
134
156
|
bound_args[param_name] = arg
|
|
157
|
+
continue
|
|
135
158
|
except excs.Error:
|
|
136
159
|
# this didn't work, but it might be a literal
|
|
137
160
|
pass
|
|
161
|
+
|
|
138
162
|
if isinstance(arg, list) or isinstance(arg, tuple):
|
|
139
163
|
try:
|
|
140
164
|
# If the column type is JsonType, force the literal to be JSON
|
|
141
|
-
|
|
165
|
+
is_json = is_var_param or (param.col_type is not None and param.col_type.is_json_type())
|
|
166
|
+
arg = InlineArray(arg, force_json=is_json)
|
|
142
167
|
bound_args[param_name] = arg
|
|
168
|
+
continue
|
|
143
169
|
except excs.Error:
|
|
144
170
|
# this didn't work, but it might be a literal
|
|
145
171
|
pass
|
|
@@ -149,30 +175,35 @@ class FunctionCall(Expr):
|
|
|
149
175
|
try:
|
|
150
176
|
_ = json.dumps(arg)
|
|
151
177
|
except TypeError:
|
|
152
|
-
raise excs.Error(f
|
|
178
|
+
raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg}')
|
|
153
179
|
if arg is not None:
|
|
154
180
|
try:
|
|
155
181
|
param_type = param.col_type
|
|
156
182
|
bound_args[param_name] = param_type.create_literal(arg)
|
|
157
183
|
except TypeError as e:
|
|
158
184
|
msg = str(e)
|
|
159
|
-
raise excs.Error(f
|
|
185
|
+
raise excs.Error(f'Argument for parameter {param_name!r}: {msg[0].lower() + msg[1:]}')
|
|
160
186
|
continue
|
|
161
187
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
188
|
+
# these checks break the db migration test, because InlineArray isn't serialized correctly (it looses
|
|
189
|
+
# the type information)
|
|
190
|
+
# if is_var_param:
|
|
191
|
+
# if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
192
|
+
# if not isinstance(arg, InlineArray) or not arg.col_type.is_json_type():
|
|
193
|
+
# pass
|
|
194
|
+
# assert isinstance(arg, InlineArray), type(arg)
|
|
195
|
+
# assert arg.col_type.is_json_type()
|
|
196
|
+
# if param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
197
|
+
# if not isinstance(arg, InlineDict):
|
|
198
|
+
# pass
|
|
199
|
+
# assert isinstance(arg, InlineDict), type(arg)
|
|
200
|
+
if is_var_param:
|
|
201
|
+
pass
|
|
202
|
+
else:
|
|
203
|
+
if not param.col_type.is_supertype_of(arg.col_type):
|
|
204
|
+
raise excs.Error(
|
|
205
|
+
f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
|
|
206
|
+
f'{param.col_type}')
|
|
176
207
|
|
|
177
208
|
def _equals(self, other: FunctionCall) -> bool:
|
|
178
209
|
if self.fn != other.fn:
|