pixeltable 0.2.25__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (97) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +421 -231
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +5 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +14 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +195 -27
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +64 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/ollama.py +8 -8
  58. pixeltable/functions/openai.py +51 -4
  59. pixeltable/functions/timestamp.py +1 -1
  60. pixeltable/functions/video.py +3 -9
  61. pixeltable/functions/vision.py +1 -1
  62. pixeltable/globals.py +354 -80
  63. pixeltable/index/embedding_index.py +106 -34
  64. pixeltable/io/__init__.py +1 -1
  65. pixeltable/io/label_studio.py +1 -1
  66. pixeltable/io/parquet.py +39 -19
  67. pixeltable/iterators/document.py +12 -0
  68. pixeltable/metadata/__init__.py +1 -1
  69. pixeltable/metadata/converters/convert_16.py +2 -1
  70. pixeltable/metadata/converters/convert_17.py +2 -1
  71. pixeltable/metadata/converters/convert_22.py +17 -0
  72. pixeltable/metadata/converters/convert_23.py +35 -0
  73. pixeltable/metadata/converters/convert_24.py +56 -0
  74. pixeltable/metadata/converters/convert_25.py +19 -0
  75. pixeltable/metadata/converters/util.py +4 -2
  76. pixeltable/metadata/notes.py +4 -0
  77. pixeltable/metadata/schema.py +1 -0
  78. pixeltable/plan.py +128 -50
  79. pixeltable/store.py +1 -1
  80. pixeltable/type_system.py +196 -54
  81. pixeltable/utils/arrow.py +8 -3
  82. pixeltable/utils/description_helper.py +89 -0
  83. pixeltable/utils/documents.py +14 -0
  84. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/METADATA +30 -20
  85. pixeltable-0.3.0.dist-info/RECORD +155 -0
  86. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  87. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  88. pixeltable/tool/create_test_db_dump.py +0 -311
  89. pixeltable/tool/create_test_video.py +0 -81
  90. pixeltable/tool/doc_plugins/griffe.py +0 -50
  91. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  92. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  93. pixeltable/tool/embed_udf.py +0 -9
  94. pixeltable/tool/mypy_plugin.py +0 -55
  95. pixeltable-0.2.25.dist-info/RECORD +0 -154
  96. pixeltable-0.2.25.dist-info/entry_points.txt +0 -3
  97. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
pixeltable/dataframe.py CHANGED
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import builtins
4
4
  import copy
5
+ import dataclasses
5
6
  import hashlib
6
7
  import json
7
8
  import logging
8
- import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union
11
+ from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, Optional, Sequence, Union, Literal
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -17,14 +17,15 @@ import sqlalchemy as sql
17
17
  import pixeltable.catalog as catalog
18
18
  import pixeltable.exceptions as excs
19
19
  import pixeltable.exprs as exprs
20
+ import pixeltable.type_system as ts
20
21
  from pixeltable import exec
22
+ from pixeltable import plan
21
23
  from pixeltable.catalog import is_valid_identifier
22
24
  from pixeltable.catalog.globals import UpdateStatus
23
25
  from pixeltable.env import Env
24
- from pixeltable.plan import Planner
25
26
  from pixeltable.type_system import ColumnType
27
+ from pixeltable.utils.description_helper import DescriptionHelper
26
28
  from pixeltable.utils.formatter import Formatter
27
- from pixeltable.utils.http_server import get_file_uri
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  import torch
@@ -131,9 +132,19 @@ class DataFrameResultSet:
131
132
 
132
133
 
133
134
  class DataFrame:
135
+ _from_clause: plan.FromClause
136
+ _select_list_exprs: list[exprs.Expr]
137
+ _schema: dict[str, ts.ColumnType]
138
+ select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]]
139
+ where_clause: Optional[exprs.Expr]
140
+ group_by_clause: Optional[list[exprs.Expr]]
141
+ grouping_tbl: Optional[catalog.TableVersion]
142
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
143
+ limit_val: Optional[int]
144
+
134
145
  def __init__(
135
146
  self,
136
- tbl: catalog.TableVersionPath,
147
+ from_clause: Optional[plan.FromClause] = None,
137
148
  select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]] = None,
138
149
  where_clause: Optional[exprs.Expr] = None,
139
150
  group_by_clause: Optional[list[exprs.Expr]] = None,
@@ -141,14 +152,11 @@ class DataFrame:
141
152
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, # list[(expr, asc)]
142
153
  limit: Optional[int] = None,
143
154
  ):
144
- self.tbl = tbl
155
+ self._from_clause = from_clause
145
156
 
146
- # select list logic
147
- DataFrame._select_list_check_rep(select_list) # check select list without expansion
148
157
  # exprs contain execution state and therefore cannot be shared
149
158
  select_list = copy.deepcopy(select_list)
150
- select_list_exprs, column_names = DataFrame._normalize_select_list(tbl, select_list)
151
- DataFrame._select_list_check_rep(list(zip(select_list_exprs, column_names)))
159
+ select_list_exprs, column_names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
152
160
  # check select list after expansion to catch early
