pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/plan.py ADDED
@@ -0,0 +1,620 @@
1
+ from typing import Tuple, Optional, List, Set, Any, Dict
2
+ from uuid import UUID
3
+
4
+ import sqlalchemy as sql
5
+
6
+ import pixeltable.exec as exec
7
+ import pixeltable.func as func
8
+ from pixeltable import catalog
9
+ from pixeltable import exceptions as excs
10
+ from pixeltable import exprs
11
+
12
+
13
+ def _is_agg_fn_call(e: exprs.Expr) -> bool:
14
+ return isinstance(e, exprs.FunctionCall) and e.is_agg_fn_call and not e.is_window_fn_call
15
+
16
+ def _get_combined_ordering(
17
+ o1: List[Tuple[exprs.Expr, bool]], o2: List[Tuple[exprs.Expr, bool]]
18
+ ) -> List[Tuple[exprs.Expr, bool]]:
19
+ """Returns an ordering that's compatible with both o1 and o2, or an empty list if no such ordering exists"""
20
+ result: List[Tuple[exprs.Expr, bool]] = []
21
+ # determine combined ordering
22
+ for (e1, asc1), (e2, asc2) in zip(o1, o2):
23
+ if e1.id != e2.id:
24
+ return []
25
+ if asc1 is not None and asc2 is not None and asc1 != asc2:
26
+ return []
27
+ asc = asc1 if asc1 is not None else asc2
28
+ result.append((e1, asc))
29
+
30
+ # add remaining ordering of the longer list
31
+ prefix_len = min(len(o1), len(o2))
32
+ if len(o1) > prefix_len:
33
+ result.extend(o1[prefix_len:])
34
+ elif len(o2) > prefix_len:
35
+ result.extend(o2[prefix_len:])
36
+ return result
37
+
38
+ class Analyzer:
39
+ """Class to perform semantic analysis of a query and to store the analysis state"""
40
+
41
+ def __init__(
42
+ self, tbl: catalog.TableVersionPath, select_list: List[exprs.Expr],
43
+ where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
44
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None):
45
+ if group_by_clause is None:
46
+ group_by_clause = []
47
+ if order_by_clause is None:
48
+ order_by_clause = []
49
+ self.tbl = tbl
50
+
51
+ # remove references to unstored computed cols
52
+ self.select_list = [e.resolve_computed_cols() for e in select_list]
53
+ if where_clause is not None:
54
+ where_clause = where_clause.resolve_computed_cols()
55
+ self.group_by_clause = [e.resolve_computed_cols() for e in group_by_clause]
56
+ self.order_by_clause = [(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
57
+
58
+ # Where clause of the Select stmt of the SQL scan
59
+ self.sql_where_clause: Optional[exprs.Expr] = None
60
+ # filter predicate applied to output rows of the SQL scan
61
+ self.filter: Optional[exprs.Predicate] = None
62
+ # not executable
63
+ #self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
64
+ if where_clause is not None:
65
+ where_clause_conjuncts, self.filter = where_clause.split_conjuncts(lambda e: e.sql_expr() is not None)
66
+ self.sql_where_clause = exprs.CompoundPredicate.make_conjunction(where_clause_conjuncts)
67
+
68
+ # all exprs that are evaluated in Python; not executable
69
+ self.all_exprs = self.select_list.copy()
70
+ self.all_exprs.extend(self.group_by_clause)
71
+ self.all_exprs.extend([e for e, _ in self.order_by_clause])
72
+ if self.filter is not None:
73
+ self.all_exprs.append(self.filter)
74
+ self.sql_exprs = list(exprs.Expr.list_subexprs(
75
+ self.all_exprs, filter=lambda e: e.sql_expr() is not None, traverse_matches=False))
76
+
77
+ # sql_exprs: exprs that can be expressed via SQL and are retrieved directly from the store
78
+ # (we don't want to materialize literals via SQL, so we remove them here)
79
+ self.sql_exprs = [e for e in self.sql_exprs if not isinstance(e, exprs.Literal)]
80
+
81
+ self.agg_fn_calls: List[exprs.FunctionCall] = []
82
+ self.agg_order_by: List[exprs.Expr] = []
83
+ self._analyze_agg()
84
+
85
+ def _analyze_agg(self) -> None:
86
+ """Check semantic correctness of aggregation and fill in agg-specific fields of Analyzer"""
87
+ self.agg_fn_calls = [e for e in self.all_exprs if _is_agg_fn_call(e)]
88
+ if len(self.agg_fn_calls) == 0:
89
+ # nothing to do
90
+ return
91
+
92
+ # check that select list only contains aggregate output
93
+ grouping_expr_ids = {e.id for e in self.group_by_clause}
94
+ is_agg_output = [self._determine_agg_status(e, grouping_expr_ids)[0] for e in self.select_list]
95
+ if is_agg_output.count(False) > 0:
96
+ raise excs.Error(
97
+ f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}')
98
+
99
+ # check that filter doesn't contain aggregates
100
+ if self.filter is not None:
101
+ agg_fn_calls = [e for e in self.filter.subexprs(filter=lambda e: _is_agg_fn_call(e))]
102
+ if len(agg_fn_calls) > 0:
103
+ raise excs.Error(f'Filter cannot contain aggregate functions: {self.filter}')
104
+
105
+ # check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
106
+ # aggregation and rely on the SqlScanNode returning data in the correct order)
107
+ for e in self.group_by_clause:
108
+ if e.sql_expr() is None:
109
+ raise excs.Error(f'Invalid grouping expression, needs to be expressible in SQL: {e}')
110
+ if e.contains(filter=lambda e: _is_agg_fn_call(e)):
111
+ raise excs.Error(f'Grouping expression contains aggregate function: {e}')
112
+
113
+ # check that agg fn calls don't have contradicting ordering requirements
114
+ order_by: List[exprs.Exprs] = []
115
+ order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
116
+ for agg_fn_call in self.agg_fn_calls:
117
+ fn_call_order_by = agg_fn_call.get_agg_order_by()
118
+ if len(fn_call_order_by) == 0:
119
+ continue
120
+ if len(order_by) == 0:
121
+ order_by = fn_call_order_by
122
+ order_by_origin = agg_fn_call
123
+ else:
124
+ combined = _get_combined_ordering(
125
+ [(e, True) for e in order_by], [(e, True) for e in fn_call_order_by])
126
+ if len(combined) == 0:
127
+ raise excs.Error((
128
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
129
+ f"'{agg_fn_call}':\n"
130
+ f"{exprs.Expr.print_list(order_by)} vs {exprs.Expr.print_list(fn_call_order_by)}"
131
+ ))
132
+ self.agg_order_by = order_by
133
+
134
+ def _determine_agg_status(self, e: exprs.Expr, grouping_expr_ids: Set[int]) -> Tuple[bool, bool]:
135
+ """Determine whether expr is the input to or output of an aggregate function.
136
+ Returns:
137
+ (<is output>, <is input>)
138
+ """
139
+ if e.id in grouping_expr_ids:
140
+ return True, True
141
+ elif _is_agg_fn_call(e):
142
+ for c in e.components:
143
+ _, is_input = self._determine_agg_status(c, grouping_expr_ids)
144
+ if not is_input:
145
+ raise excs.Error(f'Invalid nested aggregates: {e}')
146
+ return True, False
147
+ elif isinstance(e, exprs.Literal):
148
+ return True, True
149
+ elif isinstance(e, exprs.ColumnRef) or isinstance(e, exprs.RowidRef):
150
+ # we already know that this isn't a grouping expr
151
+ return False, True
152
+ else:
153
+ # an expression such as <grouping expr 1> + <grouping expr 2> can both be the output and input of agg
154
+ assert len(e.components) > 0
155
+ component_is_output, component_is_input = zip(
156
+ *[self._determine_agg_status(c, grouping_expr_ids) for c in e.components])
157
+ is_output = component_is_output.count(True) == len(e.components)
158
+ is_input = component_is_input.count(True) == len(e.components)
159
+ if not is_output and not is_input:
160
+ raise excs.Error(f'Invalid expression, mixes aggregate with non-aggregate: {e}')
161
+ return is_output, is_input
162
+
163
+
164
+ def finalize(self, row_builder: exprs.RowBuilder) -> None:
165
+ """Make all exprs executable
166
+ TODO: add EvalCtx for each expr list?
167
+ """
168
+ # maintain original composition of select list
169
+ row_builder.substitute_exprs(self.select_list, remove_duplicates=False)
170
+ row_builder.substitute_exprs(self.group_by_clause)
171
+ order_by_exprs = [e for e, _ in self.order_by_clause]
172
+ row_builder.substitute_exprs(order_by_exprs)
173
+ self.order_by_clause = [(e, asc) for e, (_, asc) in zip(order_by_exprs, self.order_by_clause)]
174
+ row_builder.substitute_exprs(self.all_exprs)
175
+ row_builder.substitute_exprs(self.sql_exprs)
176
+ if self.filter is not None:
177
+ self.filter = row_builder.unique_exprs[self.filter]
178
+ row_builder.substitute_exprs(self.agg_fn_calls)
179
+ row_builder.substitute_exprs(self.agg_order_by)
180
+
181
+
182
+ class Planner:
183
+ # TODO: create an exec.CountNode and change this to create_count_plan()
184
+ @classmethod
185
+ def create_count_stmt(
186
+ cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Predicate] = None
187
+ ) -> sql.Select:
188
+ stmt = sql.select(sql.func.count('*'))
189
+ refd_tbl_ids: Set[UUID] = set()
190
+ if where_clause is not None:
191
+ analyzer = cls.analyze(tbl, where_clause)
192
+ if analyzer.filter is not None:
193
+ raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
194
+ clause_element = analyzer.sql_where_clause.sql_expr()
195
+ assert clause_element is not None
196
+ stmt = stmt.where(clause_element)
197
+ refd_tbl_ids = where_clause.tbl_ids()
198
+ stmt = exec.SqlScanNode.create_from_clause(tbl, stmt, refd_tbl_ids)
199
+ return stmt
200
+
201
+ @classmethod
202
+ def create_insert_plan(
203
+ cls, tbl: catalog.TableVersion, rows: List[Dict[str, Any]], ignore_errors: bool
204
+ ) -> exec.ExecNode:
205
+ """Creates a plan for TableVersion.insert()"""
206
+ assert not tbl.is_view()
207
+ # stored_cols: all cols we need to store, incl computed cols (and indices)
208
+ stored_cols = [c for c in tbl.cols if c.is_stored]
209
+ assert len(stored_cols) > 0
210
+
211
+ row_builder = exprs.RowBuilder([], stored_cols, [])
212
+
213
+ # create InMemoryDataNode for 'rows'
214
+ stored_col_info = row_builder.output_slot_idxs()
215
+ stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
216
+ input_col_info = [info for info in stored_col_info if not info.col.is_computed]
217
+ plan = exec.InMemoryDataNode(tbl, rows, row_builder, tbl.next_rowid)
218
+
219
+ media_input_cols = [info for info in input_col_info if info.col.col_type.is_media_type()]
220
+
221
+ # prefetch external files for all input column refs for validation
222
+ plan = exec.CachePrefetchNode(tbl.id, media_input_cols, plan)
223
+ plan = exec.MediaValidationNode(row_builder, media_input_cols, input=plan)
224
+
225
+ computed_exprs = row_builder.default_eval_ctx.target_exprs
226
+ if len(computed_exprs) > 0:
227
+ # add an ExprEvalNode when there are exprs to compute
228
+ plan = exec.ExprEvalNode(row_builder, computed_exprs, [], input=plan)
229
+
230
+ plan.set_stored_img_cols(stored_img_col_info)
231
+ plan.set_ctx(
232
+ exec.ExecContext(
233
+ row_builder, batch_size=0, show_pbar=True, num_computed_exprs=len(computed_exprs),
234
+ ignore_errors=ignore_errors))
235
+ return plan
236
+
237
+ @classmethod
238
+ def create_update_plan(
239
+ cls, tbl: catalog.TableVersionPath,
240
+ update_targets: dict[catalog.Column, exprs.Expr],
241
+ recompute_targets: List[catalog.Column],
242
+ where_clause: Optional[exprs.Predicate], cascade: bool
243
+ ) -> Tuple[exec.ExecNode, List[str], List[catalog.Column]]:
244
+ """Creates a plan to materialize updated rows.
245
+ The plan:
246
+ - retrieves rows that are visible at the current version of the table
247
+ - materializes all stored columns and the update targets
248
+ - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
249
+ and copies the values of all other stored columns
250
+ - if cascade is False, copies all columns that aren't update targets from the original rows
251
+ Returns:
252
+ - root node of the plan
253
+ - list of qualified column names that are getting updated
254
+ - list of columns that are being recomputed
255
+ """
256
+ # retrieve all stored cols and all target exprs
257
+ assert isinstance(tbl, catalog.TableVersionPath)
258
+ target = tbl.tbl_version # the one we need to update
259
+ updated_cols = list(update_targets.keys())
260
+ if len(recompute_targets) > 0:
261
+ recomputed_cols = recompute_targets.copy()
262
+ else:
263
+ recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else {}
264
+ # we only need to recompute stored columns (unstored ones are substituted away)
265
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
266
+ recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
267
+ copied_cols = [
268
+ col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
269
+ ]
270
+ select_list = [exprs.ColumnRef(col) for col in copied_cols]
271
+ select_list.extend(update_targets.values())
272
+
273
+ recomputed_exprs = \
274
+ [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
275
+ # recomputed cols reference the new values of the updated cols
276
+ for col, e in update_targets.items():
277
+ exprs.Expr.list_substitute(recomputed_exprs, exprs.ColumnRef(col), e)
278
+ select_list.extend(recomputed_exprs)
279
+
280
+ # we need to retrieve the PK columns of the existing rows
281
+ plan = cls.create_query_plan(tbl, select_list, where_clause=where_clause, with_pk=True, ignore_errors=True)
282
+ all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
283
+ # update row builder with column information
284
+ [plan.row_builder.add_table_column(col, select_list[i].slot_idx) for i, col in enumerate(all_base_cols)]
285
+ return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + list(recomputed_cols)], list(recomputed_cols)
286
+
287
+ @classmethod
288
+ def create_view_update_plan(
289
+ cls, view: catalog.TableVersionPath, recompute_targets: List[catalog.Column]
290
+ ) -> exec.ExecNode:
291
+ """Creates a plan to materialize updated rows for a view, given that the base table has been updated.
292
+ The plan:
293
+ - retrieves rows that are visible at the current version of the table and satisfy the view predicate
294
+ - materializes all stored columns and the update targets
295
+ - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
296
+ and copies the values of all other stored columns
297
+ - if cascade is False, copies all columns that aren't update targets from the original rows
298
+
299
+ TODO: unify with create_view_load_plan()
300
+
301
+ Returns:
302
+ - root node of the plan
303
+ - list of qualified column names that are getting updated
304
+ - list of columns that are being recomputed
305
+ """
306
+ assert isinstance(view, catalog.TableVersionPath)
307
+ assert view.is_view()
308
+ target = view.tbl_version # the one we need to update
309
+ # retrieve all stored cols and all target exprs
310
+ recomputed_cols = set(recompute_targets.copy())
311
+ copied_cols = [col for col in target.cols if col.is_stored and not col in recomputed_cols]
312
+ select_list = [exprs.ColumnRef(col) for col in copied_cols]
313
+ # resolve recomputed exprs to stored columns in the base
314
+ recomputed_exprs = \
315
+ [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_cols) for c in recomputed_cols]
316
+ select_list.extend(recomputed_exprs)
317
+
318
+ # we need to retrieve the PK columns of the existing rows
319
+ plan = cls.create_query_plan(
320
+ view, select_list, where_clause=target.predicate, with_pk=True, ignore_errors=True,
321
+ exact_version_only=view.get_bases())
322
+ [
323
+ plan.row_builder.add_table_column(col, select_list[i].slot_idx)
324
+ for i, col in enumerate(copied_cols + list(recomputed_cols)) # same order as select_list
325
+ ]
326
+ # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
327
+ stored_img_col_info = \
328
+ [info for info in plan.row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
329
+ plan.set_stored_img_cols(stored_img_col_info)
330
+ return plan
331
+
332
+ @classmethod
333
+ def create_view_load_plan(
334
+ cls, view: catalog.TableVersionPath, propagates_insert: bool = False
335
+ ) -> Tuple[exec.ExecNode, int]:
336
+ """Creates a query plan for populating a view.
337
+
338
+ Args:
339
+ view: the view to populate
340
+ propagates_insert: if True, we're propagating a base update to this view
341
+
342
+ Returns:
343
+ - root node of the plan
344
+ - number of materialized values per row
345
+ """
346
+ assert isinstance(view, catalog.TableVersionPath)
347
+ assert view.is_view()
348
+ # things we need to materialize as DataRows:
349
+ # 1. stored computed cols
350
+ # - iterator columns are effectively computed, just not with a value_expr
351
+ # - we can ignore stored non-computed columns because they have a default value that is supplied directly by
352
+ # the store
353
+ target = view.tbl_version # the one we need to populate
354
+ stored_cols = [c for c in target.cols if c.is_stored and (c.is_computed or target.is_iterator_column(c))]
355
+ # 2. for component views: iterator args
356
+ iterator_args = [target.iterator_args] if target.iterator_args is not None else []
357
+
358
+ row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
359
+
360
+ # execution plan:
361
+ # 1. materialize exprs computed from the base that are needed for stored view columns
362
+ # 2. if it's an iterator view, expand the base rows into component rows
363
+ # 3. materialize stored view columns that haven't been produced by step 1
364
+ base_output_exprs = [e for e in row_builder.default_eval_ctx.exprs if e.is_bound_by(view.base)]
365
+ view_output_exprs = [
366
+ e for e in row_builder.default_eval_ctx.target_exprs
367
+ if e.is_bound_by(view) and not e.is_bound_by(view.base)
368
+ ]
369
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
370
+ base_analyzer = Analyzer(view, base_output_exprs, where_clause=target.predicate)
371
+ plan = cls._create_query_plan(
372
+ view.base, row_builder=row_builder, analyzer=base_analyzer, with_pk=True,
373
+ exact_version_only=view.get_bases() if propagates_insert else [])
374
+ exec_ctx = plan.ctx
375
+ if target.is_component_view():
376
+ plan = exec.ComponentIterationNode(target, plan)
377
+ if len(view_output_exprs) > 0:
378
+ plan = exec.ExprEvalNode(
379
+ row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs,input=plan)
380
+
381
+ stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
382
+ plan.set_stored_img_cols(stored_img_col_info)
383
+ exec_ctx.ignore_errors = True
384
+ plan.set_ctx(exec_ctx)
385
+ return plan, len(row_builder.default_eval_ctx.target_exprs)
386
+
387
+ @classmethod
388
+ def _determine_ordering(cls, analyzer: Analyzer) -> List[Tuple[exprs.Expr, bool]]:
389
+ """Returns the exprs for the ORDER BY clause of the SqlScanNode"""
390
+ order_by_items: List[Tuple[exprs.Expr, Optional[bool]]] = []
391
+ order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
392
+
393
+
394
+ # window functions require ordering by the group_by/order_by clauses
395
+ window_fn_calls = [
396
+ e for e in analyzer.all_exprs if isinstance(e, exprs.FunctionCall) and e.is_window_fn_call
397
+ ]
398
+ if len(window_fn_calls) > 0:
399
+ for fn_call in window_fn_calls:
400
+ gb, ob = fn_call.get_window_sort_exprs()
401
+ # for now, the ordering is implicitly ascending
402
+ fn_call_ordering = [(e, None) for e in gb] + [(e, True) for e in ob]
403
+ if len(order_by_items) == 0:
404
+ order_by_items = fn_call_ordering
405
+ order_by_origin = fn_call
406
+ else:
407
+ # check for compatibility
408
+ other_order_by_clauses = fn_call_ordering
409
+ combined = _get_combined_ordering(order_by_items, other_order_by_clauses)
410
+ if len(combined) == 0:
411
+ raise excs.Error((
412
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
413
+ f"'{fn_call}':\n"
414
+ f"{exprs.Expr.print_list(order_by_items)} vs {exprs.Expr.print_list(other_order_by_clauses)}"
415
+ ))
416
+ order_by_items = combined
417
+
418
+ if len(analyzer.group_by_clause) > 0:
419
+ agg_ordering = [(e, None) for e in analyzer.group_by_clause] + [(e, True) for e in analyzer.agg_order_by]
420
+ if len(order_by_items) > 0:
421
+ # check for compatibility
422
+ combined = _get_combined_ordering(order_by_items, agg_ordering)
423
+ if len(combined) == 0:
424
+ raise excs.Error((
425
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
426
+ f"grouping expressions:\n"
427
+ f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
428
+ f"{exprs.Expr.print_list([e for e, _ in agg_ordering])}"
429
+ ))
430
+ order_by_items = combined
431
+ else:
432
+ order_by_items = agg_ordering
433
+
434
+ if len(analyzer.order_by_clause) > 0:
435
+ if len(order_by_items) > 0:
436
+ # check for compatibility
437
+ combined = _get_combined_ordering(order_by_items, analyzer.order_by_clause)
438
+ if len(combined) == 0:
439
+ raise excs.Error((
440
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
441
+ f"order-by expressions:\n"
442
+ f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
443
+ f"{exprs.Expr.print_list([e for e, _ in analyzer.order_by_clause])}"
444
+ ))
445
+ order_by_items = combined
446
+ else:
447
+ order_by_items = analyzer.order_by_clause
448
+
449
+ # TODO: can this be unified with the same logic in RowBuilder
450
+ def refs_unstored_iter_col(e: exprs.Expr) -> bool:
451
+ if not isinstance(e, exprs.ColumnRef):
452
+ return False
453
+ tbl = e.col.tbl
454
+ return tbl.is_component_view() and tbl.is_iterator_column(e.col) and not e.col.is_stored
455
+ unstored_iter_col_refs = list(exprs.Expr.list_subexprs(analyzer.all_exprs, filter=refs_unstored_iter_col))
456
+ if len(unstored_iter_col_refs) > 0 and len(order_by_items) == 0:
457
+ # we don't already have a user-requested ordering and we access unstored iterator columns:
458
+ # order by the primary key of the component view, which minimizes the number of iterator instantiations
459
+ component_views = {e.col.tbl for e in unstored_iter_col_refs}
460
+ # TODO: generalize this to multi-level iteration
461
+ assert len(component_views) == 1
462
+ component_view = list(component_views)[0]
463
+ order_by_items = [
464
+ (exprs.RowidRef(component_view, idx), None)
465
+ for idx in range(len(component_view.store_tbl.rowid_columns()))
466
+ ]
467
+ order_by_origin = unstored_iter_col_refs[0]
468
+
469
+ for e in [e for e, _ in order_by_items]:
470
+ if e.sql_expr() is None:
471
+ raise excs.Error(f'order_by element cannot be expressed in SQL: {e}')
472
+ # we do ascending ordering by default, if not specified otherwise
473
+ order_by_items = [(e, True) if asc is None else (e, asc) for e, asc in order_by_items]
474
+ return order_by_items
475
+
476
+ @classmethod
477
+ def _is_contained_in(cls, l1: List[exprs.Expr], l2: List[exprs.Expr]) -> bool:
478
+ """Returns True if l1 is contained in l2"""
479
+ s1, s2 = set([e.id for e in l1]), set([e.id for e in l2])
480
+ return s1 <= s2
481
+
482
+ @classmethod
483
+ def _insert_prefetch_node(
484
+ cls, tbl_id: UUID, output_exprs: List[exprs.Expr], row_builder: exprs.RowBuilder, input: exec.ExecNode
485
+ ) -> exec.ExecNode:
486
+ """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
487
+ # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
488
+ # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
489
+ # aren't explicitly captured as dependencies
490
+ media_col_refs = [
491
+ e for e in list(row_builder.unique_exprs) if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()
492
+ ]
493
+ if len(media_col_refs) == 0:
494
+ return input
495
+ # we need to prefetch external files for media column types
496
+ file_col_info = [exprs.ColumnSlotIdx(e.col, e.slot_idx) for e in media_col_refs]
497
+ prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input)
498
+ return prefetch_node
499
+
500
+ @classmethod
501
+ def create_query_plan(
502
+ cls, tbl: catalog.TableVersionPath, select_list: Optional[List[exprs.Expr]] = None,
503
+ where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
504
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, limit: Optional[int] = None,
505
+ with_pk: bool = False, ignore_errors: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
506
+ ) -> exec.ExecNode:
507
+ """Return plan for executing a query.
508
+ Updates 'select_list' in place to make it executable.
509
+ TODO: make exact_version_only a flag and use the versions from tbl
510
+ """
511
+ if select_list is None:
512
+ select_list = []
513
+ if group_by_clause is None:
514
+ group_by_clause = []
515
+ if order_by_clause is None:
516
+ order_by_clause = []
517
+ if exact_version_only is None:
518
+ exact_version_only = []
519
+ analyzer = Analyzer(
520
+ tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
521
+ order_by_clause=order_by_clause)
522
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], analyzer.sql_exprs)
523
+
524
+ analyzer.finalize(row_builder)
525
+ # select_list: we need to materialize everything that's been collected
526
+ # with_pk: for now, we always retrieve the PK, because we need it for the file cache
527
+ plan = cls._create_query_plan(
528
+ tbl, row_builder, analyzer=analyzer, limit=limit, with_pk=True, exact_version_only=exact_version_only)
529
+ plan.ctx.ignore_errors = ignore_errors
530
+ select_list.clear()
531
+ select_list.extend(analyzer.select_list)
532
+ return plan
533
+
534
+ @classmethod
535
+ def _create_query_plan(
536
+ cls, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder, analyzer: Analyzer,
537
+ limit: Optional[int] = None, with_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
538
+ ) -> exec.ExecNode:
539
+ """
540
+ Args:
541
+ plan_target: if not None, generate a plan that materializes only expression that can be evaluted
542
+ in the context of that table version (eg, if 'tbl' is a view, 'plan_target' might be the base)
543
+ TODO: make exact_version_only a flag and use the versions from tbl
544
+ """
545
+ if exact_version_only is None:
546
+ exact_version_only = []
547
+ assert isinstance(tbl, catalog.TableVersionPath)
548
+ is_agg_query = len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0
549
+ ctx = exec.ExecContext(row_builder)
550
+
551
+ order_by_items = cls._determine_ordering(analyzer)
552
+ sql_limit = 0 if is_agg_query else limit # if we're aggregating, the limit applies to the agg output
553
+ sql_select_list = analyzer.sql_exprs.copy()
554
+ plan = exec.SqlScanNode(
555
+ tbl, row_builder, select_list=sql_select_list, where_clause=analyzer.sql_where_clause,
556
+ filter=analyzer.filter, order_by_items=order_by_items,
557
+ limit=sql_limit, set_pk=with_pk, exact_version_only=exact_version_only)
558
+ plan = cls._insert_prefetch_node(tbl.tbl_version.id, analyzer.select_list, row_builder, plan)
559
+
560
+ if len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0:
561
+ # we're doing aggregation; the input of the AggregateNode are the grouping exprs plus the
562
+ # args of the agg fn calls
563
+ agg_input = exprs.ExprSet(analyzer.group_by_clause.copy())
564
+ for fn_call in analyzer.agg_fn_calls:
565
+ agg_input.extend(fn_call.components)
566
+ if not cls._is_contained_in(agg_input, analyzer.sql_exprs):
567
+ # we need an ExprEvalNode
568
+ plan = exec.ExprEvalNode(row_builder, agg_input, analyzer.sql_exprs, input=plan)
569
+
570
+ # batch size for aggregation input: this could be the entire table, so we need to divide it into
571
+ # smaller batches; at the same time, we need to make the batches large enough to amortize the
572
+ # function call overhead
573
+ # TODO: increase this if we have NOS calls in order to reduce the cost of switching models, but take
574
+ # into account the amount of memory needed for intermediate images
575
+ ctx.batch_size = 16
576
+
577
+ plan = exec.AggregationNode(
578
+ tbl.tbl_version, row_builder, analyzer.group_by_clause, analyzer.agg_fn_calls, agg_input, input=plan)
579
+ agg_output = analyzer.group_by_clause + analyzer.agg_fn_calls
580
+ if not cls._is_contained_in(analyzer.select_list, agg_output):
581
+ # we need an ExprEvalNode to evaluate the remaining output exprs
582
+ plan = exec.ExprEvalNode(
583
+ row_builder, analyzer.select_list, agg_output, input=plan)
584
+ else:
585
+ if not cls._is_contained_in(analyzer.select_list, analyzer.sql_exprs):
586
+ # we need an ExprEvalNode to evaluate the remaining output exprs
587
+ plan = exec.ExprEvalNode(row_builder, analyzer.select_list, analyzer.sql_exprs, input=plan)
588
+ # we're returning everything to the user, so we might as well do it in a single batch
589
+ ctx.batch_size = 0
590
+
591
+ plan.set_ctx(ctx)
592
+ return plan
593
+
594
+ @classmethod
595
+ def analyze(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Predicate) -> Analyzer:
596
+ return Analyzer(tbl, [], where_clause=where_clause)
597
+
598
+ @classmethod
599
+ def create_add_column_plan(
600
+ cls, tbl: catalog.TableVersionPath, col: catalog.Column
601
+ ) -> Tuple[exec.ExecNode, Optional[int]]:
602
+ """Creates a plan for InsertableTable.add_column()
603
+ Returns:
604
+ plan: the plan to execute
605
+ value_expr slot idx for the plan output (for computed cols)
606
+ """
607
+ assert isinstance(tbl, catalog.TableVersionPath)
608
+ index_info: List[Tuple[catalog.Column, func.Function]] = []
609
+ row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
610
+ analyzer = Analyzer(tbl, row_builder.default_eval_ctx.target_exprs)
611
+ plan = cls._create_query_plan(tbl, row_builder=row_builder, analyzer=analyzer, with_pk=True)
612
+ plan.ctx.batch_size = 16
613
+ plan.ctx.show_pbar = True
614
+ plan.ctx.ignore_errors = True
615
+
616
+ # we want to flush images
617
+ if col.is_computed and col.is_stored and col.col_type.is_image_type():
618
+ plan.set_stored_img_cols(row_builder.output_slot_idxs())
619
+ value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
620
+ return plan, value_expr_slot_idx