pixeltable 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (88) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/globals.py +3 -0
  5. pixeltable/catalog/insertable_table.py +9 -7
  6. pixeltable/catalog/table.py +220 -143
  7. pixeltable/catalog/table_version.py +36 -18
  8. pixeltable/catalog/table_version_path.py +0 -8
  9. pixeltable/catalog/view.py +3 -3
  10. pixeltable/dataframe.py +9 -24
  11. pixeltable/env.py +107 -36
  12. pixeltable/exceptions.py +7 -4
  13. pixeltable/exec/__init__.py +1 -1
  14. pixeltable/exec/aggregation_node.py +22 -15
  15. pixeltable/exec/component_iteration_node.py +62 -41
  16. pixeltable/exec/data_row_batch.py +7 -7
  17. pixeltable/exec/exec_node.py +35 -7
  18. pixeltable/exec/expr_eval_node.py +2 -1
  19. pixeltable/exec/in_memory_data_node.py +9 -9
  20. pixeltable/exec/sql_node.py +265 -136
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/data_row.py +30 -19
  23. pixeltable/exprs/expr.py +15 -14
  24. pixeltable/exprs/expr_dict.py +55 -0
  25. pixeltable/exprs/expr_set.py +21 -15
  26. pixeltable/exprs/function_call.py +21 -8
  27. pixeltable/exprs/json_path.py +3 -6
  28. pixeltable/exprs/rowid_ref.py +2 -2
  29. pixeltable/exprs/sql_element_cache.py +5 -1
  30. pixeltable/ext/functions/whisperx.py +7 -2
  31. pixeltable/func/callable_function.py +2 -2
  32. pixeltable/func/function_registry.py +6 -7
  33. pixeltable/func/query_template_function.py +11 -12
  34. pixeltable/func/signature.py +17 -15
  35. pixeltable/func/udf.py +0 -4
  36. pixeltable/functions/__init__.py +1 -1
  37. pixeltable/functions/audio.py +4 -6
  38. pixeltable/functions/globals.py +86 -42
  39. pixeltable/functions/huggingface.py +12 -14
  40. pixeltable/functions/image.py +59 -45
  41. pixeltable/functions/json.py +0 -1
  42. pixeltable/functions/mistralai.py +2 -2
  43. pixeltable/functions/openai.py +22 -25
  44. pixeltable/functions/string.py +50 -50
  45. pixeltable/functions/timestamp.py +20 -20
  46. pixeltable/functions/together.py +26 -12
  47. pixeltable/functions/video.py +11 -20
  48. pixeltable/functions/whisper.py +2 -20
  49. pixeltable/globals.py +57 -56
  50. pixeltable/index/base.py +2 -2
  51. pixeltable/index/btree.py +7 -7
  52. pixeltable/index/embedding_index.py +8 -10
  53. pixeltable/io/external_store.py +11 -5
  54. pixeltable/io/globals.py +3 -1
  55. pixeltable/io/hf_datasets.py +4 -4
  56. pixeltable/io/label_studio.py +6 -6
  57. pixeltable/io/parquet.py +14 -13
  58. pixeltable/iterators/document.py +10 -8
  59. pixeltable/iterators/video.py +10 -1
  60. pixeltable/metadata/__init__.py +3 -2
  61. pixeltable/metadata/converters/convert_14.py +4 -2
  62. pixeltable/metadata/converters/convert_15.py +1 -1
  63. pixeltable/metadata/converters/convert_19.py +1 -0
  64. pixeltable/metadata/converters/convert_20.py +1 -1
  65. pixeltable/metadata/converters/util.py +9 -8
  66. pixeltable/metadata/schema.py +32 -21
  67. pixeltable/plan.py +136 -154
  68. pixeltable/store.py +51 -36
  69. pixeltable/tool/create_test_db_dump.py +7 -7
  70. pixeltable/tool/doc_plugins/griffe.py +3 -34
  71. pixeltable/tool/mypy_plugin.py +32 -0
  72. pixeltable/type_system.py +243 -60
  73. pixeltable/utils/arrow.py +10 -9
  74. pixeltable/utils/coco.py +4 -4
  75. pixeltable/utils/documents.py +1 -1
  76. pixeltable/utils/filecache.py +131 -84
  77. pixeltable/utils/formatter.py +1 -1
  78. pixeltable/utils/http_server.py +2 -5
  79. pixeltable/utils/media_store.py +6 -6
  80. pixeltable/utils/pytorch.py +10 -11
  81. pixeltable/utils/sql.py +2 -1
  82. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/METADATA +16 -7
  83. pixeltable-0.2.21.dist-info/RECORD +148 -0
  84. pixeltable/utils/help.py +0 -11
  85. pixeltable-0.2.19.dist-info/RECORD +0 -147
  86. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
  87. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
  88. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
pixeltable/plan.py CHANGED
@@ -1,5 +1,4 @@
1
- import itertools
2
- from typing import Any, Iterable, Optional, Sequence
1
+ from typing import Any, Iterable, Optional, Sequence, cast
3
2
  from uuid import UUID
4
3
 
5
4
  import sqlalchemy as sql
@@ -9,6 +8,7 @@ import pixeltable.exec as exec
9
8
  from pixeltable import catalog
10
9
  from pixeltable import exceptions as excs
11
10
  from pixeltable import exprs
11
+ from pixeltable.exec.sql_node import OrderByItem, OrderByClause, combine_order_by_clauses, print_order_by_clause
12
12
 
13
13
 
14
14
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -46,11 +46,9 @@ class Analyzer:
46
46
  tbl: catalog.TableVersionPath
47
47
  all_exprs: list[exprs.Expr]
48
48
  select_list: list[exprs.Expr]
49
- group_by_clause: list[exprs.Expr]
50
- order_by_clause: list[tuple[exprs.Expr, bool]]
51
-
52
- # exprs that can be expressed in SQL and are retrieved directly from the store
53
- #sql_exprs: list[exprs.Expr]
49
+ group_by_clause: Optional[list[exprs.Expr]] # None for non-aggregate queries; [] for agg query w/o grouping
50
+ grouping_exprs: list[exprs.Expr] # [] for non-aggregate queries or agg query w/o grouping
51
+ order_by_clause: OrderByClause
54
52
 
55
53
  sql_elements: exprs.SqlElementCache
56
54
 
@@ -60,15 +58,14 @@ class Analyzer:
60
58
  # filter predicate applied to output rows of the SQL scan
61
59
  filter: Optional[exprs.Expr]
62
60
 
63
- agg_fn_calls: list[exprs.FunctionCall]
61
+ agg_fn_calls: list[exprs.FunctionCall] # grouping aggregation (ie, not window functions)
62
+ window_fn_calls: list[exprs.FunctionCall]
64
63
  agg_order_by: list[exprs.Expr]
65
64
 
66
65
  def __init__(
67
66
  self, tbl: catalog.TableVersionPath, select_list: Sequence[exprs.Expr],
68
67
  where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[list[exprs.Expr]] = None,
69
68
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None):
70
- if group_by_clause is None:
71
- group_by_clause = []
72
69
  if order_by_clause is None:
73
70
  order_by_clause = []
74
71
  self.tbl = tbl
@@ -78,8 +75,10 @@ class Analyzer:
78
75
  self.select_list = [e.resolve_computed_cols() for e in select_list]
79
76
  if where_clause is not None:
80
77
  where_clause = where_clause.resolve_computed_cols()