153
161
  # the following two lists are always non empty, even if select list is None.
154
162
  assert len(column_names) == len(select_list_exprs)
@@ -163,28 +171,10 @@ class DataFrame:
163
171
  self.order_by_clause = copy.deepcopy(order_by_clause)
164
172
  self.limit_val = limit
165
173
 
166
- @classmethod
167
- def _select_list_check_rep(
168
- cls,
169
- select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
170
- ) -> None:
171
- """Validate basic select list types."""
172
- if select_list is None: # basic check for valid select list
173
- return
174
-
175
- assert len(select_list) > 0
176
- for ent in select_list:
177
- assert isinstance(ent, tuple)
178
- assert len(ent) == 2
179
- assert isinstance(ent[0], exprs.Expr)
180
- assert ent[1] is None or isinstance(ent[1], str)
181
- if isinstance(ent[1], str):
182
- assert is_valid_identifier(ent[1])
183
-
184
174
  @classmethod
185
175
  def _normalize_select_list(
186
176
  cls,
187
- tbl: catalog.TableVersionPath,
177
+ tbls: list[catalog.TableVersionPath],
188
178
  select_list: Optional[list[tuple[exprs.Expr, Optional[str]]]],
189
179
  ) -> tuple[list[exprs.Expr], list[str]]:
190
180
  """
@@ -193,7 +183,7 @@ class DataFrame:
193
183
  a pair composed of the list of expressions and the list of corresponding names
194
184
  """
195
185
  if select_list is None:
196
- select_list = [(exprs.ColumnRef(col), None) for col in tbl.columns()]
186
+ select_list = [(exprs.ColumnRef(col), None) for tbl in tbls for col in tbl.columns()]
197
187
 
198
188
  out_exprs: list[exprs.Expr] = []
199
189
  out_names: list[str] = [] # keep track of order
@@ -222,6 +212,11 @@ class DataFrame:
222
212
  assert set(out_names) == seen_out_names
223
213
  return out_exprs, out_names
224
214
 
215
+ @property
216
+ def _first_tbl(self) -> catalog.TableVersionPath:
217
+ assert len(self._from_clause.tbls) == 1
218
+ return self._from_clause.tbls[0]
219
+
225
220
  def _vars(self) -> dict[str, exprs.Variable]:
