pixeltable 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (67) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +5 -0
  4. pixeltable/catalog/globals.py +8 -0
  5. pixeltable/catalog/insertable_table.py +2 -2
  6. pixeltable/catalog/table.py +27 -9
  7. pixeltable/catalog/table_version.py +41 -68
  8. pixeltable/catalog/view.py +3 -3
  9. pixeltable/dataframe.py +7 -6
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/expr_eval_node.py +8 -1
  12. pixeltable/exec/row_update_node.py +61 -0
  13. pixeltable/exec/{sql_scan_node.py → sql_node.py} +120 -56
  14. pixeltable/exprs/__init__.py +1 -2
  15. pixeltable/exprs/comparison.py +5 -5
  16. pixeltable/exprs/compound_predicate.py +12 -12
  17. pixeltable/exprs/expr.py +67 -22
  18. pixeltable/exprs/function_call.py +60 -29
  19. pixeltable/exprs/globals.py +2 -0
  20. pixeltable/exprs/in_predicate.py +3 -3
  21. pixeltable/exprs/inline_array.py +18 -11
  22. pixeltable/exprs/is_null.py +5 -5
  23. pixeltable/exprs/method_ref.py +63 -0
  24. pixeltable/ext/__init__.py +9 -0
  25. pixeltable/ext/functions/__init__.py +8 -0
  26. pixeltable/ext/functions/whisperx.py +45 -5
  27. pixeltable/ext/functions/yolox.py +60 -14
  28. pixeltable/func/aggregate_function.py +10 -4
  29. pixeltable/func/callable_function.py +16 -4
  30. pixeltable/func/expr_template_function.py +1 -1
  31. pixeltable/func/function.py +12 -2
  32. pixeltable/func/function_registry.py +26 -9
  33. pixeltable/func/udf.py +32 -4
  34. pixeltable/functions/__init__.py +1 -1
  35. pixeltable/functions/fireworks.py +33 -0
  36. pixeltable/functions/globals.py +36 -1
  37. pixeltable/functions/huggingface.py +155 -7
  38. pixeltable/functions/image.py +242 -40
  39. pixeltable/functions/openai.py +214 -0
  40. pixeltable/functions/string.py +600 -8
  41. pixeltable/functions/timestamp.py +210 -0
  42. pixeltable/functions/together.py +106 -0
  43. pixeltable/functions/video.py +28 -10
  44. pixeltable/functions/whisper.py +32 -0
  45. pixeltable/globals.py +3 -3
  46. pixeltable/io/__init__.py +1 -1
  47. pixeltable/io/globals.py +186 -5
  48. pixeltable/io/label_studio.py +42 -2
  49. pixeltable/io/pandas.py +70 -34
  50. pixeltable/metadata/__init__.py +1 -1
  51. pixeltable/metadata/converters/convert_18.py +39 -0
  52. pixeltable/metadata/notes.py +10 -0
  53. pixeltable/plan.py +82 -7
  54. pixeltable/tool/create_test_db_dump.py +4 -5
  55. pixeltable/tool/doc_plugins/griffe.py +81 -0
  56. pixeltable/tool/doc_plugins/mkdocstrings.py +6 -0
  57. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +135 -0
  58. pixeltable/type_system.py +15 -14
  59. pixeltable/utils/s3.py +1 -1
  60. pixeltable-0.2.14.dist-info/METADATA +206 -0
  61. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/RECORD +64 -56
  62. pixeltable-0.2.14.dist-info/entry_points.txt +3 -0
  63. pixeltable/exprs/image_member_access.py +0 -96
  64. pixeltable/exprs/predicate.py +0 -44
  65. pixeltable-0.2.12.dist-info/METADATA +0 -137
  66. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/LICENSE +0 -0
  67. {pixeltable-0.2.12.dist-info → pixeltable-0.2.14.dist-info}/WHEEL +0 -0
pixeltable/__init__.py CHANGED
@@ -21,7 +21,7 @@ from .type_system import (
21
21
  )
22
22
  from .utils.help import help
23
23
 
24
- from . import functions, io, iterators
24
+ from . import ext, functions, io, iterators
25
25
  from .__version__ import __version__, __version_tuple__
