pixeltable 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/table.py +247 -83
- pixeltable/catalog/view.py +5 -2
- pixeltable/dataframe.py +240 -92
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/exec_node.py +6 -7
- pixeltable/exec/sql_node.py +91 -44
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +1 -1
- pixeltable/exprs/array_slice.py +1 -1
- pixeltable/exprs/column_property_ref.py +1 -1
- pixeltable/exprs/column_ref.py +29 -2
- pixeltable/exprs/comparison.py +1 -1
- pixeltable/exprs/compound_predicate.py +1 -1
- pixeltable/exprs/expr.py +11 -5
- pixeltable/exprs/expr_set.py +8 -0
- pixeltable/exprs/function_call.py +14 -11
- pixeltable/exprs/in_predicate.py +1 -1
- pixeltable/exprs/inline_expr.py +3 -3
- pixeltable/exprs/is_null.py +1 -1
- pixeltable/exprs/json_mapper.py +1 -1
- pixeltable/exprs/json_path.py +1 -1
- pixeltable/exprs/method_ref.py +1 -1
- pixeltable/exprs/rowid_ref.py +1 -1
- pixeltable/exprs/similarity_expr.py +4 -1
- pixeltable/exprs/sql_element_cache.py +4 -0
- pixeltable/exprs/type_cast.py +2 -2
- pixeltable/exprs/variable.py +3 -0
- pixeltable/func/expr_template_function.py +3 -0
- pixeltable/func/function.py +37 -1
- pixeltable/func/signature.py +1 -0
- pixeltable/functions/mistralai.py +0 -2
- pixeltable/functions/ollama.py +4 -4
- pixeltable/globals.py +32 -18
- pixeltable/index/embedding_index.py +6 -1
- pixeltable/io/__init__.py +1 -1
- pixeltable/io/parquet.py +39 -19
- pixeltable/iterators/__init__.py +1 -0
- pixeltable/iterators/image.py +100 -0
- pixeltable/iterators/video.py +7 -8
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_22.py +17 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/plan.py +129 -51
- pixeltable/store.py +1 -1
- pixeltable/tool/create_test_db_dump.py +4 -1
- pixeltable/type_system.py +1 -1
- pixeltable/utils/arrow.py +8 -3
- pixeltable/utils/description_helper.py +89 -0
- {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/METADATA +28 -12
- {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/RECORD +54 -51
- {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/WHEEL +1 -1
- {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/entry_points.txt +0 -0
pixeltable/dataframe.py
CHANGED
|
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import builtins
|
|
4
4
|
import copy
|
|
5
|
+
import dataclasses
|
|
5
6
|
import hashlib
|
|
6
7
|
import json
|
|
7
8
|
import logging
|
|
8
|
-
import mimetypes
|
|
9
9
|
import traceback
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, Literal
|
|
12
12
|
|
|
13
13
|
import pandas as pd
|
|
14
14
|
import pandas.io.formats.style
|
|
@@ -17,14 +17,15 @@ import sqlalchemy as sql
|
|
|
17
17
|
import pixeltable.catalog as catalog
|
|
18
18
|
import pixeltable.exceptions as excs
|
|
19
19
|
import pixeltable.exprs as exprs
|
|
20
|
+
import pixeltable.type_system as ts
|
|
20
21
|
from pixeltable import exec
|
|
22
|
+
from pixeltable import plan
|
|
21
23
|
from pixeltable.catalog import is_valid_identifier
|
|
22
24
|
from pixeltable.catalog.globals import UpdateStatus
|
|
23
25
|
from pixeltable.env import Env
|
|
24
|
-
from pixeltable.plan import Planner
|
|
25
26
|
from pixeltable.type_system import ColumnType
|
|
27
|
+
from pixeltable.utils.description_helper import DescriptionHelper
|
|
26
28
|
from pixeltable.utils.formatter import Formatter
|
|
27
|
-
from pixeltable.utils.http_server import get_file_uri
|
|
28
29
|
|
|
29
30
|
if TYPE_CHECKING:
|
|
30
31
|
import torch
|
|
@@ -131,9 +132,19 @@ class DataFrameResultSet:
|
|
|
131
132
|
|
|
132
133
|
|
|
133
134
|
class DataFrame:
|
|
135
|
+
_from_clause: plan.FromClause
|
|
136
|
+
_select_list_exprs: list[exprs.Expr]
|
|
137
|
+
_schema: dict[str, ts.ColumnType]
|
|
138
|
+
select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]]
|
|
139
|
+
where_clause: Optional[exprs.Expr]
|
|
140
|
+
group_by_clause: Optional[list[exprs.Expr]]
|
|
141
|
+
grouping_tbl: Optional[catalog.TableVersion]
|
|
142
|
+
order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
|
|
143
|
+
limit_val: Optional[int]
|
|
144
|
+
|
|
134
145
|
def __init__(
|
|
135
146
|
self,
|
|
136
|
-
|
|
147
|
+
from_clause: Optional[plan.FromClause] = None,
|
|
137
148
|
select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None,
|
|
138
149
|
where_clause: Optional[exprs.Expr] = None,
|
|
139
150
|
group_by_clause: Optional[list[exprs.Expr]] = None,
|
|
@@ -141,14 +152,11 @@ class DataFrame:
|
|
|
141
152
|
order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, # list[(expr, asc)]
|
|
142
153
|
limit: Optional[int] = None,
|
|
143
154
|
):
|
|
144
|
-
self.
|
|
155
|
+
self._from_clause = from_clause
|
|
145
156
|
|
|
146
|
-
# select list logic
|
|
147
|
-
DataFrame._select_list_check_rep(select_list) # check select list without expansion
|
|
148
157
|
# exprs contain execution state and therefore cannot be shared
|
|
149
158
|
select_list = copy.deepcopy(select_list)
|
|
150
|
-
select_list_exprs, column_names = DataFrame._normalize_select_list(
|
|
151
|
-
DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
|
|
159
|
+
select_list_exprs, column_names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
|
|
152
160
|
# check select list after expansion to catch early
|
|
153
161
|
# the following two lists are always non empty, even if select list is None.
|
|
154
162
|
assert len(column_names) == len(select_list_exprs)
|
|
@@ -163,28 +171,10 @@ class DataFrame:
|
|
|
163
171
|
self.order_by_clause = copy.deepcopy(order_by_clause)
|
|
164
172
|
self.limit_val = limit
|
|
165
173
|
|
|
166
|
-
@classmethod
|
|
167
|
-
def _select_list_check_rep(
|
|
168
|
-
cls,
|
|
169
|
-
select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
|
|
170
|
-
) -> None:
|
|
171
|
-
"""Validate basic select list types."""
|
|
172
|
-
if select_list is None: # basic check for valid select list
|
|
173
|
-
return
|
|
174
|
-
|
|
175
|
-
assert len(select_list) > 0
|
|
176
|
-
for ent in select_list:
|
|
177
|
-
assert isinstance(ent, tuple)
|
|
178
|
-
assert len(ent) == 2
|
|
179
|
-
assert isinstance(ent[0], exprs.Expr)
|
|
180
|
-
assert ent[1] is None or isinstance(ent[1], str)
|
|
181
|
-
if isinstance(ent[1], str):
|
|
182
|
-
assert is_valid_identifier(ent[1])
|
|
183
|
-
|
|
184
174
|
@classmethod
|
|
185
175
|
def _normalize_select_list(
|
|
186
176
|
cls,
|
|
187
|
-
|
|
177
|
+
tbls: list[catalog.TableVersionPath],
|
|
188
178
|
select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
|
|
189
179
|
) -> tuple[list[exprs.Expr], list[str]]:
|
|
190
180
|
"""
|
|
@@ -193,7 +183,7 @@ class DataFrame:
|
|
|
193
183
|
a pair composed of the list of expressions and the list of corresponding names
|
|
194
184
|
"""
|
|
195
185
|
if select_list is None:
|
|
196
|
-
select_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
|
|
186
|
+
select_list = [(exprs.ColumnRef(col), None) for tbl in tbls for col in tbl.columns()]
|
|
197
187
|
|
|
198
188
|
out_exprs: list[exprs.Expr] = []
|
|
199
189
|
out_names: list[str] = [] # keep track of order
|
|
@@ -222,6 +212,11 @@ class DataFrame:
|
|
|
222
212
|
assert set(out_names) == seen_out_names
|
|
223
213
|
return out_exprs, out_names
|
|
224
214
|
|
|
215
|
+
@property
|
|
216
|
+
def _first_tbl(self) -> catalog.TableVersionPath:
|
|
217
|
+
assert len(self._from_clause.tbls) == 1
|
|
218
|
+
return self._from_clause.tbls[0]
|
|
219
|
+
|
|
225
220
|
def _vars(self) -> dict[str, exprs.Variable]:
|
|
226
221
|
"""
|
|
227
222
|
Return a dict mapping variable name to Variable for all Variables contained in any component of the DataFrame
|
|
@@ -280,16 +275,16 @@ class DataFrame:
|
|
|
280
275
|
assert self.group_by_clause is None
|
|
281
276
|
num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
|
|
282
277
|
# the grouping table must be a base of self.tbl
|
|
283
|
-
assert num_rowid_cols <= len(self.
|
|
284
|
-
group_by_clause = [exprs.RowidRef(self.
|
|
278
|
+
assert num_rowid_cols <= len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
|
|
279
|
+
group_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
285
280
|
elif self.group_by_clause is not None:
|
|
286
281
|
group_by_clause = self.group_by_clause
|
|
287
282
|
|
|
288
283
|
for item in self._select_list_exprs:
|
|
289
284
|
item.bind_rel_paths(None)
|
|
290
285
|
|
|
291
|
-
return Planner.create_query_plan(
|
|
292
|
-
self.
|
|
286
|
+
return plan.Planner.create_query_plan(
|
|
287
|
+
self._from_clause,
|
|
293
288
|
self._select_list_exprs,
|
|
294
289
|
where_clause=self.where_clause,
|
|
295
290
|
group_by_clause=group_by_clause,
|
|
@@ -297,6 +292,8 @@ class DataFrame:
|
|
|
297
292
|
limit=self.limit_val
|
|
298
293
|
)
|
|
299
294
|
|
|
295
|
+
def _has_joins(self) -> bool:
|
|
296
|
+
return len(self._from_clause.join_clauses) > 0
|
|
300
297
|
|
|
301
298
|
def show(self, n: int = 20) -> DataFrameResultSet:
|
|
302
299
|
assert n is not None
|
|
@@ -305,15 +302,19 @@ class DataFrame:
|
|
|
305
302
|
def head(self, n: int = 10) -> DataFrameResultSet:
|
|
306
303
|
if self.order_by_clause is not None:
|
|
307
304
|
raise excs.Error(f'head() cannot be used with order_by()')
|
|
308
|
-
|
|
309
|
-
|
|
305
|
+
if self._has_joins():
|
|
306
|
+
raise excs.Error(f'head() not supported for joins')
|
|
307
|
+
num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
|
|
308
|
+
order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
310
309
|
return self.order_by(*order_by_clause, asc=True).limit(n).collect()
|
|
311
310
|
|
|
312
311
|
def tail(self, n: int = 10) -> DataFrameResultSet:
|
|
313
312
|
if self.order_by_clause is not None:
|
|
314
313
|
raise excs.Error(f'tail() cannot be used with order_by()')
|
|
315
|
-
|
|
316
|
-
|
|
314
|
+
if self._has_joins():
|
|
315
|
+
raise excs.Error(f'tail() not supported for joins')
|
|
316
|
+
num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
|
|
317
|
+
order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
|
|
317
318
|
result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
|
|
318
319
|
result._reverse()
|
|
319
320
|
return result
|
|
@@ -359,7 +360,7 @@ class DataFrame:
|
|
|
359
360
|
]
|
|
360
361
|
|
|
361
362
|
return DataFrame(
|
|
362
|
-
self.
|
|
363
|
+
from_clause=self._from_clause, select_list=select_list, where_clause=where_clause,
|
|
363
364
|
group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
|
|
364
365
|
order_by_clause=order_by_clause, limit=self.limit_val)
|
|
365
366
|
|
|
@@ -395,28 +396,42 @@ class DataFrame:
|
|
|
395
396
|
def count(self) -> int:
|
|
396
397
|
from pixeltable.plan import Planner
|
|
397
398
|
|
|
398
|
-
stmt = Planner.create_count_stmt(self.
|
|
399
|
+
stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
|
|
399
400
|
with Env.get().engine.connect() as conn:
|
|
400
401
|
result: int = conn.execute(stmt).scalar_one()
|
|
401
402
|
assert isinstance(result, int)
|
|
402
403
|
return result
|
|
403
404
|
|
|
404
|
-
def
|
|
405
|
-
|
|
405
|
+
def _descriptors(self) -> DescriptionHelper:
|
|
406
|
+
helper = DescriptionHelper()
|
|
407
|
+
helper.append(self._col_descriptor())
|
|
408
|
+
qd = self._query_descriptor()
|
|
409
|
+
if not qd.empty:
|
|
410
|
+
helper.append(qd, show_index=True, show_header=False)
|
|
411
|
+
return helper
|
|
412
|
+
|
|
413
|
+
def _col_descriptor(self) -> pd.DataFrame:
|
|
414
|
+
return pd.DataFrame([
|
|
415
|
+
{
|
|
416
|
+
'Name': name,
|
|
417
|
+
'Type': expr.col_type._to_str(as_schema=True),
|
|
418
|
+
'Expression': expr.display_str(inline=False),
|
|
419
|
+
}
|
|
420
|
+
for name, expr in zip(self.schema.keys(), self._select_list_exprs)
|
|
421
|
+
])
|
|
422
|
+
|
|
423
|
+
def _query_descriptor(self) -> pd.DataFrame:
|
|
406
424
|
heading_vals: list[str] = []
|
|
407
425
|
info_vals: list[str] = []
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
heading_vals.append('Select')
|
|
411
|
-
heading_vals.extend([''] * (len(self.select_list) - 1))
|
|
412
|
-
info_vals.extend(self.schema.keys())
|
|
426
|
+
heading_vals.append('From')
|
|
427
|
+
info_vals.extend(tbl.tbl_name() for tbl in self._from_clause.tbls)
|
|
413
428
|
if self.where_clause is not None:
|
|
414
429
|
heading_vals.append('Where')
|
|
415
430
|
info_vals.append(self.where_clause.display_str(inline=False))
|
|
416
431
|
if self.group_by_clause is not None:
|
|
417
432
|
heading_vals.append('Group By')
|
|
418
433
|
heading_vals.extend([''] * (len(self.group_by_clause) - 1))
|
|
419
|
-
info_vals.extend(
|
|
434
|
+
info_vals.extend(e.display_str(inline=False) for e in self.group_by_clause)
|
|
420
435
|
if self.order_by_clause is not None:
|
|
421
436
|
heading_vals.append('Order By')
|
|
422
437
|
heading_vals.extend([''] * (len(self.order_by_clause) - 1))
|
|
@@ -426,22 +441,8 @@ class DataFrame:
|
|
|
426
441
|
if self.limit_val is not None:
|
|
427
442
|
heading_vals.append('Limit')
|
|
428
443
|
info_vals.append(str(self.limit_val))
|
|
429
|
-
assert len(heading_vals) > 0
|
|
430
|
-
assert len(info_vals) > 0
|
|
431
444
|
assert len(heading_vals) == len(info_vals)
|
|
432
|
-
return pd.DataFrame(
|
|
433
|
-
|
|
434
|
-
def _description_html(self) -> pandas.io.formats.style.Styler:
|
|
435
|
-
"""Return the description in an ipython-friendly manner."""
|
|
436
|
-
pd_df = self._description()
|
|
437
|
-
# white-space: pre-wrap: print \n as newline
|
|
438
|
-
# th: center-align headings
|
|
439
|
-
return (
|
|
440
|
-
pd_df.style.set_properties(None, **{'white-space': 'pre-wrap', 'text-align': 'left'})
|
|
441
|
-
.set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
|
|
442
|
-
.hide(axis='index')
|
|
443
|
-
.hide(axis='columns')
|
|
444
|
-
)
|
|
445
|
+
return pd.DataFrame(info_vals, index=heading_vals)
|
|
445
446
|
|
|
446
447
|
def describe(self) -> None:
|
|
447
448
|
"""
|
|
@@ -451,15 +452,15 @@ class DataFrame:
|
|
|
451
452
|
"""
|
|
452
453
|
if getattr(builtins, '__IPYTHON__', False):
|
|
453
454
|
from IPython.display import display
|
|
454
|
-
display(self.
|
|
455
|
+
display(self._repr_html_())
|
|
455
456
|
else:
|
|
456
|
-
print(self
|
|
457
|
+
print(repr(self))
|
|
457
458
|
|
|
458
459
|
def __repr__(self) -> str:
|
|
459
|
-
return self.
|
|
460
|
+
return self._descriptors().to_string()
|
|
460
461
|
|
|
461
462
|
def _repr_html_(self) -> str:
|
|
462
|
-
return self.
|
|
463
|
+
return self._descriptors().to_html()
|
|
463
464
|
|
|
464
465
|
def select(self, *items: Any, **named_items: Any) -> DataFrame:
|
|
465
466
|
if self.select_list is not None:
|
|
@@ -472,7 +473,7 @@ class DataFrame:
|
|
|
472
473
|
return self
|
|
473
474
|
|
|
474
475
|
# analyze select list; wrap literals with the corresponding expressions
|
|
475
|
-
select_list = []
|
|
476
|
+
select_list: list[tuple[exprs.Expr, Optional[str]]] = []
|
|
476
477
|
for raw_expr, name in base_list:
|
|
477
478
|
if isinstance(raw_expr, exprs.Expr):
|
|
478
479
|
select_list.append((raw_expr, name))
|
|
@@ -485,12 +486,14 @@ class DataFrame:
|
|
|
485
486
|
expr = select_list[-1][0]
|
|
486
487
|
if expr.col_type.is_invalid_type():
|
|
487
488
|
raise excs.Error(f'Invalid type: {raw_expr}')
|
|
488
|
-
|
|
489
|
+
if not expr.is_bound_by(self._from_clause.tbls):
|
|
490
|
+
raise excs.Error(
|
|
491
|
+
f"Expression '{expr}' cannot be evaluated in the context of this query's tables "
|
|
492
|
+
f"({','.join(tbl.tbl_name() for tbl in self._from_clause.tbls)})")
|
|
489
493
|
|
|
490
|
-
# check user provided names do not conflict among themselves
|
|
491
|
-
# or with auto-generated ones
|
|
494
|
+
# check user provided names do not conflict among themselves or with auto-generated ones
|
|
492
495
|
seen: set[str] = set()
|
|
493
|
-
_, names = DataFrame._normalize_select_list(self.
|
|
496
|
+
_, names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
|
|
494
497
|
for name in names:
|
|
495
498
|
if name in seen:
|
|
496
499
|
repeated_names = [j for j, x in enumerate(names) if x == name]
|
|
@@ -499,7 +502,7 @@ class DataFrame:
|
|
|
499
502
|
seen.add(name)
|
|
500
503
|
|
|
501
504
|
return DataFrame(
|
|
502
|
-
self.
|
|
505
|
+
from_clause=self._from_clause,
|
|
503
506
|
select_list=select_list,
|
|
504
507
|
where_clause=self.where_clause,
|
|
505
508
|
group_by_clause=self.group_by_clause,
|
|
@@ -514,7 +517,7 @@ class DataFrame:
|
|
|
514
517
|
if not pred.col_type.is_bool_type():
|
|
515
518
|
raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
|
|
516
519
|
return DataFrame(
|
|
517
|
-
self.
|
|
520
|
+
from_clause=self._from_clause,
|
|
518
521
|
select_list=self.select_list,
|
|
519
522
|
where_clause=pred,
|
|
520
523
|
group_by_clause=self.group_by_clause,
|
|
@@ -523,8 +526,144 @@ class DataFrame:
|
|
|
523
526
|
limit=self.limit_val,
|
|
524
527
|
)
|
|
525
528
|
|
|
529
|
+
def _create_join_predicate(
|
|
530
|
+
self, other: catalog.TableVersionPath, on: Union[exprs.Expr, Sequence[exprs.ColumnRef]]
|
|
531
|
+
) -> exprs.Expr:
|
|
532
|
+
"""Verifies user-specified 'on' argument and converts it into a join predicate."""
|
|
533
|
+
col_refs: list[exprs.ColumnRef] = []
|
|
534
|
+
joined_tbls = self._from_clause.tbls + [other]
|
|
535
|
+
|
|
536
|
+
if isinstance(on, exprs.ColumnRef):
|
|
537
|
+
on = [on]
|
|
538
|
+
elif isinstance(on, exprs.Expr):
|
|
539
|
+
if not on.is_bound_by(joined_tbls):
|
|
540
|
+
raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {on}")
|
|
541
|
+
if not on.col_type.is_bool_type():
|
|
542
|
+
raise excs.Error(f"'on': boolean expression expected, but got {on.col_type}: {on}")
|
|
543
|
+
return on
|
|
544
|
+
else:
|
|
545
|
+
if not isinstance(on, Sequence) or len(on) == 0:
|
|
546
|
+
raise excs.Error(
|
|
547
|
+
f"'on': must be a sequence of column references or a boolean expression")
|
|
548
|
+
|
|
549
|
+
assert isinstance(on, Sequence)
|
|
550
|
+
for col_ref in on:
|
|
551
|
+
if not isinstance(col_ref, exprs.ColumnRef):
|
|
552
|
+
raise excs.Error(
|
|
553
|
+
f"'on': must be a sequence of column references or a boolean expression")
|
|
554
|
+
if not col_ref.is_bound_by(joined_tbls):
|
|
555
|
+
raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {col_ref}")
|
|
556
|
+
col_refs.append(col_ref)
|
|
557
|
+
|
|
558
|
+
predicates: list[exprs.Expr] = []
|
|
559
|
+
# try to turn ColumnRefs into equality predicates
|
|
560
|
+
assert len(col_refs) > 0 and len(joined_tbls) >= 2
|
|
561
|
+
for col_ref in col_refs:
|
|
562
|
+
# identify the referenced column by name in 'other'
|
|
563
|
+
rhs_col = other.get_column(col_ref.col.name, include_bases=True)
|
|
564
|
+
if rhs_col is None:
|
|
565
|
+
raise excs.Error(f"'on': column {col_ref.col.name!r} not found in joined table")
|
|
566
|
+
rhs_col_ref = exprs.ColumnRef(rhs_col)
|
|
567
|
+
|
|
568
|
+
lhs_col_ref: Optional[exprs.ColumnRef] = None
|
|
569
|
+
if any(tbl.has_column(col_ref.col, include_bases=True) for tbl in self._from_clause.tbls):
|
|
570
|
+
# col_ref comes from the existing from_clause, we use that directly
|
|
571
|
+
lhs_col_ref = col_ref
|
|
572
|
+
else:
|
|
573
|
+
# col_ref comes from other, we need to look for a match in the existing from_clause by name
|
|
574
|
+
for tbl in self._from_clause.tbls:
|
|
575
|
+
col = tbl.get_column(col_ref.col.name, include_bases=True)
|
|
576
|
+
if col is None:
|
|
577
|
+
continue
|
|
578
|
+
if lhs_col_ref is not None:
|
|
579
|
+
raise excs.Error(f"'on': ambiguous column reference: {col_ref.col.name!r}")
|
|
580
|
+
lhs_col_ref = exprs.ColumnRef(col)
|
|
581
|
+
if lhs_col_ref is None:
|
|
582
|
+
tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
|
|
583
|
+
raise excs.Error(
|
|
584
|
+
f"'on': column {col_ref.col.name!r} not found in any of: {' '.join(tbl_names)}")
|
|
585
|
+
pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
|
|
586
|
+
predicates.append(pred)
|
|
587
|
+
|
|
588
|
+
assert len(predicates) > 0
|
|
589
|
+
if len(predicates) == 1:
|
|
590
|
+
return predicates[0]
|
|
591
|
+
else:
|
|
592
|
+
return exprs.CompoundPredicate(operator=exprs.LogicalOperator.AND, operands=predicates)
|
|
593
|
+
|
|
594
|
+
def join(
|
|
595
|
+
self, other: catalog.Table, on: Optional[Union[exprs.Expr, Sequence[exprs.ColumnRef]]] = None,
|
|
596
|
+
how: plan.JoinType.LiteralType = 'inner'
|
|
597
|
+
) -> DataFrame:
|
|
598
|
+
"""
|
|
599
|
+
Join this DataFrame with a table.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
other: the table to join with
|
|
603
|
+
on: the join condition, which can be either a) references to one or more columns or b) a boolean
|
|
604
|
+
expression.
|
|
605
|
+
|
|
606
|
+
- column references: implies an equality predicate that matches columns in both this
|
|
607
|
+
DataFrame and `other` by name.
|
|
608
|
+
|
|
609
|
+
- column in `other`: A column with that same name must be present in this DataFrame, and **it must
|
|
610
|
+
be unique** (otherwise the join is ambiguous).
|
|
611
|
+
- column in this DataFrame: A column with that same name must be present in `other`.
|
|
612
|
+
|
|
613
|
+
- boolean expression: The expressions must be valid in the context of the joined tables.
|
|
614
|
+
how: the type of join to perform.
|
|
615
|
+
|
|
616
|
+
- `'inner'`: only keep rows that have a match in both
|
|
617
|
+
- `'left'`: keep all rows from this DataFrame and only matching rows from the other table
|
|
618
|
+
- `'right'`: keep all rows from the other table and only matching rows from this DataFrame
|
|
619
|
+
- `'full_outer'`: keep all rows from both this DataFrame and the other table
|
|
620
|
+
- `'cross'`: Cartesian product; no `on` condition allowed
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
A new DataFrame.
|
|
624
|
+
|
|
625
|
+
Examples:
|
|
626
|
+
Perform an inner join between t1 and t2 on the column id:
|
|
627
|
+
|
|
628
|
+
>>> join1 = t1.join(t2, on=t2.id)
|
|
629
|
+
|
|
630
|
+
Perform a left outer join of join1 with t3, also on id (note that we can't specify `on=t3.id` here,
|
|
631
|
+
because that would be ambiguous, since both t1 and t2 have a column named id):
|
|
632
|
+
|
|
633
|
+
>>> join2 = join1.join(t3, on=t2.id, how='left')
|
|
634
|
+
|
|
635
|
+
Do the same, but now with an explicit join predicate:
|
|
636
|
+
|
|
637
|
+
>>> join2 = join1.join(t3, on=t2.id == t3.id, how='left')
|
|
638
|
+
|
|
639
|
+
Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
|
|
640
|
+
key columns d1 and d2 in t):
|
|
641
|
+
|
|
642
|
+
>>> df = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
|
|
643
|
+
"""
|
|
644
|
+
join_pred: Optional[exprs.Expr]
|
|
645
|
+
if how == 'cross':
|
|
646
|
+
if on is not None:
|
|
647
|
+
raise excs.Error(f"'on' not allowed for cross join")
|
|
648
|
+
join_pred = None
|
|
649
|
+
else:
|
|
650
|
+
if on is None:
|
|
651
|
+
raise excs.Error(f"how={how!r} requires 'on'")
|
|
652
|
+
join_pred = self._create_join_predicate(other._tbl_version_path, on)
|
|
653
|
+
join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, "'how'"), join_predicate=join_pred)
|
|
654
|
+
from_clause = plan.FromClause(
|
|
655
|
+
tbls=[*self._from_clause.tbls, other._tbl_version_path],
|
|
656
|
+
join_clauses=[*self._from_clause.join_clauses, join_clause])
|
|
657
|
+
return DataFrame(
|
|
658
|
+
from_clause=from_clause,
|
|
659
|
+
select_list=self.select_list, where_clause=self.where_clause,
|
|
660
|
+
group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl,
|
|
661
|
+
order_by_clause=self.order_by_clause, limit=self.limit_val,
|
|
662
|
+
)
|
|
663
|
+
|
|
526
664
|
def group_by(self, *grouping_items: Any) -> DataFrame:
|
|
527
|
-
"""
|
|
665
|
+
"""
|
|
666
|
+
Add a group-by clause to this DataFrame.
|
|
528
667
|
Variants:
|
|
529
668
|
- group_by(<base table>): group a component view by their respective base table rows
|
|
530
669
|
- group_by(<expr>, ...): group by the given expressions
|
|
@@ -537,10 +676,12 @@ class DataFrame:
|
|
|
537
676
|
if isinstance(item, catalog.Table):
|
|
538
677
|
if len(grouping_items) > 1:
|
|
539
678
|
raise excs.Error(f'group_by(): only one table can be specified')
|
|
679
|
+
if len(self._from_clause.tbls) > 1:
|
|
680
|
+
raise excs.Error(f'group_by() with Table not supported for joins')
|
|
540
681
|
# we need to make sure that the grouping table is a base of self.tbl
|
|
541
|
-
base = self.
|
|
542
|
-
if base is None or base.id == self.
|
|
543
|
-
raise excs.Error(f'group_by(): {item._name} is not a base table of {self.
|
|
682
|
+
base = self._first_tbl.find_tbl_version(item._tbl_version_path.tbl_id())
|
|
683
|
+
if base is None or base.id == self._first_tbl.tbl_id():
|
|
684
|
+
raise excs.Error(f'group_by(): {item._name} is not a base table of {self._first_tbl.tbl_name()}')
|
|
544
685
|
grouping_tbl = item._tbl_version_path.tbl_version
|
|
545
686
|
break
|
|
546
687
|
if not isinstance(item, exprs.Expr):
|
|
@@ -548,7 +689,7 @@ class DataFrame:
|
|
|
548
689
|
if grouping_tbl is None:
|
|
549
690
|
group_by_clause = list(grouping_items)
|
|
550
691
|
return DataFrame(
|
|
551
|
-
self.
|
|
692
|
+
from_clause=self._from_clause,
|
|
552
693
|
select_list=self.select_list,
|
|
553
694
|
where_clause=self.where_clause,
|
|
554
695
|
group_by_clause=group_by_clause,
|
|
@@ -564,7 +705,7 @@ class DataFrame:
|
|
|
564
705
|
order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
|
|
565
706
|
order_by_clause.extend([(e.copy(), asc) for e in expr_list])
|
|
566
707
|
return DataFrame(
|
|
567
|
-
self.
|
|
708
|
+
from_clause=self._from_clause,
|
|
568
709
|
select_list=self.select_list,
|
|
569
710
|
where_clause=self.where_clause,
|
|
570
711
|
group_by_clause=self.group_by_clause,
|
|
@@ -577,7 +718,7 @@ class DataFrame:
|
|
|
577
718
|
# TODO: allow n to be a Variable that can be substituted in bind()
|
|
578
719
|
assert n is not None and isinstance(n, int)
|
|
579
720
|
return DataFrame(
|
|
580
|
-
self.
|
|
721
|
+
from_clause=self._from_clause,
|
|
581
722
|
select_list=self.select_list,
|
|
582
723
|
where_clause=self.where_clause,
|
|
583
724
|
group_by_clause=self.group_by_clause,
|
|
@@ -588,13 +729,13 @@ class DataFrame:
|
|
|
588
729
|
|
|
589
730
|
def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
|
|
590
731
|
self._validate_mutable('update')
|
|
591
|
-
return self.
|
|
732
|
+
return self._first_tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
|
|
592
733
|
|
|
593
734
|
def delete(self) -> UpdateStatus:
|
|
594
735
|
self._validate_mutable('delete')
|
|
595
|
-
if not self.
|
|
736
|
+
if not self._first_tbl.is_insertable():
|
|
596
737
|
raise excs.Error(f'Cannot delete from view')
|
|
597
|
-
return self.
|
|
738
|
+
return self._first_tbl.tbl_version.delete(where=self.where_clause)
|
|
598
739
|
|
|
599
740
|
def _validate_mutable(self, op_name: str) -> None:
|
|
600
741
|
"""Tests whether this `DataFrame` can be mutated (such as by an update operation)."""
|
|
@@ -624,10 +765,12 @@ class DataFrame:
|
|
|
624
765
|
Returns:
|
|
625
766
|
Dictionary representing this dataframe.
|
|
626
767
|
"""
|
|
627
|
-
tbl_versions = self.tbl.get_tbl_versions()
|
|
628
768
|
d = {
|
|
629
769
|
'_classname': 'DataFrame',
|
|
630
|
-
'
|
|
770
|
+
'from_clause': {
|
|
771
|
+
'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
|
|
772
|
+
'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses]
|
|
773
|
+
},
|
|
631
774
|
'select_list':
|
|
632
775
|
[(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
|
|
633
776
|
'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
|
|
@@ -642,7 +785,9 @@ class DataFrame:
|
|
|
642
785
|
|
|
643
786
|
@classmethod
|
|
644
787
|
def from_dict(cls, d: dict[str, Any]) -> 'DataFrame':
|
|
645
|
-
|
|
788
|
+
tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
|
|
789
|
+
join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
|
|
790
|
+
from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
|
|
646
791
|
select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
|
|
647
792
|
if d['select_list'] is not None else None
|
|
648
793
|
where_clause = exprs.Expr.from_dict(d['where_clause']) \
|
|
@@ -655,15 +800,18 @@ class DataFrame:
|
|
|
655
800
|
if d['order_by_clause'] is not None else None
|
|
656
801
|
limit_val = d['limit_val']
|
|
657
802
|
return DataFrame(
|
|
658
|
-
|
|
659
|
-
grouping_tbl=grouping_tbl, order_by_clause=order_by_clause,
|
|
803
|
+
from_clause=from_clause, select_list=select_list, where_clause=where_clause,
|
|
804
|
+
group_by_clause=group_by_clause, grouping_tbl=grouping_tbl, order_by_clause=order_by_clause,
|
|
805
|
+
limit=limit_val)
|
|
660
806
|
|
|
661
807
|
def _hash_result_set(self) -> str:
|
|
662
808
|
"""Return a hash that changes when the result set changes."""
|
|
663
809
|
d = self.as_dict()
|
|
664
810
|
# add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
|
|
665
811
|
# invalidation when any of the referenced tables changes
|
|
666
|
-
d['tbl_versions'] = [
|
|
812
|
+
d['tbl_versions'] = [
|
|
813
|
+
tbl_version.version for tbl in self._from_clause.tbls for tbl_version in tbl.get_tbl_versions()
|
|
814
|
+
]
|
|
667
815
|
summary_string = json.dumps(d)
|
|
668
816
|
return hashlib.sha256(summary_string.encode()).hexdigest()
|
|
669
817
|
|
|
@@ -732,7 +880,7 @@ class DataFrame:
|
|
|
732
880
|
Env.get().require_package('torch')
|
|
733
881
|
Env.get().require_package('torchvision')
|
|
734
882
|
|
|
735
|
-
from pixeltable.io
|
|
883
|
+
from pixeltable.io import export_parquet
|
|
736
884
|
from pixeltable.utils.pytorch import PixeltablePytorchDataset
|
|
737
885
|
|
|
738
886
|
cache_key = self._hash_result_set()
|
|
@@ -741,6 +889,6 @@ class DataFrame:
|
|
|
741
889
|
if dest_path.exists(): # fast path: use cache
|
|
742
890
|
assert dest_path.is_dir()
|
|
743
891
|
else:
|
|
744
|
-
|
|
892
|
+
export_parquet(self, dest_path, inline_images=True)
|
|
745
893
|
|
|
746
894
|
return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
|
pixeltable/exec/__init__.py
CHANGED
|
@@ -7,4 +7,4 @@ from .exec_node import ExecNode
|
|
|
7
7
|
from .expr_eval_node import ExprEvalNode
|
|
8
8
|
from .in_memory_data_node import InMemoryDataNode
|
|
9
9
|
from .row_update_node import RowUpdateNode
|
|
10
|
-
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
|
|
10
|
+
from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode, SqlJoinNode
|
pixeltable/exec/exec_node.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import TYPE_CHECKING, Iterable, Iterator, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
import pixeltable.exprs as exprs
|
|
7
7
|
|
|
8
8
|
from .data_row_batch import DataRowBatch
|
|
9
9
|
from .exec_context import ExecContext
|
|
10
10
|
|
|
11
|
-
if TYPE_CHECKING:
|
|
12
|
-
from pixeltable import exec
|
|
13
11
|
|
|
14
12
|
class ExecNode(abc.ABC):
|
|
15
13
|
"""Base class of all execution nodes"""
|
|
@@ -77,12 +75,13 @@ class ExecNode(abc.ABC):
|
|
|
77
75
|
def _close(self) -> None:
|
|
78
76
|
pass
|
|
79
77
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
78
|
+
T = TypeVar('T', bound='ExecNode')
|
|
79
|
+
|
|
80
|
+
def get_node(self, node_class: type[T]) -> Optional[T]:
|
|
81
|
+
if isinstance(self, node_class):
|
|
83
82
|
return self
|
|
84
83
|
if self.input is not None:
|
|
85
|
-
return self.input.
|
|
84
|
+
return self.input.get_node(node_class)
|
|
86
85
|
return None
|
|
87
86
|
|
|
88
87
|
def set_limit(self, limit: int) -> None:
|