226
221
  """
227
222
  Return a dict mapping variable name to Variable for all Variables contained in any component of the DataFrame
@@ -280,16 +275,16 @@ class DataFrame:
280
275
  assert self.group_by_clause is None
281
276
  num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
282
277
  # the grouping table must be a base of self.tbl
283
- assert num_rowid_cols <= len(self.tbl.tbl_version.store_tbl.rowid_columns())
284
- group_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
278
+ assert num_rowid_cols <= len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
279
+ group_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
285
280
  elif self.group_by_clause is not None:
286
281
  group_by_clause = self.group_by_clause
287
282
 
288
283
  for item in self._select_list_exprs:
289
284
  item.bind_rel_paths(None)
290
285
 
291
- return Planner.create_query_plan(
292
- self.tbl,
286
+ return plan.Planner.create_query_plan(
287
+ self._from_clause,
293
288
  self._select_list_exprs,
294
289
  where_clause=self.where_clause,
295
290
  group_by_clause=group_by_clause,
@@ -297,23 +292,57 @@ class DataFrame:
297
292
  limit=self.limit_val
298
293
  )
299
294
 
295
+ def _has_joins(self) -> bool:
296
+ return len(self._from_clause.join_clauses) > 0
300
297
 
301
298
  def show(self, n: int = 20) -> DataFrameResultSet:
302
299
  assert n is not None
303
300
  return self.limit(n).collect()
304
301
 
305
302
  def head(self, n: int = 10) -> DataFrameResultSet:
303
+ """Return the first n rows of the DataFrame, in insertion order of the underlying Table.
304
+
305
+ head() is not supported for joins.
306
+
307
+ Args:
308
+ n: Number of rows to select. Default is 10.
309
+
310
+ Returns:
311
+ A DataFrameResultSet with the first n rows of the DataFrame.
312
+
313
+ Raises:
314
+ Error: If the DataFrame is the result of a join or
315
+ if the DataFrame has an order_by clause.
316
+ """
306
317
  if self.order_by_clause is not None:
307
318
  raise excs.Error(f'head() cannot be used with order_by()')
308
- num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
309
- order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
319
+ if self._has_joins():
320
+ raise excs.Error(f'head() not supported for joins')
321
+ num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
322
+ order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
310
323
  return self.order_by(*order_by_clause, asc=True).limit(n).collect()
311
324
 
312
325
  def tail(self, n: int = 10) -> DataFrameResultSet:
326
+ """Return the last n rows of the DataFrame, in insertion order of the underlying Table.
327
+
328
+ tail() is not supported for joins.
329
+
330
+ Args:
331
+ n: Number of rows to select. Default is 10.
332
+
333
+ Returns:
334
+ A DataFrameResultSet with the last n rows of the DataFrame.
335
+
336
+ Raises:
337
+ Error: If the DataFrame is the result of a join or
338
+ if the DataFrame has an order_by clause.
339
+ """
313
340
  if self.order_by_clause is not None:
314
341
  raise excs.Error(f'tail() cannot be used with order_by()')
315
- num_rowid_cols = len(self.tbl.tbl_version.store_tbl.rowid_columns())
316
- order_by_clause = [exprs.RowidRef(self.tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
342
+ if self._has_joins():
343
+ raise excs.Error(f'tail() not supported for joins')
344
+ num_rowid_cols = len(self._first_tbl.tbl_version.store_tbl.rowid_columns())
345
+ order_by_clause = [exprs.RowidRef(self._first_tbl.tbl_version, idx) for idx in range(num_rowid_cols)]
317
346
  result = self.order_by(*order_by_clause, asc=False).limit(n).collect()
318
347
  result._reverse()
319
348
  return result
@@ -359,7 +388,7 @@ class DataFrame:
359
388
  ]
360
389
 
361
390
  return DataFrame(
362
- self.tbl, select_list=select_list, where_clause=where_clause,
391
+ from_clause=self._from_clause, select_list=select_list, where_clause=where_clause,
363
392
  group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
364
393
  order_by_clause=order_by_clause, limit=self.limit_val)
365
394
 
@@ -393,30 +422,49 @@ class DataFrame:
393
422
  return DataFrameResultSet(list(self._output_row_iterator(conn)), self.schema)
394
423
 
395
424
  def count(self) -> int:
425
+ """Return the number of rows in the DataFrame.
426
+
427
+ Returns:
428
+ The number of rows in the DataFrame.
429
+ """
396
430
  from pixeltable.plan import Planner
397
431
 
398
- stmt = Planner.create_count_stmt(self.tbl, self.where_clause)
432
+ stmt = Planner.create_count_stmt(self._first_tbl, self.where_clause)
399
433
  with Env.get().engine.connect() as conn:
400
434
  result: int = conn.execute(stmt).scalar_one()
401
435
  assert isinstance(result, int)
402
436
  return result
403
437
 
404
- def _description(self) -> pd.DataFrame:
405
- """see DataFrame.describe()"""
438
+ def _descriptors(self) -> DescriptionHelper:
439
+ helper = DescriptionHelper()
440
+ helper.append(self._col_descriptor())
441
+ qd = self._query_descriptor()
442
+ if not qd.empty:
443
+ helper.append(qd, show_index=True, show_header=False)
444
+ return helper
445
+
446
+ def _col_descriptor(self) -> pd.DataFrame:
447
+ return pd.DataFrame([
448
+ {
449
+ 'Name': name,
450
+ 'Type': expr.col_type._to_str(as_schema=True),
451
+ 'Expression': expr.display_str(inline=False),
452
+ }
453
+ for name, expr in zip(self.schema.keys(), self._select_list_exprs)
454
+ ])
455
+
456
+ def _query_descriptor(self) -> pd.DataFrame:
406
457
  heading_vals: list[str] = []
407
458
  info_vals: list[str] = []
408
- if self.select_list is not None:
409
- assert len(self.select_list) > 0
410
- heading_vals.append('Select')
411
- heading_vals.extend([''] * (len(self.select_list) - 1))
412
- info_vals.extend(self.schema.keys())
459
+ heading_vals.append('From')
460
+ info_vals.extend(tbl.tbl_name() for tbl in self._from_clause.tbls)
413
461
  if self.where_clause is not None:
414
462
  heading_vals.append('Where')
415
463
  info_vals.append(self.where_clause.display_str(inline=False))
416
464
  if self.group_by_clause is not None:
417
465
  heading_vals.append('Group By')
418
466
  heading_vals.extend([''] * (len(self.group_by_clause) - 1))
419
- info_vals.extend([e.display_str(inline=False) for e in self.group_by_clause])
467
+ info_vals.extend(e.display_str(inline=False) for e in self.group_by_clause)
420
468
  if self.order_by_clause is not None:
421
469
  heading_vals.append('Order By')
422
470
  heading_vals.extend([''] * (len(self.order_by_clause) - 1))
@@ -426,22 +474,8 @@ class DataFrame:
426
474
  if self.limit_val is not None:
427
475
  heading_vals.append('Limit')
428
476
  info_vals.append(str(self.limit_val))
429
- assert len(heading_vals) > 0
430
- assert len(info_vals) > 0
431
477
  assert len(heading_vals) == len(info_vals)
432
- return pd.DataFrame({'Heading': heading_vals, 'Info': info_vals})
433
-
434
- def _description_html(self) -> pandas.io.formats.style.Styler:
435
- """Return the description in an ipython-friendly manner."""
436
- pd_df = self._description()
437
- # white-space: pre-wrap: print \n as newline
438
- # th: center-align headings
439
- return (
440
- pd_df.style.set_properties(None, **{'white-space': 'pre-wrap', 'text-align': 'left'})
441
- .set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
442
- .hide(axis='index')
443
- .hide(axis='columns')
444
- )
478
+ return pd.DataFrame(info_vals, index=heading_vals)
445
479
 
446
480
  def describe(self) -> None:
447
481
  """
@@ -451,17 +485,47 @@ class DataFrame:
451
485
  """
452
486
  if getattr(builtins, '__IPYTHON__', False):
453
487
  from IPython.display import display
454
- display(self._description_html())
488
+ display(self._repr_html_())
455
489
  else:
456
- print(self.__repr__())
490
+ print(repr(self))
457
491
 
458
492
  def __repr__(self) -> str:
459
- return self._description().to_string(header=False, index=False)
493
+ return self._descriptors().to_string()
460
494
 
461
495
  def _repr_html_(self) -> str:
462
- return self._description_html()._repr_html_() # type: ignore[attr-defined]
496
+ return self._descriptors().to_html()
463
497
 
464
498
  def select(self, *items: Any, **named_items: Any) -> DataFrame:
499
+ """ Select columns or expressions from the DataFrame.
500
+
501
+ Args:
502
+ items: expressions to be selected
503
+ named_items: named expressions to be selected
504
+
505
+ Returns:
506
+ A new DataFrame with the specified select list.
507
+
508
+ Raises:
509
+ Error: If the select list is already specified,
510
+ or if any of the specified expressions are invalid,
511
+ or refer to tables not in the DataFrame.
512
+
513
+ Examples:
514
+ Given the DataFrame person from a table t with all its columns and rows:
515
+
516
+ >>> person = t.select()
517
+
518
+ Select the columns 'name' and 'age' (referenced in table t) from the DataFrame person:
519
+
520
+ >>> df = person.select(t.name, t.age)
521
+
522
+ Select the columns 'name' (referenced in table t) from the DataFrame person,
523
+ and a named column 'is_adult' from the expression `age >= 18` where 'age' is
524
+ another column in table t:
525
+
526
+ >>> df = person.select(t.name, is_adult=(t.age >= 18))
527
+
528
+ """
465
529
  if self.select_list is not None:
466
530
  raise excs.Error(f'Select list already specified')
467
531
  for name, _ in named_items.items():
@@ -472,7 +536,7 @@ class DataFrame:
472
536
  return self
473
537
 
474
538
  # analyze select list; wrap literals with the corresponding expressions
475
- select_list = []
539
+ select_list: list[tuple[exprs.Expr, Optional[str]]] = []
476
540
  for raw_expr, name in base_list:
477
541
  if isinstance(raw_expr, exprs.Expr):
478
542
  select_list.append((raw_expr, name))
@@ -485,12 +549,14 @@ class DataFrame:
485
549
  expr = select_list[-1][0]
486
550
  if expr.col_type.is_invalid_type():
487
551
  raise excs.Error(f'Invalid type: {raw_expr}')
488
- # TODO: check that ColumnRefs in expr refer to self.tbl
552
+ if not expr.is_bound_by(self._from_clause.tbls):
553
+ raise excs.Error(
554
+ f"Expression '{expr}' cannot be evaluated in the context of this query's tables "
555
+ f"({','.join(tbl.tbl_name() for tbl in self._from_clause.tbls)})")
489
556
 
490
- # check user provided names do not conflict among themselves
491
- # or with auto-generated ones
557
+ # check user provided names do not conflict among themselves or with auto-generated ones
492
558
  seen: set[str] = set()
493
- _, names = DataFrame._normalize_select_list(self.tbl, select_list)
559
+ _, names = DataFrame._normalize_select_list(self._from_clause.tbls, select_list)
494
560
  for name in names:
495
561
  if name in seen:
496
562
  repeated_names = [j for j, x in enumerate(names) if x == name]
@@ -499,7 +565,7 @@ class DataFrame:
499
565
  seen.add(name)
500
566
 
501
567
  return DataFrame(
502
- self.tbl,
568
+ from_clause=self._from_clause,
503
569
  select_list=select_list,
504
570
  where_clause=self.where_clause,
505
571
  group_by_clause=self.group_by_clause,
@@ -509,12 +575,35 @@ class DataFrame:
509
575
  )
510
576
 
511
577
  def where(self, pred: exprs.Expr) -> DataFrame:
578
+ """Filter rows based on a predicate.
579
+
580
+ Args:
581
+ pred: the predicate to filter rows
582
+
583
+ Returns:
584
+ A new DataFrame with the specified predicates replacing the where-clause.
585
+
586
+ Raises:
587
+ Error: If the predicate is not a Pixeltable expression,
588
+ or if it does not return a boolean value,
589
+ or refers to tables not in the DataFrame.
590
+
591
+ Examples:
592
+ Given the DataFrame person from a table t with all its columns and rows:
593
+
594
+ >>> person = t.select()
595
+
596
+ Filter the above DataFrame person to only include rows where the column 'age'
597
+ (referenced in table t) is greater than 30:
598
+
599
+ >>> df = person.where(t.age > 30)
600
+ """
512
601
  if not isinstance(pred, exprs.Expr):
513
602
  raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
514
603
  if not pred.col_type.is_bool_type():
515
604
  raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
516
605
  return DataFrame(
517
- self.tbl,
606
+ from_clause=self._from_clause,
518
607
  select_list=self.select_list,
519
608
  where_clause=pred,
520
609
  group_by_clause=self.group_by_clause,
@@ -523,11 +612,181 @@ class DataFrame:
523
612
  limit=self.limit_val,
524
613
  )
525
614
 
615
+ def _create_join_predicate(
616
+ self, other: catalog.TableVersionPath, on: Union[exprs.Expr, Sequence[exprs.ColumnRef]]
617
+ ) -> exprs.Expr:
618
+ """Verifies user-specified 'on' argument and converts it into a join predicate."""
619
+ col_refs: list[exprs.ColumnRef] = []
620
+ joined_tbls = self._from_clause.tbls + [other]
621
+
622
+ if isinstance(on, exprs.ColumnRef):
623
+ on = [on]
624
+ elif isinstance(on, exprs.Expr):
625
+ if not on.is_bound_by(joined_tbls):
626
+ raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {on}")
627
+ if not on.col_type.is_bool_type():
628
+ raise excs.Error(f"'on': boolean expression expected, but got {on.col_type}: {on}")
629
+ return on
630
+ else:
631
+ if not isinstance(on, Sequence) or len(on) == 0:
632
+ raise excs.Error(
633
+ f"'on': must be a sequence of column references or a boolean expression")
634
+
635
+ assert isinstance(on, Sequence)
636
+ for col_ref in on:
637
+ if not isinstance(col_ref, exprs.ColumnRef):
638
+ raise excs.Error(
639
+ f"'on': must be a sequence of column references or a boolean expression")
640
+ if not col_ref.is_bound_by(joined_tbls):
641
+ raise excs.Error(f"'on': expression cannot be evaluated in the context of the joined tables: {col_ref}")
642
+ col_refs.append(col_ref)
643
+
644
+ predicates: list[exprs.Expr] = []
645
+ # try to turn ColumnRefs into equality predicates
646
+ assert len(col_refs) > 0 and len(joined_tbls) >= 2
647
+ for col_ref in col_refs:
648
+ # identify the referenced column by name in 'other'
649
+ rhs_col = other.get_column(col_ref.col.name, include_bases=True)
650
+ if rhs_col is None:
651
+ raise excs.Error(f"'on': column {col_ref.col.name!r} not found in joined table")
652
+ rhs_col_ref = exprs.ColumnRef(rhs_col)
653
+
654
+ lhs_col_ref: Optional[exprs.ColumnRef] = None
655
+ if any(tbl.has_column(col_ref.col, include_bases=True) for tbl in self._from_clause.tbls):
656
+ # col_ref comes from the existing from_clause, we use that directly
657
+ lhs_col_ref = col_ref
658
+ else:
659
+ # col_ref comes from other, we need to look for a match in the existing from_clause by name
660
+ for tbl in self._from_clause.tbls:
661
+ col = tbl.get_column(col_ref.col.name, include_bases=True)
662
+ if col is None:
663
+ continue
664
+ if lhs_col_ref is not None:
665
+ raise excs.Error(f"'on': ambiguous column reference: {col_ref.col.name!r}")
666
+ lhs_col_ref = exprs.ColumnRef(col)
667
+ if lhs_col_ref is None:
668
+ tbl_names = [tbl.tbl_name() for tbl in self._from_clause.tbls]
669
+ raise excs.Error(
670
+ f"'on': column {col_ref.col.name!r} not found in any of: {' '.join(tbl_names)}")
671
+ pred = exprs.Comparison(exprs.ComparisonOperator.EQ, lhs_col_ref, rhs_col_ref)
672
+ predicates.append(pred)
673
+
674
+ assert len(predicates) > 0
675
+ if len(predicates) == 1:
676
+ return predicates[0]
677
+ else:
678
+ return exprs.CompoundPredicate(operator=exprs.LogicalOperator.AND, operands=predicates)
679
+
680
+ def join(
681
+ self, other: catalog.Table, on: Optional[Union[exprs.Expr, Sequence[exprs.ColumnRef]]] = None,
682
+ how: plan.JoinType.LiteralType = 'inner'
683
+ ) -> DataFrame:
684
+ """
685
+ Join this DataFrame with a table.
686
+
687
+ Args:
688
+ other: the table to join with
689
+ on: the join condition, which can be either a) references to one or more columns or b) a boolean
690
+ expression.
691
+
692
+ - column references: implies an equality predicate that matches columns in both this
693
+ DataFrame and `other` by name.
694
+
695
+ - column in `other`: A column with that same name must be present in this DataFrame, and **it must
696
+ be unique** (otherwise the join is ambiguous).
697
+ - column in this DataFrame: A column with that same name must be present in `other`.
698
+
699
+ - boolean expression: The expressions must be valid in the context of the joined tables.
700
+ how: the type of join to perform.
701
+
702
+ - `'inner'`: only keep rows that have a match in both
703
+ - `'left'`: keep all rows from this DataFrame and only matching rows from the other table
704
+ - `'right'`: keep all rows from the other table and only matching rows from this DataFrame
705
+ - `'full_outer'`: keep all rows from both this DataFrame and the other table
706
+ - `'cross'`: Cartesian product; no `on` condition allowed
707
+
708
+ Returns:
709
+ A new DataFrame.
710
+
711
+ Examples:
712
+ Perform an inner join between t1 and t2 on the column id:
713
+
714
+ >>> join1 = t1.join(t2, on=t2.id)
715
+
716
+ Perform a left outer join of join1 with t3, also on id (note that we can't specify `on=t3.id` here,
717
+ because that would be ambiguous, since both t1 and t2 have a column named id):
718
+
719
+ >>> join2 = join1.join(t3, on=t2.id, how='left')
720
+
721
+ Do the same, but now with an explicit join predicate:
722
+
723
+ >>> join2 = join1.join(t3, on=t2.id == t3.id, how='left')
724
+
725
+ Join t with d, which has a composite primary key (columns pk1 and pk2, with corresponding foreign
726
+ key columns d1 and d2 in t):
727
+
728
+ >>> df = t.join(d, on=(t.d1 == d.pk1) & (t.d2 == d.pk2), how='left')
729
+ """
730
+ join_pred: Optional[exprs.Expr]
731
+ if how == 'cross':
732
+ if on is not None:
733
+ raise excs.Error(f"'on' not allowed for cross join")
734
+ join_pred = None
735
+ else:
736
+ if on is None:
737
+ raise excs.Error(f"how={how!r} requires 'on'")
738
+ join_pred = self._create_join_predicate(other._tbl_version_path, on)
739
+ join_clause = plan.JoinClause(join_type=plan.JoinType.validated(how, "'how'"), join_predicate=join_pred)
740
+ from_clause = plan.FromClause(
741
+ tbls=[*self._from_clause.tbls, other._tbl_version_path],
742
+ join_clauses=[*self._from_clause.join_clauses, join_clause])
743
+ return DataFrame(
744
+ from_clause=from_clause,
745
+ select_list=self.select_list, where_clause=self.where_clause,
746
+ group_by_clause=self.group_by_clause, grouping_tbl=self.grouping_tbl,
747
+ order_by_clause=self.order_by_clause, limit=self.limit_val,
748
+ )
749
+
526
750
  def group_by(self, *grouping_items: Any) -> DataFrame:
527
- """Add a group-by clause to this DataFrame.
751
+ """ Add a group-by clause to this DataFrame.
752
+
528
753
  Variants:
529
754
  - group_by(<base table>): group a component view by their respective base table rows
530
755
  - group_by(<expr>, ...): group by the given expressions
756
+
757
+ Note, that grouping will be applied to the rows and take effect when
758
+ used with an aggregation function like sum(), count() etc.
759
+
760
+ Args:
761
+ grouping_items: expressions to group by
762
+
763
+ Returns:
764
+ A new DataFrame with the specified group-by clause.
765
+
766
+ Raises:
767
+ Error: If the group-by clause is already specified,
768
+ or if the specified expression is invalid,
769
+ or refer to tables not in the DataFrame,
770
+ or if the DataFrame is a result of a join.
771
+
772
+ Examples:
773
+ Given the DataFrame book from a table t with all its columns and rows:
774
+
775
+ >>> book = t.select()
776
+
777
+ Group the above DataFrame book by the 'genre' column (referenced in table t):
778
+
779
+ >>> df = book.group_by(t.genre)
780
+
781
+ Use the above DataFrame df grouped by genre to count the number of
782
+ books for each 'genre':
783
+
784
+ >>> df = book.group_by(t.genre).select(t.genre, count=count(t.genre)).show()
785
+
786
+ Use the above DataFrame df grouped by genre to the total price of
787
+ books for each 'genre':
788
+
789
+ >>> df = book.group_by(t.genre).select(t.genre, total=sum(t.price)).show()
531
790
  """
532
791
  if self.group_by_clause is not None:
533
792
  raise excs.Error(f'Group-by already specified')
@@ -537,10 +796,12 @@ class DataFrame:
537
796
  if isinstance(item, catalog.Table):
538
797
  if len(grouping_items) > 1:
539
798
  raise excs.Error(f'group_by(): only one table can be specified')
799
+ if len(self._from_clause.tbls) > 1:
800
+ raise excs.Error(f'group_by() with Table not supported for joins')
540
801
  # we need to make sure that the grouping table is a base of self.tbl
541
- base = self.tbl.find_tbl_version(item._tbl_version_path.tbl_id())
542
- if base is None or base.id == self.tbl.tbl_id():
543
- raise excs.Error(f'group_by(): {item._name} is not a base table of {self.tbl.tbl_name()}')
802
+ base = self._first_tbl.find_tbl_version(item._tbl_version_path.tbl_id())
803
+ if base is None or base.id == self._first_tbl.tbl_id():
804
+ raise excs.Error(f'group_by(): {item._name} is not a base table of {self._first_tbl.tbl_name()}')
544
805
  grouping_tbl = item._tbl_version_path.tbl_version
545
806
  break
546
807
  if not isinstance(item, exprs.Expr):
@@ -548,7 +809,7 @@ class DataFrame:
548
809
  if grouping_tbl is None:
549
810
  group_by_clause = list(grouping_items)
550
811
  return DataFrame(
551
- self.tbl,
812
+ from_clause=self._from_clause,
552
813
  select_list=self.select_list,
553
814
  where_clause=self.where_clause,
554
815
  group_by_clause=group_by_clause,
@@ -558,13 +819,42 @@ class DataFrame:
558
819
  )
559
820
 
560
821
  def order_by(self, *expr_list: exprs.Expr, asc: bool = True) -> DataFrame:
822
+ """ Add an order-by clause to this DataFrame.
823
+
824
+ Args:
825
+ expr_list: expressions to order by
826
+ asc: whether to order in ascending order (True) or descending order (False).
827
+ Default is True.
828
+
829
+ Returns:
830
+ A new DataFrame with the specified order-by clause.
831
+
832
+ Raises:
833
+ Error: If the order-by clause is already specified,
834
+ or if the specified expression is invalid,
835
+ or refer to tables not in the DataFrame.
836
+
837
+ Examples:
838
+ Given the DataFrame book from a table t with all its columns and rows:
839
+
840
+ >>> book = t.select()
841
+
842
+ Order the above DataFrame book by two columns (price, pages) in descending order:
843
+
844
+ >>> df = book.order_by(t.price, t.pages, asc=False)
845
+
846
+ Order the above DataFrame book by price in descending order, but order the pages
847
+ in ascending order:
848
+
849
+ >>> df = book.order_by(t.price, asc=False).order_by(t.pages)
850
+ """
561
851
  for e in expr_list:
562
852
  if not isinstance(e, exprs.Expr):
563
853
  raise excs.Error(f'Invalid expression in order_by(): {e}')
564
854
  order_by_clause = self.order_by_clause if self.order_by_clause is not None else []
565
855
  order_by_clause.extend([(e.copy(), asc) for e in expr_list])
566
856
  return DataFrame(
567
- self.tbl,
857
+ from_clause=self._from_clause,
568
858
  select_list=self.select_list,
569
859
  where_clause=self.where_clause,
570
860
  group_by_clause=self.group_by_clause,
@@ -574,10 +864,18 @@ class DataFrame:
574
864
  )
575
865
 
576
866
  def limit(self, n: int) -> DataFrame:
867
+ """ Limit the number of rows in the DataFrame.
868
+
869
+ Args:
870
+ n: Number of rows to select.
871
+
872
+ Returns:
873
+ A new DataFrame with the specified limited rows.
874
+ """
577
875
  # TODO: allow n to be a Variable that can be substituted in bind()
578
876
  assert n is not None and isinstance(n, int)
579
877
  return DataFrame(
580
- self.tbl,
878
+ from_clause=self._from_clause,
581
879
  select_list=self.select_list,
582
880
  where_clause=self.where_clause,
583
881
  group_by_clause=self.group_by_clause,
@@ -587,17 +885,58 @@ class DataFrame:
587
885
  )
588
886
 
589
887
  def update(self, value_spec: dict[str, Any], cascade: bool = True) -> UpdateStatus:
888
+ """ Update rows in the underlying table of the DataFrame.
889
+
890
+ Update rows in the table with the specified value_spec.
891
+
892
+ Args:
893
+ value_spec: a dict of column names to update and the new value to update it to.
894
+ cascade: if True, also update all computed columns that transitively depend
895
+ on the updated columns, including within views. Default is True.
896
+
897
+ Returns:
898
+ UpdateStatus: the status of the update operation.
899
+
900
+ Example:
901
+ Given the DataFrame person from a table t with all its columns and rows:
902
+
903
+ >>> person = t.select()
904
+
905
+ Via the above DataFrame person, update the column 'city' to 'Oakland' and 'state' to 'CA' in the table t:
906
+
907
+ >>> df = person.update({'city': 'Oakland', 'state': 'CA'})
908
+
909
+ Via the above DataFrame person, update the column 'age' to 30 for any rows where 'year' is 2014 in the table t:
910
+
911
+ >>> df = person.where(t.year == 2014).update({'age': 30})
912
+ """
590
913
  self._validate_mutable('update')
591
- return self.tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
914
+ return self._first_tbl.tbl_version.update(value_spec, where=self.where_clause, cascade=cascade)
592
915
 
593
916
  def delete(self) -> UpdateStatus:
917
+ """ Delete rows form the underlying table of the DataFrame.
918
+
919
+ The delete operation is only allowed for DataFrames on base tables.
920
+
921
+ Returns:
922
+ UpdateStatus: the status of the delete operation.
923
+
924
+ Example:
925
+ Given the DataFrame person from a table t with all its columns and rows:
926
+
927
+ >>> person = t.select()
928
+
929
+ Via the above DataFrame person, delete all rows from the table t where the column 'age' is less than 18:
930
+
931
+ >>> df = person.where(t.age < 18).delete()
932
+ """
594
933
  self._validate_mutable('delete')
