pixeltable 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/table.py +118 -44
  3. pixeltable/catalog/view.py +2 -2
  4. pixeltable/dataframe.py +240 -92
  5. pixeltable/exec/__init__.py +1 -1
  6. pixeltable/exec/exec_node.py +6 -7
  7. pixeltable/exec/sql_node.py +91 -44
  8. pixeltable/exprs/__init__.py +1 -0
  9. pixeltable/exprs/arithmetic_expr.py +1 -1
  10. pixeltable/exprs/array_slice.py +1 -1
  11. pixeltable/exprs/column_property_ref.py +1 -1
  12. pixeltable/exprs/column_ref.py +29 -2
  13. pixeltable/exprs/comparison.py +1 -1
  14. pixeltable/exprs/compound_predicate.py +1 -1
  15. pixeltable/exprs/expr.py +11 -5
  16. pixeltable/exprs/expr_set.py +8 -0
  17. pixeltable/exprs/function_call.py +14 -11
  18. pixeltable/exprs/in_predicate.py +1 -1
  19. pixeltable/exprs/inline_expr.py +3 -3
  20. pixeltable/exprs/is_null.py +1 -1
  21. pixeltable/exprs/json_mapper.py +1 -1
  22. pixeltable/exprs/json_path.py +1 -1
  23. pixeltable/exprs/method_ref.py +1 -1
  24. pixeltable/exprs/rowid_ref.py +1 -1
  25. pixeltable/exprs/similarity_expr.py +1 -1
  26. pixeltable/exprs/sql_element_cache.py +4 -0
  27. pixeltable/exprs/type_cast.py +2 -2
  28. pixeltable/exprs/variable.py +3 -0
  29. pixeltable/func/expr_template_function.py +3 -0
  30. pixeltable/functions/ollama.py +4 -4
  31. pixeltable/globals.py +4 -1
  32. pixeltable/io/__init__.py +1 -1
  33. pixeltable/io/parquet.py +39 -19
  34. pixeltable/metadata/__init__.py +1 -1
  35. pixeltable/metadata/converters/convert_22.py +17 -0
  36. pixeltable/metadata/notes.py +1 -0
  37. pixeltable/plan.py +128 -50
  38. pixeltable/store.py +1 -1
  39. pixeltable/type_system.py +1 -1
  40. pixeltable/utils/arrow.py +8 -3
  41. pixeltable/utils/description_helper.py +89 -0
  42. {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/METADATA +26 -10
  43. {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/RECORD +46 -44
  44. {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/WHEEL +1 -1
  45. {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/LICENSE +0 -0
  46. {pixeltable-0.2.25.dist-info → pixeltable-0.2.26.dist-info}/entry_points.txt +0 -0
pixeltable/dataframe.py CHANGED
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import builtins
4
4
  import copy
5
+ import dataclasses
5
6
  import hashlib
6
7
  import json
7
8
  import logging
8
- import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union
11
+ from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, Literal
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -17,14 +17,15 @@ import sqlalchemy as sql
17
17
  import pixeltable.catalog as catalog
18
18
  import pixeltable.exceptions as excs
19
19
  import pixeltable.exprs as exprs
20
+ import pixeltable.type_system as ts
20
21
  from pixeltable import exec
22
+ from pixeltable import plan
21
23
  from pixeltable.catalog import is_valid_identifier
22
24
  from pixeltable.catalog.globals import UpdateStatus
23
25
  from pixeltable.env import Env
24
- from pixeltable.plan import Planner
25
26
  from pixeltable.type_system import ColumnType
27
+ from pixeltable.utils.description_helper import DescriptionHelper
26
28
  from pixeltable.utils.formatter import Formatter
27
- from pixeltable.utils.http_server import get_file_uri
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  import torch
@@ -131,9 +132,19 @@ class DataFrameResultSet:
131
132
 
132
133
 
133
134
  class DataFrame:
135
+ _from_clause: plan.FromClause
136
+ _select_list_exprs: list[exprs.Expr]
137
+ _schema: dict[str, ts.ColumnType]
138
+ select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]]
139
+ where_clause: Optional[exprs.Expr]
140
+ group_by_clause: Optional[list[exprs.Expr]]
141
+ grouping_tbl: Optional[catalog.TableVersion]
142
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
143
+ limit_val: Optional[int]
144
+
134
145
  def __init__(
135
146
  self,
136
- tbl: catalog.TableVersionPath,
147
+ from_clause: Optional[plan.FromClause] = None,
137
148
  select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None,
138
149
  where_clause: Optional[exprs.Expr] = None,
139
150
  group_by_clause: Optional[list[exprs.Expr]] = None,
@@ -141,14 +152,11 @@ class DataFrame:
141
152
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, # list[(expr, asc)]
142
153
  limit: Optional[int] = None,
143
154
  ):
144
- self.tbl = tbl
155
+ self._from_clause = from_clause
145
156
 
146
- # select list logic
147
- DataFrame._select_list_check_rep(select_list) # check select list without expansion
148
157
  # exprs contain execution state and therefore cannot be shared
149
158
  select_list = copy.deepcopy(select_list)
150
- select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
151
- DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
159
+ select_list_exprs, column_names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
152
160
  # check select list after expansion to catch early
153
161
  # the following two lists are always non empty, even if select list is None.
154
162
  assert len(column_names) == len(select_list_exprs)
@@ -163,28 +171,10 @@ class DataFrame:
163
171
  self.order_by_clause = copy.deepcopy(order_by_clause)
164
172
  self.limit_val = limit
165
173
 
166
- @classmethod
167
- def _select_list_check_rep(
168
- cls,
169
- select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
170
- ) -> None:
171
- """Validate basic select list types."""
172
- if select_list is None: # basic check for valid select list
173
- return
174
-
175
- assert len(select_list) > 0
176
- for ent in select_list:
177
- assert isinstance(ent, tuple)
178
- assert len(ent) == 2
179
- assert isinstance(ent[0], exprs.Expr)
180
- assert ent[1] is None or isinstance(ent[1], str)
181
- if isinstance(ent[1], str):
182
- assert is_valid_identifier(ent[1])
183
-
184
174
  @classmethod
185
175
  def _normalize_select_list(
186
176
  cls,
187
- tbl: catalog.TableVersionPath,
177
+ tbls: list[catalog.TableVersionPath],
188
178
  select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
189
179
  ) -> tuple[list[exprs.Expr], list[str]]:
190
180
  """
@@ -193,7 +183,7 @@ class DataFrame:
193
183
  a pair composed of the list of expressions and the list of corresponding names
194
184
  """
195
185
  if select_list is None:
196
- select_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
186
+ select_list = [(exprs.ColumnRef(col), None) for tbl in tbls for col in tbl.columns()]
197
187
 
198
188
  out_exprs: list[exprs.Expr] = []
199
189
  out_names: list[str] = [] # keep track of order
@@ -222,6 +212,11 @@ class DataFrame:
222
212
  assert set(out_names) == seen_out_names
223
213
  return out_exprs, out_names
224
214
 
215
+ @property
216
+ def _first_tbl(self) -> catalog.TableVersionPath:
217
+ assert len(self._from_clause.tbls) == 1
218
+ return self._from_clause.tbls[0]
219
+
225
220
  def _vars(self) -> dict[str, exprs.Variable]:
226
221
  """
227
222
  Return a dict mapping variable name to Variable for all Variables contained in any component of the DataFrame
@@ -280,16 +275,16 @@ class DataFrame:
280
275
  assert self.group_by_clause is None
281
276
  num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
282
277
  # the grouping table must be a base of self.tbl
283
- assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
284
- group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
278
+ assert num_rowid_cols <= len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
279
+ group_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
285
280
  elif self.group_by_clause is not None:
286
281
  group_by_clause = self.group_by_clause
287
282
 
288
283
  for item in self._select_list_exprs:
289
284
  item.bind_rel_paths(None)
290
285
 
291
- return Planner.create_query_plan(
292
- self.tbl,
286
+ return plan.Planner.create_query_plan(
287
+ self._from_clause,
293
288
  self._select_list_exprs,
294
289
  where_clause=self.where_clause,
295
290
  group_by_clause=group_by_clause,
@@ -297,6 +292,8 @@ class DataFrame:
297
292
  limit=self.limit_val
298
293
  )
299
294
 
295
+ def _has_joins(self) -> bool:
296
+ return len(self._from_clause.join_clauses) > 0
300
297
 
301
298
  def show(self, n: int = 20) -> DataFrameResultSet:
302
299
  assert n is not None
@@ -305,15 +302,19 @@ class DataFrame:
305
302
  def head(self, n: int = 10) -> DataFrameResultSet:
306
303
  if self.order_by_clause is not None:
307
304
  raise excs.Error(f'head() cannot be used with order_by()')
308
- num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
309
- order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
305
+ if self._has_joins():
306
+ raise excs.Error(f'head() not supported for joins')
307
+ num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
308
+ order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
310
309
  return self.order_by(*order_by_clause, asc=True).limit(n).collect()
311
310
 
312
311
  def tail(self, n: int = 10) -> DataFrameResultSet:
313
312
  if self.order_by_clause is not None:
314
313
  raise excs.Error(f'tail() cannot be used with order_by()')
315
- num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
316
- order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
314
+ if self._has_joins():
315
+ raise excs.Error(f'tail() not supported for joins')
316
+ num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
317
+ order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
317
318
  result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
318
319
  result._reverse()
319
320
  return result
@@ -359,7 +360,7 @@ class DataFrame:
359
360
  ]
360
361
 
361
362
  return DataFrame(
362
- self.tbl, select_list=select_list, where_clause=where_clause,
363
+ from_clause=self._from_clause, select_list=select_list, where_clause=where_clause,
363
364
  group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
364
365
  order_by_clause=order_by_clause, limit=self.limit_val)
365
366
 
@@ -395,28 +396,42 @@ class DataFrame:
395
396
  def count(self) -> int:
396
397
  from pixeltable.plan import Planner
397
398
 
398
- stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
399
+ stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
399
400
  with Env.get().engine.connect() as conn:
400
401
  result: int = conn.execute(stmt).scalar_one()
401
402
  assert isinstance(result, int)
402
403
  return result
403
404
 
404
- def _description(self) -> pd.DataFrame:
405
- """see DataFrame.describe()"""
405
+ def _descriptors(self) -> DescriptionHelper:
406
+ helper = DescriptionHelper()
407
+ helper.append(self._col_descriptor())
408
+ qd = self._query_descriptor()
409
+ if not qd.empty:
410
+ helper.append(qd, show_index=True, show_header=False)
411
+ return helper
412
+
413
+ def _col_descriptor(self) -> pd.DataFrame:
414
+ return pd.DataFrame([
415
+ {
416
+ 'Name': name,
417
+ 'Type': expr.col_type._to_str(as_schema=True),
418
+ 'Expression': expr.display_str(inline=False),
419
+ }
420
+ for name, expr in zip(self.schema.keys(), self._select_list_exprs)
421
+ ])
422
+
423
+ def _query_descriptor(self) -> pd.DataFrame:
406
424
  heading_vals: list[str] = []
407
425
  info_vals: list[str] = []
408
- if self.select_list is not None:
409
- assert len(self.select_list) > 0
410
- heading_vals.append('Select')
411
- heading_vals.extend([''] * (len(self.select_list) - 1))
412
- info_vals.extend(self.schema.keys())
426
+ heading_vals.append('From')
427
+ info_vals.extend(tbl.tbl_name() for tbl in self._from_clause.tbls)
413
428
  if self.where_clause is not None:
414
429
  heading_vals.append('Where')
415
430
  info_vals.append(self.where_clause.display_str(inline=False))
416
431
  if self.group_by_clause is not None:
417
432
  heading_vals.append('Group By')
418
433
  heading_vals.extend([''] * (len(self.group_by_clause) - 1))
419
- info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
434
+ info_vals.extend(e.display_str(inline=False) for e in self.group_by_clause)
420
435
  if self.order_by_clause is not None:
421
436
  heading_vals.append('Order By')
422
437
  heading_vals.extend([''] * (len(self.order_by_clause) - 1))
@@ -426,22 +441,8 @@ class DataFrame:
426
441
  if self.limit_val is not None:
427
442
  heading_vals.append('Limit')
428
443
  info_vals.append(str(self.limit_val))
429
- assert len(heading_vals) > 0
430
- assert len(info_vals) > 0
431
444
  assert len(heading_vals) == len(info_vals)
432
- return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
433
-
434
- def _description_html(self) -> pandas.io.formats.style.Styler:
435
- """Return the description in an ipython-friendly manner."""
436
- pd_df = self._description()
437
- # white-space: pre-wrap: print \n as newline
438
- # th: center-align headings
439
- return (
440
- pd_df.style.set_properties(None, **{'white-space': 'pre-wrap', 'text-align': 'left'})
441
- .set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
442
- .hide(axis='index')
443
- .hide(axis='columns')
444
- )
445
+ return pd.DataFrame(info_vals, index=heading_vals)
445
446
 
446
447
  def describe(self) -> None:
447
448
  """
@@ -451,15 +452,15 @@ class DataFrame:
451
452
  """
452
453
  if getattr(builtins, '__IPYTHON__', False):
453
454
  from IPython.display import display
454
- display(self._description_html())
455
+ display(self._repr_html_())
455
456
  else:
456
- print(self.__repr__())
457
+ print(repr(self))
457
458
 
458
459
  def __repr__(self) -> str:
459
- return self._description().to_string(header=False, index=False)
460
+ return self._descriptors().to_string()
460
461
 
461
462
  def _repr_html_(self) -> str:
462
- return self._description_html()._repr_html_() # type: ignore[attr-defined]
463
+ return self._descriptors().to_html()
463
464
 
464
465
  def select(self, *items: Any, **named_items: Any) -> DataFrame:
465
466
  if self.select_list is not None:
@@ -472,7 +473,7 @@ class DataFrame:
472
473
  return self
473
474
 
474
475
  # analyze select list; wrap literals with the corresponding expressions
475
- select_list = []
476
+ select_list: list[tuple[exprs.Expr, Optional[str]]] = []
476
477
  for raw_expr, name in base_list:
477
478
  if isinstance(raw_expr, exprs.Expr):
478
479
  select_list.append((raw_expr, name))
@@ -485,12 +486,14 @@ class DataFrame:
485
486
  expr = select_list[-1][0]
486
487
  if expr.col_type.is_invalid_type():
487
488
  raise excs.Error(f'Invalid type: {raw_expr}')
488
- # TODO: check that ColumnRefs in expr refer to self.tbl
489
+ if not expr.is_bound_by(self._from_clause.tbls):
490
+ raise excs.Error(
491
+ f"Expression '{expr}' cannot be evaluated in the context of this query's tables "
492
+ f"({','.join(tbl.tbl_name() for tbl in self._from_clause.tbls)})")
489
493
 
490
- # check user provided names do not conflict among themselves
491
- # or with auto-generated ones
494
+ # check user provided names do not conflict among themselves or with auto-generated ones
492
495
  seen: set[str] = set()
493
- _, names = DataFrame._normalize_select_list(self.tbl, select_list)
496
+ _, names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
494
497
  for name in names:
495
498
  if name in seen:
496
499
  repeated_names = [j for j, x in enumerate(names) if x == name]
@@ -499,7 +502,7 @@ class DataFrame:
499
502
  seen.add(name)
500
503
 
501
504
  return DataFrame(
502
- self.tbl,
505
+ from_clause=self._from_clause,
503
506
  select_list=select_list,
504
507
  where_clause=self.where_clause,
505
508
  group_by_clause=self.group_by_clause,
@@ -514,7 +517,7 @@ class DataFrame:
514
517
  if not pred.col_type.is_bool_type():
515
518
  raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
516
519
  return DataFrame(
517
- self.tbl,
520
+ from_clause=self._from_clause,
518
521
  select_list=self.select_list,
519
522
  where_clause=pred,
520
523
  group_by_clause=self.group_by_clause,
@@ -523,8 +526,144 @@ class DataFrame:
523
526
  limit=self.limit_val,
524
527
  )
525
528
 
529
+ def _create_join_predicate(
530
+ self, other: catalog.TableVersionPath, on: Union[exprs.Expr, Sequence[exprs.ColumnRef]]
531
+ ) -> exprs.Expr:
532
+ """Verifies user-specified 'on' argument and converts it into a join predicate."""
533
+ col_refs: list[exprs.ColumnRef] = []
534
+ joined_tbls = self._from_clause.tbls + [other]
535
+
536
+ if isinstance(on, exprs.ColumnRef):
537
+ on = [on]
538
+ elif isinstance(on, exprs.Expr):
539
+ if not on.is_bound_by(joined_tbls):
540
+ raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {on}")
541
+ if not on.col_type.is_bool_type():
542
+ raise excs.Error(f"'on': boolean expression expected, but got {on.col_type}: {on}")
543
+ return on
544
+ else:
545
+ if not isinstance(on, Sequence) or len(on) == 0:
546
+ raise excs.Error(
547
+ f"'on': must be a sequence of column references or a boolean expression")
548
+
549
+ assert isinstance(on, Sequence)
550
+ for col_ref in on:
551
+ if not isinstance(col_ref, exprs.ColumnRef):
552
+ raise excs.Error(
553
+ f"'on': must be a sequence of column references or a boolean expression")
554
+ if not col_ref.is_bound_by(joined_tbls):
555
+ raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {col_ref}")
556
+ col_refs.append(col_ref)
557
+
558
+ predicates: list[exprs.Expr] = []
559
+ # try to turn ColumnRefs into equality predicates
560
+ assert len(col_refs) > 0 and len(joined_tbls) >= 2
561
+ for col_ref in col_refs:
562
+ # identify the referenced column by name in 'other'
563
+ rhs_col = other.get_column(col_ref.col.name, include_bases=True)
564
+ if rhs_col is None:
565
+ raise excs.Error(f"'on': column {col_ref.col.name!r} not found in joined table")
566
+ rhs_col_ref = exprs.ColumnRef(rhs_col)
567
+
568
+ lhs_col_ref: Optional[exprs.ColumnRef] = None
569
+ if any(tbl.has_column(col_ref.col, include_bases=True) for tbl in self._from_clause.tbls):
570
+ # col_ref comes from the existing from_clause, we use that directly
571
+ lhs_col_ref = col_ref
572
+ else:
573
+ # col_ref comes from other, we need to look for a match in the existing from_clause by name
574
+ for tbl in self._from_clause.tbls:
575
+ col = tbl.get_column(col_ref.col.name, include_bases=True)
576
+ if col is None:
577
+ continue
578
+ if lhs_col_ref is not None:
579
+ raise excs.Error(f"'on': ambiguous column reference: {col_ref.col.name!r}")
580
+ lhs_col_ref = exprs.ColumnRef(col)
581
+ if lhs_col_ref is None:
582
+ tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
583
+ raise excs.Error(
584
+ f"'on': column {col_ref.col.name!r} not found in any of: {' '.join(tbl_names)}")
585
+ pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
586
+ predicates.append(pred)
587
+
588
+ assert len(predicates) > 0
589
+ if len(predicates) == 1:
590
+ return predicates[0]
591
+ else:
592
+ return exprs.CompoundPredicate(operator=exprs.LogicalOperator.AND, operands=predicates)
593
+
594
+ def join(
595
+ self, other: catalog.Table, on: Optional[Union[exprs.Expr, Sequence[exprs.ColumnRef]]] = None,
596
+ how: plan.JoinType.LiteralType = 'inner'
597
+ ) -> DataFrame:
598
+ """
599
+ Join this DataFrame with a table.
600
+
601
+ Args:
602
+ other: the table to join with
603
+ on: the join condition, which can be either a) references to one or more columns or b) a boolean
604
+ expression.
605
+
606
+ - column references: implies an equality predicate that matches columns in both this
607
+ DataFrame and `other` by name.
608
+
609
+ - column in `other`: A column with that same name must be present in this DataFrame, and **it must
610
+ be unique** (otherwise the join is ambiguous).
611
+ - column in this DataFrame: A column with that same name must be present in `other`.
612
+
613
+ - boolean expression: The expressions must be valid in the context of the joined tables.
614
+ how: the type of join to perform.
615
+
616
+ - `'inner'`: only keep rows that have a match in both
617
+ - `'left'`: keep all rows from this DataFrame and only matching rows from the other table
618
+ - `'right'`: keep all rows from the other table and only matching rows from this DataFrame
619
+ - `'full_outer'`: keep all rows from both this DataFrame and the other table
620
+ - `'cross'`: Cartesian product; no `on` condition allowed
621
+
622
+ Returns:
623
+ A new DataFrame.
624
+
625
+ Examples:
626
+ Perform an inner join between t1 and t2 on the column id:
627
+
628
+ >>> join1 = t1.join(t2, on=t2.id)
629
+
630
+ Perform a left outer join of join1 with t3, also on id (note that we can't specify `on=t3.id` here,
631
+ because that would be ambiguous, since both t1 and t2 have a column named id):
632
+
633
+ >>> join2 = join1.join(t3, on=t2.id, how='left')
634
+
635
+ Do the same, but now with an explicit join predicate:
636
+
637
+ >>> join2 = join1.join(t3, on=t2.id == t3.id, how='left')
638
+
639
+ Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
640
+ key columns d1 and d2 in t):
641
+
642
+ >>> df = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
643
+ """
644
+ join_pred: Optional[exprs.Expr]
645
+ if how == 'cross':
646
+ if on is not None:
647
+ raise excs.Error(f"'on' not allowed for cross join")
648
+ join_pred = None
649
+ else:
650
+ if on is None:
651
+ raise excs.Error(f"how={how!r} requires 'on'")
652
+ join_pred = self._create_join_predicate(other._tbl_version_path, on)
653
+ join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, "'how'"), join_predicate=join_pred)
654
+ from_clause = plan.FromClause(
655
+ tbls=[*self._from_clause.tbls, other._tbl_version_path],
656
+ join_clauses=[*self._from_clause.join_clauses, join_clause])
657
+ return DataFrame(
658
+ from_clause=from_clause,
659
+ select_list=self.select_list, where_clause=self.where_clause,
660
+ group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl,
661
+ order_by_clause=self.order_by_clause, limit=self.limit_val,
662
+ )
663
+
526
664
  def group_by(self, *grouping_items: Any) -> DataFrame:
527
- """Add a group-by clause to this DataFrame.
665
+ """
666
+ Add a group-by clause to this DataFrame.
528
667
  Variants:
529
668
  - group_by(<base table>): group a component view by their respective base table rows
530
669
  - group_by(<expr>, ...): group by the given expressions
@@ -537,10 +676,12 @@ class DataFrame:
537
676
  if isinstance(item, catalog.Table):
538
677
  if len(grouping_items) > 1:
539
678
  raise excs.Error(f'group_by(): only one table can be specified')
679
+ if len(self._from_clause.tbls) > 1:
680
+ raise excs.Error(f'group_by() with Table not supported for joins')
540
681
  # we need to make sure that the grouping table is a base of self.tbl
541
- base = self.tbl.find_tbl_version(item._tbl_version_path.tbl_id())
542
- if base is None or base.id == self.tbl.tbl_id():
543
- raise excs.Error(f'group_by(): {item._name} is not a base table of {self.tbl.tbl_name()}')
682
+ base = self._first_tbl.find_tbl_version(item._tbl_version_path.tbl_id())
683
+ if base is None or base.id == self._first_tbl.tbl_id():
684
+ raise excs.Error(f'group_by(): {item._name} is not a base table of {self._first_tbl.tbl_name()}')
544
685
  grouping_tbl = item._tbl_version_path.tbl_version
545
686
  break
546
687
  if not isinstance(item, exprs.Expr):
@@ -548,7 +689,7 @@ class DataFrame:
548
689
  if grouping_tbl is None:
549
690
  group_by_clause = list(grouping_items)
550
691
  return DataFrame(
551
- self.tbl,
692
+ from_clause=self._from_clause,
552
693
  select_list=self.select_list,
553
694
  where_clause=self.where_clause,
554
695
  group_by_clause=group_by_clause,
@@ -564,7 +705,7 @@ class DataFrame:
564
705
  order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
565
706
  order_by_clause.extend([(e.copy(), asc) for e in expr_list])
566
707
  return DataFrame(
567
- self.tbl,
708
+ from_clause=self._from_clause,
568
709
  select_list=self.select_list,
569
710
  where_clause=self.where_clause,
570
711
  group_by_clause=self.group_by_clause,
@@ -577,7 +718,7 @@ class DataFrame:
577
718
  # TODO: allow n to be a Variable that can be substituted in bind()
578
719
  assert n is not None and isinstance(n, int)
579
720
  return DataFrame(
580
- self.tbl,
721
+ from_clause=self._from_clause,
581
722
  select_list=self.select_list,
582
723
  where_clause=self.where_clause,
583
724
  group_by_clause=self.group_by_clause,
@@ -588,13 +729,13 @@ class DataFrame:
588
729
 
589
730
  def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
590
731
  self._validate_mutable('update')
591
- return self.tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
732
+ return self._first_tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
592
733
 
593
734
  def delete(self) -> UpdateStatus:
594
735
  self._validate_mutable('delete')
595
- if not self.tbl.is_insertable():
736
+ if not self._first_tbl.is_insertable():
596
737
  raise excs.Error(f'Cannot delete from view')
597
- return self.tbl.tbl_version.delete(where=self.where_clause)
738
+ return self._first_tbl.tbl_version.delete(where=self.where_clause)
598
739
 
599
740
  def _validate_mutable(self, op_name: str) -> None:
600
741
  """Tests whether this `DataFrame` can be mutated (such as by an update operation)."""
@@ -624,10 +765,12 @@ class DataFrame:
624
765
  Returns:
625
766
  Dictionary representing this dataframe.
626
767
  """