81
- self.group_by_clause = [e.resolve_computed_cols() for e in group_by_clause]
82
- self.order_by_clause = [(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
78
+ self.group_by_clause = (
79
+ [e.resolve_computed_cols() for e in group_by_clause] if group_by_clause is not None else None
80
+ )
81
+ self.order_by_clause = [OrderByItem(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
83
82
 
84
83
  self.sql_where_clause = None
85
84
  self.filter = None
@@ -89,20 +88,36 @@ class Analyzer:
89
88
 
90
89
  # all exprs that are evaluated in Python; not executable
91
90
  self.all_exprs = self.select_list.copy()
92
- self.all_exprs.extend(self.group_by_clause)
91
+ if self.group_by_clause is not None:
92
+ self.all_exprs.extend(self.group_by_clause)
93
93
  self.all_exprs.extend(e for e, _ in self.order_by_clause)
94
94
  if self.filter is not None:
95
95
  self.all_exprs.append(self.filter)
96
96
 
97
97
  self.agg_order_by = []
98
+ self.agg_fn_calls = []
99
+ self.window_fn_calls = []
98
100
  self._analyze_agg()
101
+ self.grouping_exprs = self.group_by_clause if self.group_by_clause is not None else []
99
102
 
100
103
  def _analyze_agg(self) -> None:
101
104
  """Check semantic correctness of aggregation and fill in agg-specific fields of Analyzer"""
102
- self.agg_fn_calls = [e for e in self.all_exprs if isinstance(e, exprs.FunctionCall) and _is_agg_fn_call(e)]
105
+ candidates = self.select_list
106
+ agg_fn_calls = exprs.ExprSet(
107
+ exprs.Expr.list_subexprs(
108
+ candidates, expr_class=exprs.FunctionCall,
109
+ filter=lambda e: bool(e.is_agg_fn_call and not e.is_window_fn_call)))
110
+ self.agg_fn_calls = list(agg_fn_calls)
111
+ window_fn_calls = exprs.ExprSet(
112
+ exprs.Expr.list_subexprs(
113
+ candidates, expr_class=exprs.FunctionCall, filter=lambda e: bool(e.is_window_fn_call)))
114
+ self.window_fn_calls = list(window_fn_calls)
103
115
  if len(self.agg_fn_calls) == 0:
104
116
  # nothing to do
105
117
  return
118
+ # if we're doing grouping aggregation and don't have an explicit Group By clause, we're creating a single group
119
+ if self.group_by_clause is None:
120
+ self.group_by_clause = []
106
121
 
107
122
  # check that select list only contains aggregate output
108
123
  grouping_expr_ids = {e.id for e in self.group_by_clause}
@@ -113,8 +128,7 @@ class Analyzer:
113
128
 
114
129
  # check that filter doesn't contain aggregates
115
130
  if self.filter is not None:
116
- agg_fn_calls = [e for e in self.filter.subexprs(expr_class=exprs.FunctionCall, filter=lambda e: _is_agg_fn_call(e))]
117
- if len(agg_fn_calls) > 0:
131
+ if any(_is_agg_fn_call(e) for e in self.filter.subexprs(expr_class=exprs.FunctionCall)):
118
132
  raise excs.Error(f'Filter cannot contain aggregate functions: {self.filter}')
119
133
 
120
134
  # check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
@@ -125,27 +139,6 @@ class Analyzer:
125
139
  if e._contains(filter=lambda e: _is_agg_fn_call(e)):
126
140
  raise excs.Error(f'Grouping expression contains aggregate function: {e}')
127
141
 
128
- # check that agg fn calls don't have contradicting ordering requirements
129
- order_by: list[exprs.Expr] = []
130
- order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
131
- for agg_fn_call in self.agg_fn_calls:
132
- fn_call_order_by = agg_fn_call.get_agg_order_by()
133
- if len(fn_call_order_by) == 0:
134
- continue
135
- if len(order_by) == 0:
136
- order_by = fn_call_order_by
137
- order_by_origin = agg_fn_call
138
- else:
139
- combined = _get_combined_ordering(
140
- [(e, True) for e in order_by], [(e, True) for e in fn_call_order_by])
141
- if len(combined) == 0:
142
- raise excs.Error((
143
- f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
144
- f"'{agg_fn_call}':\n"
145
- f"{exprs.Expr.print_list(order_by)} vs {exprs.Expr.print_list(fn_call_order_by)}"
146
- ))
147
- self.agg_order_by = order_by
148
-
149
142
  def _determine_agg_status(self, e: exprs.Expr, grouping_expr_ids: set[int]) -> tuple[bool, bool]:
150
143
  """Determine whether expr is the input to or output of an aggregate function.
151
144
  Returns:
@@ -175,14 +168,14 @@ class Analyzer:
175
168
  raise excs.Error(f'Invalid expression, mixes aggregate with non-aggregate: {e}')
176
169
  return is_output, is_input
177
170
 
178
-
179
171
  def finalize(self, row_builder: exprs.RowBuilder) -> None:
180
172
  """Make all exprs executable
181
173
  TODO: add EvalCtx for each expr list?
182
174
  """
183
175
  # maintain original composition of select list
184
176
  row_builder.set_slot_idxs(self.select_list, remove_duplicates=False)
185
- row_builder.set_slot_idxs(self.group_by_clause)
177
+ if self.group_by_clause is not None:
178
+ row_builder.set_slot_idxs(self.group_by_clause)
186
179
  order_by_exprs = [e for e, _ in self.order_by_clause]
187
180
  row_builder.set_slot_idxs(order_by_exprs)
188
181
  row_builder.set_slot_idxs(self.all_exprs)
@@ -191,6 +184,19 @@ class Analyzer:
191
184
  row_builder.set_slot_idxs(self.agg_fn_calls)
192
185
  row_builder.set_slot_idxs(self.agg_order_by)
193
186
 
187
+ def get_window_fn_ob_clause(self) -> Optional[OrderByClause]:
188
+ clause: list[OrderByClause] = []
189
+ for fn_call in self.window_fn_calls:
190
+ # window functions require ordering by the group_by/order_by clauses
191
+ group_by_exprs, order_by_exprs = fn_call.get_window_sort_exprs()
192
+ clause.append(
193
+ [OrderByItem(e, None) for e in group_by_exprs] + [OrderByItem(e, True) for e in order_by_exprs])
194
+ return combine_order_by_clauses(clause)
195
+
196
+ def has_agg(self) -> bool:
197
+ """True if there is any kind of aggregation in the query"""
198
+ return self.group_by_clause is not None or len(self.agg_fn_calls) > 0 or len(self.window_fn_calls) > 0
199
+
194
200
 
195
201
  class Planner:
196
202
  # TODO: create an exec.CountNode and change this to create_count_plan()
@@ -507,93 +513,35 @@ class Planner:
507
513
  return plan, len(row_builder.default_eval_ctx.target_exprs)
508
514
 
509
515
  @classmethod
510
- def _determine_ordering(cls, analyzer: Analyzer) -> list[tuple[exprs.Expr, bool]]:
511
- """Returns the exprs for the ORDER BY clause of the SqlScanNode"""
512
- order_by_items: list[tuple[exprs.Expr, Optional[bool]]] = []
513
- order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
514
-
515
-
516
- # window functions require ordering by the group_by/order_by clauses
517
- window_fn_calls = [
518
- e for e in analyzer.all_exprs if isinstance(e, exprs.FunctionCall) and e.is_window_fn_call
519
- ]
520
- if len(window_fn_calls) > 0:
521
- for fn_call in window_fn_calls:
516
+ def _verify_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> None:
517
+ """Verify that the various ordering requirements don't conflict"""
518
+ ob_clauses: list[OrderByClause] = [analyzer.order_by_clause.copy()]
519
+
520
+ if verify_agg:
521
+ ordering: OrderByClause
522
+ for fn_call in analyzer.window_fn_calls:
523
+ # window functions require ordering by the group_by/order_by clauses
522
524
  gb, ob = fn_call.get_window_sort_exprs()
523
- # for now, the ordering is implicitly ascending
524
- fn_call_ordering = [(e, None) for e in gb] + [(e, True) for e in ob]
525
- if len(order_by_items) == 0:
526
- order_by_items = fn_call_ordering
527
- order_by_origin = fn_call
528
- else:
529
- # check for compatibility
530
- other_order_by_clauses = fn_call_ordering
531
- combined = _get_combined_ordering(order_by_items, other_order_by_clauses)
532
- if len(combined) == 0:
533
- raise excs.Error((
534
- f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
535
- f"'{fn_call}':\n"
536
- f"{exprs.Expr.print_list(order_by_items)} vs {exprs.Expr.print_list(other_order_by_clauses)}"
537
- ))
538
- order_by_items = combined
539
-
540
- if len(analyzer.group_by_clause) > 0:
541
- agg_ordering = [(e, None) for e in analyzer.group_by_clause] + [(e, True) for e in analyzer.agg_order_by]
542
- if len(order_by_items) > 0:
543
- # check for compatibility
544
- combined = _get_combined_ordering(order_by_items, agg_ordering)
545
- if len(combined) == 0:
546
- raise excs.Error((
547
- f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
548
- f"grouping expressions:\n"
549
- f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
550
- f"{exprs.Expr.print_list([e for e, _ in agg_ordering])}"
551
- ))
552
- order_by_items = combined
553
- else:
554
- order_by_items = agg_ordering
525
+ ordering = [OrderByItem(e, None) for e in gb] + [OrderByItem(e, True) for e in ob]
526
+ ob_clauses.append(ordering)
527
+ for fn_call in analyzer.agg_fn_calls:
528
+ # agg functions with an ordering requirement are implicitly ascending
529
+ ordering = (
530
+ [OrderByItem(e, None) for e in analyzer.group_by_clause]
531
+ + [OrderByItem(e, True) for e in fn_call.get_agg_order_by()]
532
+ )
533
+ ob_clauses.append(ordering)
534
+ if len(ob_clauses) <= 1:
535
+ return
555
536
 
556
- if len(analyzer.order_by_clause) > 0:
557
- if len(order_by_items) > 0:
558
- # check for compatibility
559
- combined = _get_combined_ordering(order_by_items, analyzer.order_by_clause)
560
- if len(combined) == 0:
561
- raise excs.Error((
562
- f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
563
- f"order-by expressions:\n"
564
- f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
565
- f"{exprs.Expr.print_list([e for e, _ in analyzer.order_by_clause])}"
566
- ))
567
- order_by_items = combined
568
- else:
569
- order_by_items = analyzer.order_by_clause
570
-
571
- # TODO: can this be unified with the same logic in RowBuilder
572
- def refs_unstored_iter_col(e: exprs.Expr) -> bool:
573
- if not isinstance(e, exprs.ColumnRef):
574
- return False
575
- tbl = e.col.tbl
576
- return tbl.is_component_view() and tbl.is_iterator_column(e.col) and not e.col.is_stored
577
- unstored_iter_col_refs = list(exprs.Expr.list_subexprs(analyzer.all_exprs, expr_class=exprs.ColumnRef, filter=refs_unstored_iter_col))
578
- if len(unstored_iter_col_refs) > 0 and len(order_by_items) == 0:
579
- # we don't already have a user-requested ordering and we access unstored iterator columns:
580
- # order by the primary key of the component view, which minimizes the number of iterator instantiations
581
- component_views = {e.col.tbl for e in unstored_iter_col_refs}
582
- # TODO: generalize this to multi-level iteration
583
- assert len(component_views) == 1
584
- component_view = list(component_views)[0]
585
- order_by_items = [
586
- (exprs.RowidRef(component_view, idx), None)
587
- for idx in range(len(component_view.store_tbl.rowid_columns()))
588
- ]
589
- order_by_origin = unstored_iter_col_refs[0]
590
-
591
- for e in [e for e, _ in order_by_items]:
592
- if not analyzer.sql_elements.contains(e):
593
- raise excs.Error(f'order_by element cannot be expressed in SQL: {e}')
594
- # we do ascending ordering by default, if not specified otherwise
595
- order_by_items = [(e, True) if asc is None else (e, asc) for e, asc in order_by_items]
596
- return order_by_items
537
+ combined_ordering = ob_clauses[0]
538
+ for ordering in ob_clauses[1:]:
539
+ combined = combine_order_by_clauses([combined_ordering, ordering])
540
+ if combined is None:
541
+ raise excs.Error(
542
+ f'Incompatible ordering requirements: '
543
+ f'{print_order_by_clause(combined_ordering)} vs {print_order_by_clause(ordering)}')
544
+ combined_ordering = combined
597
545
 
598
546
  @classmethod
599
547
  def _is_contained_in(cls, l1: Iterable[exprs.Expr], l2: Iterable[exprs.Expr]) -> bool:
@@ -632,8 +580,6 @@ class Planner:
632
580
  """
633
581
  if select_list is None:
634
582
  select_list = []
635
- if group_by_clause is None:
636
- group_by_clause = []
637
583
  if order_by_clause is None:
638
584
  order_by_clause = []
639
585
  if exact_version_only is None:
@@ -641,16 +587,12 @@ class Planner:
641
587
  analyzer = Analyzer(
642
588
  tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
643
589
  order_by_clause=order_by_clause)
644
- input_exprs = exprs.ExprSet(exprs.Expr.list_subexprs(
645
- analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False))
646
- # remove Literals from sql_exprs, we don't want to materialize them via a Select
647
- input_exprs = exprs.ExprSet(e for e in input_exprs if not isinstance(e, exprs.Literal))
648
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], input_exprs)
590
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
649
591
 
650
592
  analyzer.finalize(row_builder)
651
593
  # select_list: we need to materialize everything that's been collected
652
594
  # with_pk: for now, we always retrieve the PK, because we need it for the file cache
653
- eval_ctx = row_builder.create_eval_ctx(analyzer.all_exprs)
595
+ eval_ctx = row_builder.create_eval_ctx(analyzer.select_list)
654
596
  plan = cls._create_query_plan(
655
597
  tbl, row_builder, analyzer=analyzer, eval_ctx=eval_ctx, limit=limit, with_pk=True,
656
598
  exact_version_only=exact_version_only)
@@ -677,48 +619,88 @@ class Planner:
677
619
  if exact_version_only is None:
678
620
  exact_version_only = []
679
621
  assert isinstance(tbl, catalog.TableVersionPath)
680
- is_agg_query = len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0
622
+ sql_elements = analyzer.sql_elements
623
+ is_python_agg = (
624
+ not sql_elements.contains(analyzer.agg_fn_calls) or not sql_elements.contains(analyzer.window_fn_calls)
625
+ )
681
626
  ctx = exec.ExecContext(row_builder)
627
+ cls._verify_ordering(analyzer, verify_agg=is_python_agg)
628
+
629
+ # materialized with SQL scan:
630
+ # - select list subexprs that aren't aggregates
631
+ # - Where clause conjuncts that can't be run in SQL
632
+ # - all grouping exprs, if any aggregate function call can't be run in SQL (in that case, they all have to be
633
+ # run in Python)
634
+ candidates = list(exprs.Expr.list_subexprs(
635
+ analyzer.select_list,
636
+ filter=lambda e: (
637
+ sql_elements.contains(e)
638
+ and not e._contains(cls=exprs.FunctionCall, filter=lambda e: bool(e.is_agg_fn_call))
639
+ ),
640
+ traverse_matches=False))
641
+ if analyzer.filter is not None:
642
+ candidates.extend(exprs.Expr.subexprs(
643
+ analyzer.filter, filter=lambda e: sql_elements.contains(e), traverse_matches=False))
644
+ if is_python_agg and analyzer.group_by_clause is not None:
645
+ candidates.extend(exprs.Expr.list_subexprs(
646
+ analyzer.group_by_clause, filter=lambda e: sql_elements.contains(e), traverse_matches=False))
647
+ # not isinstance(...): we don't want to materialize Literals via a Select
648
+ sql_scan_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
682
649
 
683
- order_by_items = cls._determine_ordering(analyzer)
684
- sql_limit = 0 if is_agg_query else limit # if we're aggregating, the limit applies to the agg output
685
- sql_exprs = [
686
- e for e in eval_ctx.exprs if analyzer.sql_elements.contains(e) and not isinstance(e, exprs.Literal)
687
- ]
688
650
  plan = exec.SqlScanNode(
689
- tbl, row_builder, select_list=sql_exprs, where_clause=analyzer.sql_where_clause,
690
- filter=analyzer.filter, order_by_items=order_by_items,
691
- limit=sql_limit, set_pk=with_pk, exact_version_only=exact_version_only)
651
+ tbl, row_builder, select_list=sql_scan_exprs, where_clause=analyzer.sql_where_clause,
652
+ filter=analyzer.filter, set_pk=with_pk, exact_version_only=exact_version_only)
653
+ if len(analyzer.window_fn_calls) > 0:
654
+ # we need to order the input for window functions
655
+ plan.add_order_by(analyzer.get_window_fn_ob_clause())
692
656
  plan = cls._insert_prefetch_node(tbl.tbl_version.id, analyzer.select_list, row_builder, plan)
693
657
 
694
- if len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0:
695
- # we're doing aggregation; the input of the AggregateNode are the grouping exprs plus the
658
+ if analyzer.group_by_clause is not None:
659
+ # we're doing grouping aggregation; the input of the AggregateNode are the grouping exprs plus the
696
660
  # args of the agg fn calls
697
- agg_input = exprs.ExprSet(analyzer.group_by_clause.copy())
661
+ agg_input = exprs.ExprSet(analyzer.grouping_exprs.copy())
698
662
  for fn_call in analyzer.agg_fn_calls:
699
663
  agg_input.update(fn_call.components)
700
- if not exprs.ExprSet(sql_exprs).issuperset(agg_input):
664
+ if not sql_scan_exprs.issuperset(agg_input):
701
665
  # we need an ExprEvalNode
702
- plan = exec.ExprEvalNode(row_builder, agg_input, sql_exprs, input=plan)
666
+ plan = exec.ExprEvalNode(row_builder, agg_input, sql_scan_exprs, input=plan)
703
667
 
704
668
  # batch size for aggregation input: this could be the entire table, so we need to divide it into
705
669
  # smaller batches; at the same time, we need to make the batches large enough to amortize the
706
670
  # function call overhead
707
671
  ctx.batch_size = 16
708
672
 
709
- plan = exec.AggregationNode(
710
- tbl.tbl_version, row_builder, analyzer.group_by_clause, analyzer.agg_fn_calls, agg_input, input=plan)
711
- agg_output = exprs.ExprSet(itertools.chain(analyzer.group_by_clause, analyzer.agg_fn_calls))
712
- if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
713
- # we need an ExprEvalNode to evaluate the remaining output exprs
714
- plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
673
+ # do aggregation in SQL if all agg exprs can be translated
674
+ if (sql_elements.contains(analyzer.select_list)
675
+ and sql_elements.contains(analyzer.grouping_exprs)
676
+ and isinstance(plan, exec.SqlNode)
677
+ and plan.to_cte() is not None):
678
+ plan = exec.SqlAggregationNode(
679
+ row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause)
680
+ else:
681
+ plan = exec.AggregationNode(
682
+ tbl.tbl_version, row_builder, analyzer.group_by_clause,
683
+ analyzer.agg_fn_calls + analyzer.window_fn_calls, agg_input, input=plan)
684
+ typecheck_dummy = analyzer.grouping_exprs + analyzer.agg_fn_calls + analyzer.window_fn_calls
685
+ agg_output = exprs.ExprSet(typecheck_dummy)
686
+ if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
687
+ # we need an ExprEvalNode to evaluate the remaining output exprs
688
+ plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
715
689
  else:
716
- if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
690
+ if not exprs.ExprSet(sql_scan_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
717
691
  # we need an ExprEvalNode to evaluate the remaining output exprs
718
- plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_exprs, input=plan)
692
+ plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_scan_exprs, input=plan)
719
693
  # we're returning everything to the user, so we might as well do it in a single batch
720
694
  ctx.batch_size = 0
721
695
 
696
+ sql_node = plan.get_sql_node()
697
+ assert sql_node is not None
698
+ if len(analyzer.order_by_clause) > 0:
699
+ sql_node.add_order_by(analyzer.order_by_clause)
700
+
701
+ if limit is not None:
702
+ plan.set_limit(limit)
703
+
722
704
  plan.set_ctx(ctx)
723
705
  return plan
724
706
 
pixeltable/store.py CHANGED
@@ -7,18 +7,19 @@ import sys
7
7
  import urllib.parse
8
8
  import urllib.request
9
9
  import warnings
10
- from typing import Optional, Dict, Any, List, Tuple, Set, Union
10
+ from typing import Any, Iterator, Literal, Optional, Union
11
11
 
12
12
  import sqlalchemy as sql
13
- from tqdm import tqdm, TqdmWarning
13
+ from tqdm import TqdmWarning, tqdm
14
14
 
15
15
  import pixeltable.catalog as catalog
16
16
  import pixeltable.env as env
17
+ import pixeltable.exceptions as excs
17
18
  from pixeltable import exprs
18
19
  from pixeltable.exec import ExecNode
19
20
  from pixeltable.metadata import schema
20
21
  from pixeltable.utils.media_store import MediaStore
21
- from pixeltable.utils.sql import log_stmt, log_explain
22
+ from pixeltable.utils.sql import log_explain, log_stmt
22
23
 
23
24
  _logger = logging.getLogger('pixeltable')
24
25
 
@@ -31,35 +32,42 @@ class StoreBase:
31
32
  - v_min: version at which the row was created
32
33
  - v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
33
34
  """
35
+ tbl_version: catalog.TableVersion
36
+ sa_md: sql.MetaData
37
+ sa_tbl: Optional[sql.Table]
38
+ _pk_cols: list[sql.Column]
39
+ v_min_col: sql.Column
40
+ v_max_col: sql.Column
41
+ base: Optional[StoreBase]
34
42
 
35
43
  __INSERT_BATCH_SIZE = 1000
36
44
 
37
45
  def __init__(self, tbl_version: catalog.TableVersion):
38
46
  self.tbl_version = tbl_version
39
47
  self.sa_md = sql.MetaData()
40
- self.sa_tbl: Optional[sql.Table] = None
48
+ self.sa_tbl = None
41
49
  # We need to declare a `base` variable here, even though it's only defined for instances of `StoreView`,
42
50
  # since it's referenced by various methods of `StoreBase`
43
51
  self.base = None if tbl_version.base is None else tbl_version.base.store_tbl
44
52
  self.create_sa_tbl()
45
53
 
46
- def pk_columns(self) -> List[sql.Column]:
47
- return self._pk_columns
54
+ def pk_columns(self) -> list[sql.Column]:
55
+ return self._pk_cols
48
56
 
49
- def rowid_columns(self) -> List[sql.Column]:
50
- return self._pk_columns[:-1]
57
+ def rowid_columns(self) -> list[sql.Column]:
58
+ return self._pk_cols[:-1]
51
59
 
52
60
  @abc.abstractmethod
53
- def _create_rowid_columns(self) -> List[sql.Column]:
61
+ def _create_rowid_columns(self) -> list[sql.Column]:
54
62
  """Create and return rowid columns"""
55
63
 
56
- def _create_system_columns(self) -> List[sql.Column]:
64
+ def _create_system_columns(self) -> list[sql.Column]:
57
65
  """Create and return system columns"""
58
66
  rowid_cols = self._create_rowid_columns()
59
67
  self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
60
68
  self.v_max_col = \
61
69
  sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
62
- self._pk_columns = [*rowid_cols, self.v_min_col]
70
+ self._pk_cols = [*rowid_cols, self.v_min_col]
63
71
  return [*rowid_cols, self.v_min_col, self.v_max_col]
64
72
 
65
73
  def create_sa_tbl(self) -> None:
@@ -79,7 +87,7 @@ class StoreBase:
79
87
  # if we're called in response to a schema change, we need to remove the old table first
80
88
  self.sa_md.remove(self.sa_tbl)
81
89
 
82
- idxs: List[sql.Index] = []
90
+ idxs: list[sql.Index] = []
83
91
  # index for all system columns:
84
92
  # - base x view joins can be executed as merge joins
85
93
  # - speeds up ORDER BY rowid DESC
@@ -126,7 +134,7 @@ class StoreBase:
126
134
  return new_file_url
127
135
 
128
136
  def _move_tmp_media_files(
129
- self, table_rows: List[Dict[str, Any]], media_cols: List[catalog.Column], v_min: int
137
+ self, table_rows: list[dict[str, Any]], media_cols: list[catalog.Column], v_min: int
130
138
  ) -> None:
131
139
  """Move tmp media files that we generated to a permanent location"""
132
140
  for c in media_cols:
@@ -135,23 +143,17 @@ class StoreBase:
135
143
  table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
136
144
 
137
145
  def _create_table_row(
138
- self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
139
- exc_col_ids: Set[int], v_min: int
140
- ) -> Tuple[Dict[str, Any], int]:
146
+ self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, exc_col_ids: set[int], pk: tuple[int, ...]
147
+ ) -> tuple[dict[str, Any], int]:
141
148
  """Return Tuple[complete table row, # of exceptions] for insert()
142
149
  Creates a row that includes the PK columns, with the values from input_row.pk.
143
150
  Returns:
144
151
  Tuple[complete table row, # of exceptions]
145
152
  """
146
153
  table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
147
-
148
- assert input_row.pk is not None and len(input_row.pk) == len(self._pk_columns)
149
- for pk_col, pk_val in zip(self._pk_columns, input_row.pk):
150
- if pk_col == self.v_min_col:
151
- table_row[pk_col.name] = v_min
152
- else:
153
- table_row[pk_col.name] = pk_val
154
-
154
+ assert len(pk) == len(self._pk_cols)
155
+ for pk_col, pk_val in zip(self._pk_cols, pk):
156
+ table_row[pk_col.name] = pk_val
155
157
  return table_row, num_excs
156
158
 
157
159
  def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
@@ -212,14 +214,20 @@ class StoreBase:
212
214
  conn.execute(sql.text(stmt))
213
215
 
214
216
  def load_column(
215
- self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, conn: sql.engine.Connection
217
+ self,
218
+ col: catalog.Column,
219
+ exec_plan: ExecNode,
220
+ value_expr_slot_idx: int,
221
+ conn: sql.engine.Connection,
222
+ on_error: Literal['abort', 'ignore']
216
223
  ) -> int:
217
224
  """Update store column of a computed column with values produced by an execution plan
218
225
 
219
226
  Returns:
220
227
  number of rows with exceptions
221
228
  Raises:
222
- sql.exc.DBAPIError if there was an error during SQL execution
229
+ sql.exc.DBAPIError if there was a SQL error during execution
230
+ excs.Error if on_error='abort' and there was an exception during row evaluation
223
231
  """
224
232
  num_excs = 0
225
233
  num_rows = 0
@@ -253,6 +261,10 @@ class StoreBase:
253
261
  if result_row.has_exc(value_expr_slot_idx):
254
262
  num_excs += 1
255
263
  value_exc = result_row.get_exc(value_expr_slot_idx)
264
+ if on_error == 'abort':
265
+ raise excs.Error(
266
+ f'Error while evaluating computed column `{col.name}`:\n{value_exc}'
267
+ ) from value_exc
256
268
  # we store a NULL value and record the exception/exc type
257
269
  error_type = type(value_exc).__name__
258
270
  error_msg = str(value_exc)
@@ -291,8 +303,8 @@ class StoreBase:
291
303
 
292
304
  def insert_rows(
293
305
  self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None,
294
- show_progress: bool = True
295
- ) -> Tuple[int, int, Set[int]]:
306
+ show_progress: bool = True, rowids: Optional[Iterator[int]] = None
307
+ ) -> tuple[int, int, set[int]]:
296
308
  """Insert rows into the store table and update the catalog table's md
297
309
  Returns:
298
310
  number of inserted rows, number of exceptions, set of column ids that have exceptions
@@ -302,7 +314,7 @@ class StoreBase:
302
314
  # TODO: total?
303
315
  num_excs = 0
304
316
  num_rows = 0
305
- cols_with_excs: Set[int] = set()
317
+ cols_with_excs: set[int] = set()
306
318
  progress_bar: Optional[tqdm] = None # create this only after we started executing
307
319
  row_builder = exec_plan.row_builder
308
320
  media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
@@ -312,13 +324,16 @@ class StoreBase:
312
324
  num_rows += len(row_batch)
313
325
  for batch_start_idx in range(0, len(row_batch), self.__INSERT_BATCH_SIZE):
314
326
  # compute batch of rows and convert them into table rows
315
- table_rows: List[Dict[str, Any]] = []
327
+ table_rows: list[dict[str, Any]] = []
316
328
  for row_idx in range(batch_start_idx, min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))):
317
329
  row = row_batch[row_idx]
318
- table_row, num_row_exc = \
319
- self._create_table_row(row, row_builder, media_cols, cols_with_excs, v_min=v_min)
330
+
331
+ rowid = (next(rowids),) if rowids is not None else row.pk[:-1]
332
+ pk = rowid + (v_min,)
333
+ table_row, num_row_exc = self._create_table_row(row, row_builder, cols_with_excs, pk=pk)
320
334
  num_excs += num_row_exc
321
335
  table_rows.append(table_row)
336
+
322
337
  if show_progress:
323
338
  if progress_bar is None:
324
339
  warnings.simplefilter("ignore", category=TqdmWarning)
@@ -353,7 +368,7 @@ class StoreBase:
353
368
  return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
354
369
 
355
370
  def delete_rows(
356
- self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
371
+ self, current_version: int, base_versions: list[Optional[int]], match_on_vmin: bool,
357
372
  where_clause: Optional[sql.ColumnElement[bool]], conn: sql.engine.Connection) -> int:
358
373
  """Mark rows as deleted that are live and were created prior to current_version.
359
374
  Also: populate the undo columns
@@ -397,7 +412,7 @@ class StoreTable(StoreBase):
397
412
  assert not tbl_version.is_view()
398
413
  super().__init__(tbl_version)
399
414
 
400
- def _create_rowid_columns(self) -> List[sql.Column]:
415
+ def _create_rowid_columns(self) -> list[sql.Column]:
401
416
  self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
402
417
  return [self.rowid_col]
403
418
 
@@ -413,7 +428,7 @@ class StoreView(StoreBase):
413
428
  assert catalog_view.is_view()
414
429
  super().__init__(catalog_view)
415
430
 
416
- def _create_rowid_columns(self) -> List[sql.Column]:
431
+ def _create_rowid_columns(self) -> list[sql.Column]:
417
432
  # a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
418
433
  self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
419
434
  return self.rowid_cols
@@ -439,7 +454,7 @@ class StoreComponentView(StoreView):
439
454
  def __init__(self, catalog_view: catalog.TableVersion):
440
455
  super().__init__(catalog_view)
441
456
 
442
- def _create_rowid_columns(self) -> List[sql.Column]:
457
+ def _create_rowid_columns(self) -> list[sql.Column]:
443
458
  # each base row is expanded into n view rows
444
459
  self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
445
460
  # name of pos column: avoid collisions with bases' pos columns