595
- if not self.tbl.is_insertable():
934
+ if not self._first_tbl.is_insertable():
596
935
  raise excs.Error(f'Cannot delete from view')
597
- return self.tbl.tbl_version.delete(where=self.where_clause)
936
+ return self._first_tbl.tbl_version.delete(where=self.where_clause)
598
937
 
599
938
  def _validate_mutable(self, op_name: str) -> None:
600
- """Tests whether this `DataFrame` can be mutated (such as by an update operation)."""
939
+ """Tests whether this DataFrame can be mutated (such as by an update operation)."""
601
940
  if self.group_by_clause is not None or self.grouping_tbl is not None:
602
941
  raise excs.Error(f'Cannot use `{op_name}` after `group_by`')
603
942
  if self.order_by_clause is not None:
@@ -607,27 +946,17 @@ class DataFrame:
607
946
  if self.limit_val is not None:
608
947
  raise excs.Error(f'Cannot use `{op_name}` after `limit`')
609
948
 
610
- def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
611
- """
612
- Allowed:
613
- - [list[Expr]]/[tuple[Expr]]: setting the select list
614
- - [Expr]: setting a single-col select list
615
- """
616
- if isinstance(index, exprs.Expr):
617
- return self.select(index)
618
- if isinstance(index, Sequence):
619
- return self.select(*index)
620
- raise TypeError(f'Invalid index type: {type(index)}')
621
-
622
949
  def as_dict(self) -> dict[str, Any]:
623
950
  """
624
951
  Returns:
625
952
  Dictionary representing this dataframe.
626
953
  """
627
- tbl_versions = self.tbl.get_tbl_versions()
628
954
  d = {
629
955
  '_classname': 'DataFrame',
630
- 'tbl': self.tbl.as_dict(),
956
+ 'from_clause': {
957
+ 'tbls': [tbl.as_dict() for tbl in self._from_clause.tbls],
958
+ 'join_clauses': [dataclasses.asdict(clause) for clause in self._from_clause.join_clauses]
959
+ },
631
960
  'select_list':
632
961
  [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
633
962
  'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
@@ -642,7 +971,9 @@ class DataFrame:
642
971
 
643
972
  @classmethod
644
973
  def from_dict(cls, d: dict[str, Any]) -> 'DataFrame':
645
- tbl = catalog.TableVersionPath.from_dict(d['tbl'])
974
+ tbls = [catalog.TableVersionPath.from_dict(tbl_dict) for tbl_dict in d['from_clause']['tbls']]
975
+ join_clauses = [plan.JoinClause(**clause_dict) for clause_dict in d['from_clause']['join_clauses']]
976
+ from_clause = plan.FromClause(tbls=tbls, join_clauses=join_clauses)
646
977
  select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
647
978
  if d['select_list'] is not None else None
648
979
  where_clause = exprs.Expr.from_dict(d['where_clause']) \
@@ -655,15 +986,18 @@ class DataFrame:
655
986
  if d['order_by_clause'] is not None else None
656
987
  limit_val = d['limit_val']
657
988
  return DataFrame(
658
- tbl, select_list=select_list, where_clause=where_clause, group_by_clause=group_by_clause,
659
- grouping_tbl=grouping_tbl, order_by_clause=order_by_clause, limit=limit_val)
989
+ from_clause=from_clause, select_list=select_list, where_clause=where_clause,
990
+ group_by_clause=group_by_clause, grouping_tbl=grouping_tbl, order_by_clause=order_by_clause,
991
+ limit=limit_val)
660
992
 
661
993
  def _hash_result_set(self) -> str:
662
994
  """Return a hash that changes when the result set changes."""
663
995
  d = self.as_dict()
664
996
  # add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
665
997
  # invalidation when any of the referenced tables changes
666
- d['tbl_versions'] = [tbl_version.version for tbl_version in self.tbl.get_tbl_versions()]
998
+ d['tbl_versions'] = [
999
+ tbl_version.version for tbl in self._from_clause.tbls for tbl_version in tbl.get_tbl_versions()
1000
+ ]
667
1001
  summary_string = json.dumps(d)
668
1002
  return hashlib.sha256(summary_string.encode()).hexdigest()
669
1003
 
@@ -732,7 +1066,7 @@ class DataFrame:
732
1066
  Env.get().require_package('torch')
733
1067
  Env.get().require_package('torchvision')
734
1068
 
735
- from pixeltable.io.parquet import save_parquet
1069
+ from pixeltable.io import export_parquet
736
1070
  from pixeltable.utils.pytorch import PixeltablePytorchDataset
737
1071
 
738
1072
  cache_key = self._hash_result_set()
@@ -741,6 +1075,6 @@ class DataFrame:
741
1075
  if dest_path.exists(): # fast path: use cache
742
1076
  assert dest_path.is_dir()
743
1077
  else:
744
- save_parquet(self, dest_path)
1078
+ export_parquet(self, dest_path, inline_images=True)
745
1079
 
746
1080
  return PixeltablePytorchDataset(path=dest_path, image_format=image_format)