26
26
 
27
27
  # This is the safest / most maintainable way to do this: start with the default and "blacklist" stuff that
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = "0.2.12"
3
- __version_tuple__ = (0, 2, 12)
2
+ __version__ = "0.2.14"
3
+ __version_tuple__ = (0, 2, 14)
@@ -152,6 +152,11 @@ class Column:
152
152
  return self._records_errors
153
153
  return self.is_stored and (self.is_computed or self.col_type.is_media_type())
154
154
 
155
+ @property
156
+ def qualified_name(self) -> str:
157
+ assert self.tbl is not None
158
+ return f'{self.tbl.name}.{self.name}'
159
+
155
160
  def source(self) -> None:
156
161
  """
157
162
  If this is a computed col and the top-level expr is a function call, print the source, if possible.
@@ -19,6 +19,14 @@ class UpdateStatus:
19
19
  updated_cols: List[str] = dataclasses.field(default_factory=list)
20
20
  cols_with_excs: List[str] = dataclasses.field(default_factory=list)
21
21
 
22
+ def __iadd__(self, other: 'UpdateStatus') -> 'UpdateStatus':
23
+ self.num_rows += other.num_rows
24
+ self.num_computed_values += other.num_computed_values
25
+ self.num_excs += other.num_excs
26
+ self.updated_cols = list(dict.fromkeys(self.updated_cols + other.updated_cols))
27
+ self.cols_with_excs = list(dict.fromkeys(self.cols_with_excs + other.cols_with_excs))
28
+ return self
29
+
22
30
  def is_valid_identifier(name: str) -> bool:
23
31
  return name.isidentifier() and not name.startswith('_')
24
32
 
@@ -129,11 +129,11 @@ class InsertableTable(Table):
129
129
  msg = str(e)
130
130
  raise excs.Error(f'Error in column {col.name}: {msg[0].lower() + msg[1:]}\nRow: {row}')
131
131
 
132
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
132
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
133
133
  """Delete rows in this table.
134
134
 
135
135
  Args:
136
- where: a Predicate to filter rows to delete.
136
+ where: a predicate to filter rows to delete.
137
137
 
138
138
  Examples:
139
139
  Delete all rows in a table:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import logging
5
5
  from pathlib import Path
6
- from typing import Union, Any, Optional, Callable, Set, Tuple, Iterable, overload, Type
6
+ from typing import Union, Any, Optional, Callable, Set, Tuple, Iterable, overload, Type, Literal
7
7
  from uuid import UUID
8
8
  import abc
9
9
 
@@ -113,7 +113,7 @@ class Table(SchemaObject):
113
113
  from pixeltable.dataframe import DataFrame
114
114
  return DataFrame(self._tbl_version_path).select(*items, **named_items)
115
115
 
116
- def where(self, pred: 'exprs.Predicate') -> 'pixeltable.dataframe.DataFrame':
116
+ def where(self, pred: 'exprs.Expr') -> 'pixeltable.dataframe.DataFrame':
117
117
  """Return a DataFrame for this table.
118
118
  """
119
119
  # local import: avoid circular imports
@@ -716,13 +716,13 @@ class Table(SchemaObject):
716
716
  raise NotImplementedError
717
717
 
718
718
  def update(
719
- self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.Predicate'] = None, cascade: bool = True
719
+ self, value_spec: dict[str, Any], where: Optional['pixeltable.exprs.Expr'] = None, cascade: bool = True
720
720
  ) -> UpdateStatus:
721
721
  """Update rows in this table.
722
722
 
723
723
  Args:
724
724
  value_spec: a dictionary mapping column names to literal values or Pixeltable expressions.
725
- where: a Predicate to filter rows to update.
725
+ where: a predicate to filter rows to update.
726
726
  cascade: if True, also update all computed columns that transitively depend on the updated columns.
727
727
 
728
728
  Examples:
@@ -745,18 +745,34 @@ class Table(SchemaObject):
745
745
  self._check_is_dropped()
746
746
  return self._tbl_version.update(value_spec, where, cascade)
747
747
 
748
- def batch_update(self, rows: Iterable[dict[str, Any]], cascade: bool = True) -> UpdateStatus:
748
+ def batch_update(
749
+ self, rows: Iterable[dict[str, Any]], cascade: bool = True,
750
+ if_not_exists: Literal['error', 'ignore', 'insert'] = 'error'
751
+ ) -> UpdateStatus:
749
752
  """Update rows in this table.
750
753
 
751
754
  Args:
752
755
  rows: an Iterable of dictionaries containing values for the updated columns plus values for the primary key
753
756
  columns.
754
757
  cascade: if True, also update all computed columns that transitively depend on the updated columns.
758
+ if_not_exists: Specifies the behavior if a row to update does not exist:
759
+
760
+ - `'error'`: Raise an error.
761
+ - `'ignore'`: Skip the row silently.
762
+ - `'insert'`: Insert the row.
755
763
 
756
764
  Examples:
757
- Update the 'name' and 'age' columns for the rows with ids 1 and 2 (assuming 'id' is the primary key):
765
+ Update the `name` and `age` columns for the rows with ids 1 and 2 (assuming `id` is the primary key).
766
+ If either row does not exist, this raises an error:
758
767
 
759
768
  >>> tbl.update([{'id': 1, 'name': 'Alice', 'age': 30}, {'id': 2, 'name': 'Bob', 'age': 40}])
769
+
770
+ Update the `name` and `age` columns for the row with `id` 1 (assuming `id` is the primary key) and insert
771
+ the row with new `id` 3 (assuming this key does not exist):
772
+
773
+ >>> tbl.update(
774
+ [{'id': 1, 'name': 'Alice', 'age': 30}, {'id': 3, 'name': 'Bob', 'age': 40}],
775
+ if_not_exists='insert')
760
776
  """
761
777
  if self._tbl_version_path.is_snapshot():
762
778
  raise excs.Error('Cannot update a snapshot')
@@ -784,13 +800,15 @@ class Table(SchemaObject):
784
800
  missing_cols = pk_col_names - set(col.name for col in col_vals.keys())
785
801
  raise excs.Error(f'Primary key columns ({", ".join(missing_cols)}) missing in {row_spec}')
786
802
  row_updates.append(col_vals)
787
- return self._tbl_version.batch_update(row_updates, rowids, cascade)
803
+ return self._tbl_version.batch_update(
804
+ row_updates, rowids, error_if_not_exists=if_not_exists == 'error',
805
+ insert_if_not_exists=if_not_exists == 'insert', cascade=cascade)
788
806
 
789
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
807
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
790
808
  """Delete rows in this table.
791
809
 
792
810
  Args:
793
- where: a Predicate to filter rows to delete.
811
+ where: a predicate to filter rows to delete.
794
812
 
795
813
  Examples:
796
814
  Delete all rows in a table:
@@ -678,12 +678,12 @@ class TableVersion:
678
678
  return result
679
679
 
680
680
  def update(
681
- self, value_spec: dict[str, Any], where: Optional['exprs.Predicate'] = None, cascade: bool = True
681
+ self, value_spec: dict[str, Any], where: Optional['exprs.Expr'] = None, cascade: bool = True
682
682
  ) -> UpdateStatus:
683
683
  """Update rows in this TableVersionPath.
684
684
  Args:
685
685
  value_spec: a list of (column, value) pairs specifying the columns to update and their new values.
686
- where: a Predicate to filter rows to update.
686
+ where: a predicate to filter rows to update.
687
687
  cascade: if True, also update all computed columns that transitively depend on the updated columns,
688
688
  including within views.
689
689
  """
@@ -694,18 +694,26 @@ class TableVersion:
694
694
 
695
695
  update_spec = self._validate_update_spec(value_spec, allow_pk=False, allow_exprs=True)
696
696
  if where is not None:
697
- if not isinstance(where, exprs.Predicate):
698
- raise excs.Error(f"'where' argument must be a Predicate, got {type(where)}")
697
+ if not isinstance(where, exprs.Expr):
698
+ raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
699
699
  analysis_info = Planner.analyze(self.path, where)
700
700
  # for now we require that the updated rows can be identified via SQL, rather than via a Python filter
701
701
  if analysis_info.filter is not None:
702
702
  raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
703
703
 
704
704
  with Env.get().engine.begin() as conn:
705
- return self._update(conn, update_spec, where, cascade)
705
+ plan, updated_cols, recomputed_cols = (
706
+ Planner.create_update_plan(self.path, update_spec, [], where, cascade)
707
+ )
708
+ result = self.propagate_update(
709
+ plan, where.sql_expr() if where is not None else None, recomputed_cols,
710
+ base_versions=[], conn=conn, timestamp=time.time(), cascade=cascade, show_progress=True)
711
+ result.updated_cols = updated_cols
712
+ return result
706
713
 
707
714
  def batch_update(
708
- self, batch: list[dict[Column, 'exprs.Expr']], rowids: list[tuple[int, ...]], cascade: bool = True
715
+ self, batch: list[dict[Column, 'exprs.Expr']], rowids: list[tuple[int, ...]], insert_if_not_exists: bool,
716
+ error_if_not_exists: bool, cascade: bool = True,
709
717
  ) -> UpdateStatus:
710
718
  """Update rows in batch.
711
719
  Args:
@@ -714,62 +722,26 @@ class TableVersion:
714
722
  """
715
723
  # if we do lookups of rowids, we must have one for each row in the batch
716
724
  assert len(rowids) == 0 or len(rowids) == len(batch)
717
- result_status = UpdateStatus()
718
725
  cols_with_excs: set[str] = set()
719
- updated_cols: set[str] = set()
720
- pk_cols = self.primary_key_columns()
721
- use_rowids = len(rowids) > 0
722
726
 
723
727
  with Env.get().engine.begin() as conn:
724
- for i, row in enumerate(batch):
725
- where_clause: Optional[exprs.Expr] = None
726
- if use_rowids:
727
- # construct Where clause to match rowid
728
- num_rowid_cols = len(self.store_tbl.rowid_columns())
729
- for col_idx in range(num_rowid_cols):
730
- assert len(rowids[i]) == num_rowid_cols, f'len({rowids[i]}) != {num_rowid_cols}'
731
- clause = exprs.RowidRef(self, col_idx) == rowids[i][col_idx]
732
- if where_clause is None:
733
- where_clause = clause
734
- else:
735
- where_clause = where_clause & clause
736
- else:
737
- # construct Where clause for primary key columns
738
- for col in pk_cols:
739
- assert col in row
740
- clause = exprs.ColumnRef(col) == row[col]
741
- if where_clause is None:
742
- where_clause = clause
743
- else:
744
- where_clause = where_clause & clause
745
-
746
- update_targets = {col: row[col] for col in row if col not in pk_cols}
747
- status = self._update(conn, update_targets, where_clause, cascade, show_progress=False)
748
- result_status.num_rows += status.num_rows
749
- result_status.num_excs += status.num_excs
750
- result_status.num_computed_values += status.num_computed_values
751
- cols_with_excs.update(status.cols_with_excs)
752
- updated_cols.update(status.updated_cols)
753
-
754
- result_status.cols_with_excs = list(cols_with_excs)
755
- result_status.updated_cols = list(updated_cols)
756
- return result_status
757
-
758
- def _update(
759
- self, conn: sql.engine.Connection, update_targets: dict[Column, 'pixeltable.exprs.Expr'],
760
- where_clause: Optional['pixeltable.exprs.Predicate'] = None, cascade: bool = True,
761
- show_progress: bool = True
762
- ) -> UpdateStatus:
763
- from pixeltable.plan import Planner
728
+ from pixeltable.plan import Planner
764
729
 
765
- plan, updated_cols, recomputed_cols = (
766
- Planner.create_update_plan(self.path, update_targets, [], where_clause, cascade)
767
- )
768
- result = self.propagate_update(
769
- plan, where_clause.sql_expr() if where_clause is not None else None, recomputed_cols,
770
- base_versions=[], conn=conn, timestamp=time.time(), cascade=cascade, show_progress=show_progress)
771
- result.updated_cols = updated_cols
772
- return result
730
+ plan, row_update_node, delete_where_clause, updated_cols, recomputed_cols = \
731
+ Planner.create_batch_update_plan(self.path, batch, rowids, cascade=cascade)
732
+ result = self.propagate_update(
733
+ plan, delete_where_clause, recomputed_cols, base_versions=[], conn=conn, timestamp=time.time(),
734
+ cascade=cascade)
735
+ result.updated_cols = [c.qualified_name for c in updated_cols]
736
+
737
+ unmatched_rows = row_update_node.unmatched_rows()
738
+ if len(unmatched_rows) > 0:
739
+ if error_if_not_exists:
740
+ raise excs.Error(f'batch_update(): {len(unmatched_rows)} row(s) not found')
741
+ if insert_if_not_exists:
742
+ insert_status = self.insert(unmatched_rows, print_stats=False, fail_on_exception=False)
743
+ result += insert_status
744
+ return result
773
745
 
774
746
  def _validate_update_spec(
775
747
  self, value_spec: dict[str, Any], allow_pk: bool, allow_exprs: bool
@@ -779,7 +751,10 @@ class TableVersion:
779
751
  if not isinstance(col_name, str):
780
752
  raise excs.Error(f'Update specification: dict key must be column name, got {col_name!r}')
781
753
  if col_name == _ROWID_COLUMN_NAME:
782
- # ignore pseudo-column _rowid
754
+ # a valid rowid is a list of ints, one per rowid column
755
+ assert len(val) == len(self.store_tbl.rowid_columns())
756
+ for el in val:
757
+ assert isinstance(el, int)
783
758
  continue
784
759
  col = self.path.get_column(col_name, include_bases=False)
785
760
  if col is None:
@@ -789,8 +764,6 @@ class TableVersion:
789
764
  raise excs.Error(f'Column {col_name} is computed and cannot be updated')
790
765
  if col.is_pk and not allow_pk:
791
766
  raise excs.Error(f'Column {col_name} is a primary key column and cannot be updated')
792
- if col.col_type.is_media_type():
793
- raise excs.Error(f'Column {col_name} has type image/video/audio/document and cannot be updated')
794
767
 
795
768
  # make sure that the value is compatible with the column type
796
769
  try:
@@ -848,17 +821,17 @@ class TableVersion:
848
821
  result.cols_with_excs = list(dict.fromkeys(result.cols_with_excs).keys()) # remove duplicates
849
822
  return result
850
823
 
851
- def delete(self, where: Optional['exprs.Predicate'] = None) -> UpdateStatus:
824
+ def delete(self, where: Optional['exprs.Expr'] = None) -> UpdateStatus:
852
825
  """Delete rows in this table.
853
826
  Args:
854
- where: a Predicate to filter rows to delete.
827
+ where: a predicate to filter rows to delete.
855
828
  """
856
829
  assert self.is_insertable()
857
- from pixeltable.exprs import Predicate
830
+ from pixeltable.exprs import Expr
858
831
  from pixeltable.plan import Planner
859
832
  if where is not None:
860
- if not isinstance(where, Predicate):
861
- raise excs.Error(f"'where' argument must be a Predicate, got {type(where)}")
833
+ if not isinstance(where, Expr):
834
+ raise excs.Error(f"'where' argument must be a predicate, got {type(where)}")
862
835
  analysis_info = Planner.analyze(self.path, where)
863
836
  # for now we require that the updated rows can be identified via SQL, rather than via a Python filter
864
837
  if analysis_info.filter is not None:
@@ -872,11 +845,11 @@ class TableVersion:
872
845
  return status
873
846
 
874
847
  def propagate_delete(
875
- self, where: Optional['exprs.Predicate'], base_versions: List[Optional[int]],
848
+ self, where: Optional['exprs.Expr'], base_versions: List[Optional[int]],
876
849
  conn: sql.engine.Connection, timestamp: float) -> int:
877
850
  """Delete rows in this table and propagate to views.
878
851
  Args:
879
- where: a Predicate to filter rows to delete.
852
+ where: a predicate to filter rows to delete.
880
853
  Returns:
881
854
  number of deleted rows
882
855
  """
@@ -51,7 +51,7 @@ class View(Table):
51
51
  @classmethod
52
52
  def create(
53
53
  cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
54
- predicate: 'pxt.exprs.Predicate', is_snapshot: bool, num_retained_versions: int, comment: str,
54
+ predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
55
55
  iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
56
56
  ) -> View:
57
57
  columns = cls._create_columns(schema)
@@ -92,7 +92,7 @@ class View(Table):
92
92
  ]
93
93
  sig = func.Signature(InvalidType(), params)
94
94
  from pixeltable.exprs import FunctionCall
95
- FunctionCall.check_args(sig, bound_args)
95
+ FunctionCall.normalize_args(sig, bound_args)
96
96
  except TypeError as e:
97
97
  raise Error(f'Cannot instantiate iterator with given arguments: {e}')
98
98
 
@@ -213,5 +213,5 @@ class View(Table):
213
213
  ) -> UpdateStatus:
214
214
  raise excs.Error(f'{self.display_name()} {self._name!r}: cannot insert into view')
215
215
 
216
- def delete(self, where: Optional['pixeltable.exprs.Predicate'] = None) -> UpdateStatus:
216
+ def delete(self, where: Optional['pixeltable.exprs.Expr'] = None) -> UpdateStatus:
217
217
  raise excs.Error(f'{self.display_name()} {self._name!r}: cannot delete from view')
pixeltable/dataframe.py CHANGED
@@ -153,7 +153,7 @@ class DataFrame:
153
153
  self,
154
154
  tbl: catalog.TableVersionPath,
155
155
  select_list: Optional[List[Tuple[exprs.Expr, Optional[str]]]] = None,
156
- where_clause: Optional[exprs.Predicate] = None,
156
+ where_clause: Optional[exprs.Expr] = None,
157
157
  group_by_clause: Optional[List[exprs.Expr]] = None,
158
158
  grouping_tbl: Optional[catalog.TableVersion] = None,
159
159
  order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, # List[(expr, asc)]
@@ -530,7 +530,11 @@ class DataFrame:
530
530
  limit=self.limit_val,
531
531
  )
532
532
 
533
- def where(self, pred: exprs.Predicate) -> DataFrame:
533
+ def where(self, pred: exprs.Expr) -> DataFrame:
534
+ if not isinstance(pred, exprs.Expr):
535
+ raise excs.Error(f'Where() requires a Pixeltable expression, but instead got {type(pred)}')
536
+ if not pred.col_type.is_bool_type():
537
+ raise excs.Error(f'Where(): expression needs to return bool, but instead returns {pred.col_type}')
534
538
  return DataFrame(
535
539
  self.tbl,
536
540
  select_list=self.select_list,
@@ -628,12 +632,9 @@ class DataFrame:
628
632
  def __getitem__(self, index: object) -> DataFrame:
629
633
  """
630
634
  Allowed:
631
- - [<Predicate>]: filter operation
632
635
  - [List[Expr]]/[Tuple[Expr]]: setting the select list
633
636
  - [Expr]: setting a single-col select list
634
637
  """
635
- if isinstance(index, exprs.Predicate):
636
- return self.where(index)
637
638
  if isinstance(index, tuple):
638
639
  index = list(index)
639
640
  if isinstance(index, exprs.Expr):
@@ -668,7 +669,7 @@ class DataFrame:
668
669
  tbl = catalog.TableVersionPath.from_dict(d['tbl'])
669
670
  select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
670
671
  if d['select_list'] is not None else None
671
- where_clause = exprs.Predicate.from_dict(d['where_clause']) \
672
+ where_clause = exprs.Expr.from_dict(d['where_clause']) \
672
673
  if d['where_clause'] is not None else None
673
674
  group_by_clause = [exprs.Expr.from_dict(e) for e in d['group_by_clause']] \
674
675
  if d['group_by_clause'] is not None else None
@@ -5,6 +5,7 @@ from .exec_context import ExecContext
5
5
  from .exec_node import ExecNode
6
6
  from .expr_eval_node import ExprEvalNode
7
7
  from .in_memory_data_node import InMemoryDataNode
8
- from .sql_scan_node import SqlScanNode
8
+ from .sql_node import SqlScanNode, SqlLookupNode
9
+ from .row_update_node import RowUpdateNode
9
10
  from .media_validation_node import MediaValidationNode
10
11
  from .data_row_batch import DataRowBatch
@@ -50,7 +50,14 @@ class ExprEvalNode(ExecNode):
50
50
 
51
51
  def _open(self) -> None:
52
52
  warnings.simplefilter("ignore", category=TqdmWarning)
53
- if self.ctx.show_pbar:
53
+ # This is a temporary hack. When B-tree indices on string columns were implemented (via computed columns
54
+ # that invoke the `BtreeIndex.str_filter` udf), it resulted in frivolous progress bars appearing on every
55
+ # insertion. This special-cases the `str_filter` call to suppress the corresponding progress bar.
56
+ # TODO(aaron-siegel) Remove this hack once we clean up progress bars more generally.
57
+ is_str_filter_node = all(
58
+ isinstance(expr, exprs.FunctionCall) and expr.fn.name == 'str_filter' for expr in self.output_exprs
59
+ )
60
+ if self.ctx.show_pbar and not is_str_filter_node:
54
61
  self.pbar = tqdm(
55
62
  total=len(self.target_exprs) * self.ctx.num_rows,
56
63
  desc='Computing cells',
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ import pixeltable.catalog as catalog
5
+ import pixeltable.exprs as exprs
6
+ from pixeltable.utils.media_store import MediaStore
7
+ from .data_row_batch import DataRowBatch
8
+ from .exec_node import ExecNode
9
+
10
+ _logger = logging.getLogger('pixeltable')
11
+
12
+ class RowUpdateNode(ExecNode):
13
+ """
14
+ Update individual rows in the input batches, identified by key columns.
15
+
16
+ The updates for a row are provided as a dict of column names to new values.
17
+ The node assumes that all update dicts contain the same keys, and it populates the slots of the columns present in
18
+ the update list.
19
+ """
20
+ def __init__(
21
+ self, tbl: catalog.TableVersionPath, key_vals_batch: list[tuple], is_rowid_key: bool,
22
+ col_vals_batch: list[dict[catalog.Column, Any]], row_builder: exprs.RowBuilder, input: ExecNode,
23
+ ):
24
+ super().__init__(row_builder, [], [], input)
25
+ self.updates = {key_vals: col_vals for key_vals, col_vals in zip(key_vals_batch, col_vals_batch)}
26
+ self.is_rowid_key = is_rowid_key
27
+ # determine slot idxs of all columns we need to read or write
28
+ # retrieve ColumnRefs from the RowBuilder (has slot_idx set)
29
+ all_col_slot_idxs = {
30
+ col_ref.col: col_ref.slot_idx
31
+ for col_ref in row_builder.unique_exprs if isinstance(col_ref, exprs.ColumnRef)
32
+ }
33
+ self.col_slot_idxs = {col: all_col_slot_idxs[col] for col in col_vals_batch[0].keys()}
34
+ self.key_slot_idxs = {col: all_col_slot_idxs[col] for col in tbl.tbl_version.primary_key_columns()}
35
+ self.matched_key_vals: set[tuple] = set()
36
+
37
+ def __next__(self) -> DataRowBatch:
38
+ batch = next(self.input)
39
+ for row in batch:
40
+ key_vals = row.rowid if self.is_rowid_key else \
41
+ tuple(row[slot_idx] for slot_idx in self.key_slot_idxs.values())
42
+ if key_vals not in self.updates:
43
+ continue
44
+ self.matched_key_vals.add(key_vals)
45
+ col_vals = self.updates[key_vals]
46
+ for col, val in col_vals.items():
47
+ slot_idx = self.col_slot_idxs[col]
48
+ row[slot_idx] = val
49
+ return batch
50
+
51
+ def unmatched_rows(self) -> list[dict[str, Any]]:
52
+ """Return rows that didn't get used in the updates as a list of dicts compatible with TableVersion.insert()."""
53
+ result: list[dict[str, Any]] = []
54
+ key_cols = self.key_slot_idxs.keys()
55
+ for key_vals, col_vals in self.updates.items():
56
+ if key_vals in self.matched_key_vals:
57
+ continue
58
+ row = {col.name: val for col, val in zip(key_cols, key_vals)}
59
+ row.update({col.name: val for col, val in col_vals.items()})
60
+ result.append(row)
61
+ return result