627
- tbl_versions = self.tbl.get_tbl_versions()
628
768
  d = {
629
769
  '_classname': 'DataFrame',
630
- 'tbl': self.tbl.as_dict(),
770
+ 'from_clause': {
771
+ 'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
772
+ 'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses]
773
+ },
631
774
  'select_list':
632
775
  [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
633
776
  'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
@@ -642,7 +785,9 @@ class DataFrame:
642
785
 
643
786
  @classmethod
644
787
  def from_dict(cls, d: dict[str, Any]) -> 'DataFrame':
645
- tbl = catalog.TableVersionPath.from_dict(d['tbl'])
788
+ tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
789
+ join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
790
+ from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
646
791
  select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
647
792
  if d['select_list'] is not None else None
648
793
  where_clause = exprs.Expr.from_dict(d['where_clause']) \
@@ -655,15 +800,18 @@ class DataFrame:
655
800
  if d['order_by_clause'] is not None else None
656
801
  limit_val = d['limit_val']
657
802
  return DataFrame(
658
- tbl, select_list=select_list, where_clause=where_clause, group_by_clause=group_by_clause,
659
- grouping_tbl=grouping_tbl, order_by_clause=order_by_clause, limit=limit_val)
803
+ from_clause=from_clause, select_list=select_list, where_clause=where_clause,
804
+ group_by_clause=group_by_clause, grouping_tbl=grouping_tbl, order_by_clause=order_by_clause,
805
+ limit=limit_val)
660
806
 
661
807
  def _hash_result_set(self) -> str:
662
808
  """Return a hash that changes when the result set changes."""
663
809
  d = self.as_dict()
664
810
  # add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
665
811
  # invalidation when any of the referenced tables changes
666
- d['tbl_versions'] = [tbl_version.version for tbl_version in self.tbl.get_tbl_versions()]
812
+ d['tbl_versions'] = [
813
+ tbl_version.version for tbl in self._from_clause.tbls for tbl_version in tbl.get_tbl_versions()
814
+ ]
667
815
  summary_string = json.dumps(d)
668
816
  return hashlib.sha256(summary_string.encode()).hexdigest()
669
817
 
@@ -732,7 +880,7 @@ class DataFrame:
732
880
  Env.get().require_package('torch')
733
881
  Env.get().require_package('torchvision')
734
882
 
735
- from pixeltable.io.parquet import save_parquet
883
+ from pixeltable.io import export_parquet
736
884
  from pixeltable.utils.pytorch import PixeltablePytorchDataset
737
885
 
738
886
  cache_key = self._hash_result_set()
@@ -741,6 +889,6 @@ class DataFrame:
741
889
  if dest_path.exists(): # fast path: use cache
742
890
  assert dest_path.is_dir()
743
891
  else:
744
- save_parquet(self, dest_path)
892
+ export_parquet(self, dest_path, inline_images=True)
745
893
 
746
894
  return PixeltablePytorchDataset(path=dest_path, image_format=image_format)
@@ -7,4 +7,4 @@ from .exec_node import ExecNode
7
7
  from .expr_eval_node import ExprEvalNode
8
8
  from .in_memory_data_node import InMemoryDataNode
9
9
  from .row_update_node import RowUpdateNode
10
- from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
10
+ from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode, SqlJoinNode
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Iterable, Iterator, Optional
4
+ from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar
5
5
 
6
6
  import pixeltable.exprs as exprs
7
7
 
8
8
  from .data_row_batch import DataRowBatch
9
9
  from .exec_context import ExecContext
10
10
 
11
- if TYPE_CHECKING:
12
- from pixeltable import exec
13
11
 
14
12
  class ExecNode(abc.ABC):
15
13
  """Base class of all execution nodes"""
@@ -77,12 +75,13 @@ class ExecNode(abc.ABC):
77
75
  def _close(self) -> None:
78
76
  pass
79
77
 
80
- def get_sql_node(self) -> Optional['exec.SqlNode']:
81
- from .sql_node import SqlNode
82
- if isinstance(self, SqlNode):
78
+ T = TypeVar('T', bound='ExecNode')
79
+
80
+ def get_node(self, node_class: type[T]) -> Optional[T]:
81
+ if isinstance(self, node_class):
83
82
  return self
84
83
  if self.input is not None:
85
- return self.input.get_sql_node()
84
+ return self.input.get_node(node_class)
86
85
  return None
87
86
 
88
87
  def set_limit(self, limit: int) -> None: