pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (147) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +590 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +359 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +195 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +34 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +256 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +122 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +418 -182
  88. pixeltable/tests/conftest.py +146 -88
  89. pixeltable/tests/functions/test_fireworks.py +42 -0
  90. pixeltable/tests/functions/test_functions.py +60 -0
  91. pixeltable/tests/functions/test_huggingface.py +158 -0
  92. pixeltable/tests/functions/test_openai.py +152 -0
  93. pixeltable/tests/functions/test_together.py +111 -0
  94. pixeltable/tests/test_audio.py +65 -0
  95. pixeltable/tests/test_catalog.py +27 -0
  96. pixeltable/tests/test_client.py +14 -14
  97. pixeltable/tests/test_component_view.py +370 -0
  98. pixeltable/tests/test_dataframe.py +439 -0
  99. pixeltable/tests/test_dirs.py +78 -62
  100. pixeltable/tests/test_document.py +120 -0
  101. pixeltable/tests/test_exprs.py +592 -135
  102. pixeltable/tests/test_function.py +297 -67
  103. pixeltable/tests/test_migration.py +43 -0
  104. pixeltable/tests/test_nos.py +54 -0
  105. pixeltable/tests/test_snapshot.py +208 -0
  106. pixeltable/tests/test_table.py +1195 -263
  107. pixeltable/tests/test_transactional_directory.py +42 -0
  108. pixeltable/tests/test_types.py +5 -11
  109. pixeltable/tests/test_video.py +151 -34
  110. pixeltable/tests/test_view.py +530 -0
  111. pixeltable/tests/utils.py +320 -45
  112. pixeltable/tool/create_test_db_dump.py +149 -0
  113. pixeltable/tool/create_test_video.py +81 -0
  114. pixeltable/type_system.py +445 -124
  115. pixeltable/utils/__init__.py +17 -46
  116. pixeltable/utils/arrow.py +98 -0
  117. pixeltable/utils/clip.py +12 -15
  118. pixeltable/utils/coco.py +136 -0
  119. pixeltable/utils/documents.py +39 -0
  120. pixeltable/utils/filecache.py +195 -0
  121. pixeltable/utils/help.py +11 -0
  122. pixeltable/utils/hf_datasets.py +157 -0
  123. pixeltable/utils/media_store.py +76 -0
  124. pixeltable/utils/parquet.py +167 -0
  125. pixeltable/utils/pytorch.py +91 -0
  126. pixeltable/utils/s3.py +13 -0
  127. pixeltable/utils/sql.py +17 -0
  128. pixeltable/utils/transactional_directory.py +35 -0
  129. pixeltable-0.2.4.dist-info/LICENSE +18 -0
  130. pixeltable-0.2.4.dist-info/METADATA +127 -0
  131. pixeltable-0.2.4.dist-info/RECORD +132 -0
  132. {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
  133. pixeltable/catalog.py +0 -1421
  134. pixeltable/exprs.py +0 -1745
  135. pixeltable/function.py +0 -269
  136. pixeltable/functions/clip.py +0 -10
  137. pixeltable/functions/pil/__init__.py +0 -23
  138. pixeltable/functions/tf.py +0 -21
  139. pixeltable/index.py +0 -57
  140. pixeltable/tests/test_dict.py +0 -24
  141. pixeltable/tests/test_functions.py +0 -11
  142. pixeltable/tests/test_tf.py +0 -69
  143. pixeltable/tf.py +0 -33
  144. pixeltable/utils/tf.py +0 -33
  145. pixeltable/utils/video.py +0 -32
  146. pixeltable-0.1.0.dist-info/METADATA +0 -34
  147. pixeltable-0.1.0.dist-info/RECORD +0 -36
pixeltable/plan.py ADDED
@@ -0,0 +1,656 @@
1
+ from typing import Tuple, Optional, List, Set, Any, Dict
2
+ from uuid import UUID
3
+
4
+ import sqlalchemy as sql
5
+
6
+ import pixeltable.exec as exec
7
+ import pixeltable.func as func
8
+ from pixeltable import catalog
9
+ from pixeltable import exceptions as excs
10
+ from pixeltable import exprs
11
+
12
+
13
+ def _is_agg_fn_call(e: exprs.Expr) -> bool:
14
+ return isinstance(e, exprs.FunctionCall) and e.is_agg_fn_call and not e.is_window_fn_call
15
+
16
+ def _get_combined_ordering(
17
+ o1: List[Tuple[exprs.Expr, bool]], o2: List[Tuple[exprs.Expr, bool]]
18
+ ) -> List[Tuple[exprs.Expr, bool]]:
19
+ """Returns an ordering that's compatible with both o1 and o2, or an empty list if no such ordering exists"""
20
+ result: List[Tuple[exprs.Expr, bool]] = []
21
+ # determine combined ordering
22
+ for (e1, asc1), (e2, asc2) in zip(o1, o2):
23
+ if e1.id != e2.id:
24
+ return []
25
+ if asc1 is not None and asc2 is not None and asc1 != asc2:
26
+ return []
27
+ asc = asc1 if asc1 is not None else asc2
28
+ result.append((e1, asc))
29
+
30
+ # add remaining ordering of the longer list
31
+ prefix_len = min(len(o1), len(o2))
32
+ if len(o1) > prefix_len:
33
+ result.extend(o1[prefix_len:])
34
+ elif len(o2) > prefix_len:
35
+ result.extend(o2[prefix_len:])
36
+ return result
37
+
38
+ class Analyzer:
39
+ """Class to perform semantic analysis of a query and to store the analysis state"""
40
+
41
+ def __init__(
42
+ self, tbl: catalog.TableVersionPath, select_list: List[exprs.Expr],
43
+ where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
44
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None):
45
+ if group_by_clause is None:
46
+ group_by_clause = []
47
+ if order_by_clause is None:
48
+ order_by_clause = []
49
+ self.tbl = tbl
50
+
51
+ # remove references to unstored computed cols
52
+ self.select_list = [e.resolve_computed_cols() for e in select_list]
53
+ if where_clause is not None:
54
+ where_clause = where_clause.resolve_computed_cols()
55
+ self.group_by_clause = [e.resolve_computed_cols() for e in group_by_clause]
56
+ self.order_by_clause = [(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
57
+
58
+ # Where clause of the Select stmt of the SQL scan
59
+ self.sql_where_clause: Optional[exprs.Expr] = None
60
+ # filter predicate applied to output rows of the SQL scan
61
+ self.filter: Optional[exprs.Predicate] = None
62
+ # not executable
63
+ self.similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None
64
+ if where_clause is not None:
65
+ where_clause_conjuncts, self.filter = where_clause.split_conjuncts(lambda e: e.sql_expr() is not None)
66
+ self.sql_where_clause = exprs.CompoundPredicate.make_conjunction(where_clause_conjuncts)
67
+ if self.filter is not None:
68
+ similarity_clauses, self.filter = self.filter.split_conjuncts(
69
+ lambda e: isinstance(e, exprs.ImageSimilarityPredicate))
70
+ if len(similarity_clauses) > 1:
71
+ raise excs.Error(f'More than one nearest() not supported')
72
+ if len(similarity_clauses) == 1:
73
+ if len(self.order_by_clause) > 0:
74
+ raise excs.Error((
75
+ f'nearest() returns results in order of proximity and cannot be used in conjunction with '
76
+ f'order_by()'))
77
+ self.similarity_clause = similarity_clauses[0]
78
+ img_col = self.similarity_clause.img_col_ref.col
79
+ if not img_col.is_indexed:
80
+ raise excs.Error(f'nearest() not available for unindexed column {img_col.name}')
81
+
82
+ # all exprs that are evaluated in Python; not executable
83
+ self.all_exprs = self.select_list.copy()
84
+ self.all_exprs.extend(self.group_by_clause)
85
+ self.all_exprs.extend([e for e, _ in self.order_by_clause])
86
+ if self.filter is not None:
87
+ self.all_exprs.append(self.filter)
88
+ self.sql_exprs = list(exprs.Expr.list_subexprs(
89
+ self.all_exprs, filter=lambda e: e.sql_expr() is not None, traverse_matches=False))
90
+
91
+ # sql_exprs: exprs that can be expressed via SQL and are retrieved directly from the store
92
+ # (we don't want to materialize literals via SQL, so we remove them here)
93
+ self.sql_exprs = [e for e in self.sql_exprs if not isinstance(e, exprs.Literal)]
94
+
95
+ self.agg_fn_calls: List[exprs.FunctionCall] = []
96
+ self.agg_order_by: List[exprs.Expr] = []
97
+ self._analyze_agg()
98
+
99
+ def _analyze_agg(self) -> None:
100
+ """Check semantic correctness of aggregation and fill in agg-specific fields of Analyzer"""
101
+ self.agg_fn_calls = [e for e in self.all_exprs if _is_agg_fn_call(e)]
102
+ if len(self.agg_fn_calls) == 0:
103
+ # nothing to do
104
+ return
105
+
106
+ # check that select list only contains aggregate output
107
+ grouping_expr_ids = {e.id for e in self.group_by_clause}
108
+ is_agg_output = [self._determine_agg_status(e, grouping_expr_ids)[0] for e in self.select_list]
109
+ if is_agg_output.count(False) > 0:
110
+ raise excs.Error(
111
+ f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}')
112
+
113
+ # check that filter doesn't contain aggregates
114
+ if self.filter is not None:
115
+ agg_fn_calls = [e for e in self.filter.subexprs(filter=lambda e: _is_agg_fn_call(e))]
116
+ if len(agg_fn_calls) > 0:
117
+ raise excs.Error(f'Filter cannot contain aggregate functions: {self.filter}')
118
+
119
+ # check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
120
+ # aggregation and rely on the SqlScanNode returning data in the correct order)
121
+ for e in self.group_by_clause:
122
+ if e.sql_expr() is None:
123
+ raise excs.Error(f'Invalid grouping expression, needs to be expressible in SQL: {e}')
124
+ if e.contains(filter=lambda e: _is_agg_fn_call(e)):
125
+ raise excs.Error(f'Grouping expression contains aggregate function: {e}')
126
+
127
+ # check that agg fn calls don't have contradicting ordering requirements
128
+ order_by: List[exprs.Exprs] = []
129
+ order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
130
+ for agg_fn_call in self.agg_fn_calls:
131
+ fn_call_order_by = agg_fn_call.get_agg_order_by()
132
+ if len(fn_call_order_by) == 0:
133
+ continue
134
+ if len(order_by) == 0:
135
+ order_by = fn_call_order_by
136
+ order_by_origin = agg_fn_call
137
+ else:
138
+ combined = _get_combined_ordering(
139
+ [(e, True) for e in order_by], [(e, True) for e in fn_call_order_by])
140
+ if len(combined) == 0:
141
+ raise excs.Error((
142
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
143
+ f"'{agg_fn_call}':\n"
144
+ f"{exprs.Expr.print_list(order_by)} vs {exprs.Expr.print_list(fn_call_order_by)}"
145
+ ))
146
+ self.agg_order_by = order_by
147
+
148
+ def _determine_agg_status(self, e: exprs.Expr, grouping_expr_ids: Set[int]) -> Tuple[bool, bool]:
149
+ """Determine whether expr is the input to or output of an aggregate function.
150
+ Returns:
151
+ (<is output>, <is input>)
152
+ """
153
+ if e.id in grouping_expr_ids:
154
+ return True, True
155
+ elif _is_agg_fn_call(e):
156
+ for c in e.components:
157
+ _, is_input = self._determine_agg_status(c, grouping_expr_ids)
158
+ if not is_input:
159
+ raise excs.Error(f'Invalid nested aggregates: {e}')
160
+ return True, False
161
+ elif isinstance(e, exprs.Literal):
162
+ return True, True
163
+ elif isinstance(e, exprs.ColumnRef) or isinstance(e, exprs.RowidRef):
164
+ # we already know that this isn't a grouping expr
165
+ return False, True
166
+ else:
167
+ # an expression such as <grouping expr 1> + <grouping expr 2> can both be the output and input of agg
168
+ assert len(e.components) > 0
169
+ component_is_output, component_is_input = zip(
170
+ *[self._determine_agg_status(c, grouping_expr_ids) for c in e.components])
171
+ is_output = component_is_output.count(True) == len(e.components)
172
+ is_input = component_is_input.count(True) == len(e.components)
173
+ if not is_output and not is_input:
174
+ raise excs.Error(f'Invalid expression, mixes aggregate with non-aggregate: {e}')
175
+ return is_output, is_input
176
+
177
+
178
+ def finalize(self, row_builder: exprs.RowBuilder) -> None:
179
+ """Make all exprs executable
180
+ TODO: add EvalCtx for each expr list?
181
+ """
182
+ # maintain original composition of select list
183
+ row_builder.substitute_exprs(self.select_list, remove_duplicates=False)
184
+ row_builder.substitute_exprs(self.group_by_clause)
185
+ order_by_exprs = [e for e, _ in self.order_by_clause]
186
+ row_builder.substitute_exprs(order_by_exprs)
187
+ self.order_by_clause = [(e, asc) for e, (_, asc) in zip(order_by_exprs, self.order_by_clause)]
188
+ row_builder.substitute_exprs(self.all_exprs)
189
+ row_builder.substitute_exprs(self.sql_exprs)
190
+ if self.filter is not None:
191
+ self.filter = row_builder.unique_exprs[self.filter]
192
+ row_builder.substitute_exprs(self.agg_fn_calls)
193
+ row_builder.substitute_exprs(self.agg_order_by)
194
+
195
+
196
+ class Planner:
197
+ # TODO: create an exec.CountNode and change this to create_count_plan()
198
+ @classmethod
199
+ def create_count_stmt(
200
+ cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Predicate] = None
201
+ ) -> sql.Select:
202
+ stmt = sql.select(sql.func.count('*'))
203
+ refd_tbl_ids: Set[UUID] = set()
204
+ if where_clause is not None:
205
+ analyzer = cls.analyze(tbl, where_clause)
206
+ if analyzer.similarity_clause is not None:
207
+ raise excs.Error('nearest() cannot be used with count()')
208
+ if analyzer.filter is not None:
209
+ raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
210
+ clause_element = analyzer.sql_where_clause.sql_expr()
211
+ assert clause_element is not None
212
+ stmt = stmt.where(clause_element)
213
+ refd_tbl_ids = where_clause.tbl_ids()
214
+ stmt = exec.SqlScanNode.create_from_clause(tbl, stmt, refd_tbl_ids)
215
+ return stmt
216
+
217
+ @classmethod
218
+ def create_insert_plan(
219
+ cls, tbl: catalog.TableVersion, rows: List[Dict[str, Any]], ignore_errors: bool
220
+ ) -> exec.ExecNode:
221
+ """Creates a plan for TableVersion.insert()"""
222
+ assert not tbl.is_view()
223
+ # things we need to materialize:
224
+ # 1. stored_cols: all cols we need to store, incl computed cols (and indices)
225
+ stored_cols = [c for c in tbl.cols if c.is_stored]
226
+ assert len(stored_cols) > 0
227
+ # 2. values to insert into indices
228
+ indexed_cols = [c for c in tbl.cols if c.is_indexed]
229
+ index_info: List[Tuple[catalog.Column, func.Function]] = []
230
+ if len(indexed_cols) > 0:
231
+ from pixeltable.functions.nos.image_embedding import openai_clip
232
+ index_info = [(c, openai_clip) for c in tbl.cols if c.is_indexed]
233
+
234
+ row_builder = exprs.RowBuilder([], stored_cols, index_info, [])
235
+
236
+ # create InMemoryDataNode for 'rows'
237
+ stored_col_info = row_builder.output_slot_idxs()
238
+ stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
239
+ input_col_info = [info for info in stored_col_info if not info.col.is_computed]
240
+ plan = exec.InMemoryDataNode(tbl, rows, row_builder, tbl.next_rowid)
241
+
242
+ media_input_cols = [info for info in input_col_info if info.col.col_type.is_media_type()]
243
+
244
+ # prefetch external files for all input column refs for validation
245
+ plan = exec.CachePrefetchNode(tbl.id, media_input_cols, plan)
246
+ plan = exec.MediaValidationNode(row_builder, media_input_cols, input=plan)
247
+
248
+ computed_exprs = row_builder.default_eval_ctx.target_exprs
249
+ if len(computed_exprs) > 0:
250
+ # add an ExprEvalNode when there are exprs to compute
251
+ plan = exec.ExprEvalNode(row_builder, computed_exprs, [], input=plan)
252
+
253
+ plan.set_stored_img_cols(stored_img_col_info)
254
+ plan.set_ctx(
255
+ exec.ExecContext(
256
+ row_builder, batch_size=0, show_pbar=True, num_computed_exprs=len(computed_exprs),
257
+ ignore_errors=ignore_errors))
258
+ return plan
259
+
260
+ @classmethod
261
+ def create_update_plan(
262
+ cls, tbl: catalog.TableVersionPath,
263
+ update_targets: List[Tuple[catalog.Column, exprs.Expr]],
264
+ recompute_targets: List[catalog.Column],
265
+ where_clause: Optional[exprs.Predicate], cascade: bool
266
+ ) -> Tuple[exec.ExecNode, List[str], List[catalog.Column]]:
267
+ """Creates a plan to materialize updated rows.
268
+ The plan:
269
+ - retrieves rows that are visible at the current version of the table
270
+ - materializes all stored columns and the update targets
271
+ - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
272
+ and copies the values of all other stored columns
273
+ - if cascade is False, copies all columns that aren't update targets from the original rows
274
+ Returns:
275
+ - root node of the plan
276
+ - list of qualified column names that are getting updated
277
+ - list of columns that are being recomputed
278
+ """
279
+ # retrieve all stored cols and all target exprs
280
+ assert isinstance(tbl, catalog.TableVersionPath)
281
+ target = tbl.tbl_version # the one we need to update
282
+ updated_cols = [col for col, _ in update_targets]
283
+ if len(recompute_targets) > 0:
284
+ recomputed_cols = recompute_targets.copy()
285
+ else:
286
+ recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else {}
287
+ # we only need to recompute stored columns (unstored ones are substituted away)
288
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
289
+ recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
290
+ copied_cols = [
291
+ col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
292
+ ]
293
+ select_list = [exprs.ColumnRef(col) for col in copied_cols]
294
+ select_list.extend([expr for _, expr in update_targets])
295
+
296
+ recomputed_exprs = \
297
+ [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
298
+ # recomputed cols reference the new values of the updated cols
299
+ for col, e in update_targets:
300
+ exprs.Expr.list_substitute(recomputed_exprs, exprs.ColumnRef(col), e)
301
+ select_list.extend(recomputed_exprs)
302
+
303
+ # we need to retrieve the PK columns of the existing rows
304
+ plan = cls.create_query_plan(tbl, select_list, where_clause=where_clause, with_pk=True, ignore_errors=True)
305
+ all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
306
+ # update row builder with column information
307
+ [plan.row_builder.add_table_column(col, select_list[i].slot_idx) for i, col in enumerate(all_base_cols)]
308
+ return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + list(recomputed_cols)], list(recomputed_cols)
309
+
310
+ @classmethod
311
+ def create_view_update_plan(
312
+ cls, view: catalog.TableVersionPath, recompute_targets: List[catalog.Column]
313
+ ) -> exec.ExecNode:
314
+ """Creates a plan to materialize updated rows for a view, given that the base table has been updated.
315
+ The plan:
316
+ - retrieves rows that are visible at the current version of the table and satisfy the view predicate
317
+ - materializes all stored columns and the update targets
318
+ - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
319
+ and copies the values of all other stored columns
320
+ - if cascade is False, copies all columns that aren't update targets from the original rows
321
+
322
+ TODO: unify with create_view_load_plan()
323
+
324
+ Returns:
325
+ - root node of the plan
326
+ - list of qualified column names that are getting updated
327
+ - list of columns that are being recomputed
328
+ """
329
+ assert isinstance(view, catalog.TableVersionPath)
330
+ assert view.is_view()
331
+ target = view.tbl_version # the one we need to update
332
+ # retrieve all stored cols and all target exprs
333
+ recomputed_cols = set(recompute_targets.copy())
334
+ copied_cols = [col for col in target.cols if col.is_stored and not col in recomputed_cols]
335
+ select_list = [exprs.ColumnRef(col) for col in copied_cols]
336
+ # resolve recomputed exprs to stored columns in the base
337
+ recomputed_exprs = \
338
+ [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_cols) for c in recomputed_cols]
339
+ select_list.extend(recomputed_exprs)
340
+
341
+ # we need to retrieve the PK columns of the existing rows
342
+ plan = cls.create_query_plan(
343
+ view, select_list, where_clause=target.predicate, with_pk=True, ignore_errors=True,
344
+ exact_version_only=view.get_bases())
345
+ [
346
+ plan.row_builder.add_table_column(col, select_list[i].slot_idx)
347
+ for i, col in enumerate(copied_cols + list(recomputed_cols)) # same order as select_list
348
+ ]
349
+ # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
350
+ stored_img_col_info = \
351
+ [info for info in plan.row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
352
+ plan.set_stored_img_cols(stored_img_col_info)
353
+ return plan
354
+
355
+ @classmethod
356
+ def create_view_load_plan(
357
+ cls, view: catalog.TableVersionPath, propagates_insert: bool = False
358
+ ) -> Tuple[exec.ExecNode, int]:
359
+ """Creates a query plan for populating a view.
360
+
361
+ Args:
362
+ view: the view to populate
363
+ propagates_insert: if True, we're propagating a base update to this view
364
+
365
+ Returns:
366
+ - root node of the plan
367
+ - number of materialized values per row
368
+ """
369
+ assert isinstance(view, catalog.TableVersionPath)
370
+ assert view.is_view()
371
+ # things we need to materialize as DataRows:
372
+ # 1. stored computed cols
373
+ # - iterator columns are effectively computed, just not with a value_expr
374
+ # - we can ignore stored non-computed columns because they have a default value that is supplied directly by
375
+ # the store
376
+ target = view.tbl_version # the one we need to populate
377
+ stored_cols = [c for c in target.cols if c.is_stored and (c.is_computed or target.is_iterator_column(c))]
378
+ # 2. index values
379
+ indexed_cols = [c for c in target.cols if c.is_indexed]
380
+ index_info: List[Tuple[catalog.Column, func.Function]] = []
381
+ if len(indexed_cols) > 0:
382
+ from pixeltable.functions.nos.image_embedding import openai_clip
383
+ index_info = [(c, openai_clip) for c in target.cols if c.is_indexed]
384
+ # 3. for component views: iterator args
385
+ iterator_args = [target.iterator_args] if target.iterator_args is not None else []
386
+
387
+ row_builder = exprs.RowBuilder(iterator_args, stored_cols, index_info, [])
388
+
389
+ # execution plan:
390
+ # 1. materialize exprs computed from the base that are needed for stored view columns
391
+ # 2. if it's an iterator view, expand the base rows into component rows
392
+ # 3. materialize stored view columns that haven't been produced by step 1
393
+ base_output_exprs = [e for e in row_builder.default_eval_ctx.exprs if e.is_bound_by(view.base)]
394
+ view_output_exprs = [
395
+ e for e in row_builder.default_eval_ctx.target_exprs
396
+ if e.is_bound_by(view) and not e.is_bound_by(view.base)
397
+ ]
398
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
399
+ base_analyzer = Analyzer(view, base_output_exprs, where_clause=target.predicate)
400
+ plan = cls._create_query_plan(
401
+ view.base, row_builder=row_builder, analyzer=base_analyzer, with_pk=True,
402
+ exact_version_only=view.get_bases() if propagates_insert else [])
403
+ exec_ctx = plan.ctx
404
+ if target.is_component_view():
405
+ plan = exec.ComponentIterationNode(target, plan)
406
+ if len(view_output_exprs) > 0:
407
+ plan = exec.ExprEvalNode(
408
+ row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs,input=plan)
409
+
410
+ stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
411
+ plan.set_stored_img_cols(stored_img_col_info)
412
+ exec_ctx.ignore_errors = True
413
+ plan.set_ctx(exec_ctx)
414
+ return plan, len(row_builder.default_eval_ctx.target_exprs)
415
+
416
+ @classmethod
417
+ def _determine_ordering(cls, analyzer: Analyzer) -> List[Tuple[exprs.Expr, bool]]:
418
+ """Returns the exprs for the ORDER BY clause of the SqlScanNode"""
419
+ order_by_items: List[Tuple[exprs.Expr, Optional[bool]]] = []
420
+ order_by_origin: Optional[exprs.Expr] = None # the expr that determines the ordering
421
+
422
+
423
+ # window functions require ordering by the group_by/order_by clauses
424
+ window_fn_calls = [
425
+ e for e in analyzer.all_exprs if isinstance(e, exprs.FunctionCall) and e.is_window_fn_call
426
+ ]
427
+ if len(window_fn_calls) > 0:
428
+ for fn_call in window_fn_calls:
429
+ gb, ob = fn_call.get_window_sort_exprs()
430
+ # for now, the ordering is implicitly ascending
431
+ fn_call_ordering = [(e, None) for e in gb] + [(e, True) for e in ob]
432
+ if len(order_by_items) == 0:
433
+ order_by_items = fn_call_ordering
434
+ order_by_origin = fn_call
435
+ else:
436
+ # check for compatibility
437
+ other_order_by_clauses = fn_call_ordering
438
+ combined = _get_combined_ordering(order_by_items, other_order_by_clauses)
439
+ if len(combined) == 0:
440
+ raise excs.Error((
441
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
442
+ f"'{fn_call}':\n"
443
+ f"{exprs.Expr.print_list(order_by_items)} vs {exprs.Expr.print_list(other_order_by_clauses)}"
444
+ ))
445
+ order_by_items = combined
446
+
447
+ if len(analyzer.group_by_clause) > 0:
448
+ agg_ordering = [(e, None) for e in analyzer.group_by_clause] + [(e, True) for e in analyzer.agg_order_by]
449
+ if len(order_by_items) > 0:
450
+ # check for compatibility
451
+ combined = _get_combined_ordering(order_by_items, agg_ordering)
452
+ if len(combined) == 0:
453
+ raise excs.Error((
454
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
455
+ f"grouping expressions:\n"
456
+ f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
457
+ f"{exprs.Expr.print_list([e for e, _ in agg_ordering])}"
458
+ ))
459
+ order_by_items = combined
460
+ else:
461
+ order_by_items = agg_ordering
462
+
463
+ if len(analyzer.order_by_clause) > 0:
464
+ if len(order_by_items) > 0:
465
+ # check for compatibility
466
+ combined = _get_combined_ordering(order_by_items, analyzer.order_by_clause)
467
+ if len(combined) == 0:
468
+ raise excs.Error((
469
+ f"Incompatible ordering requirements between expressions '{order_by_origin}' and "
470
+ f"order-by expressions:\n"
471
+ f"{exprs.Expr.print_list([e for e, _ in order_by_items])} vs "
472
+ f"{exprs.Expr.print_list([e for e, _ in analyzer.order_by_clause])}"
473
+ ))
474
+ order_by_items = combined
475
+ else:
476
+ order_by_items = analyzer.order_by_clause
477
+
478
+ # TODO: can this be unified with the same logic in RowBuilder
479
+ def refs_unstored_iter_col(e: exprs.Expr) -> bool:
480
+ if not isinstance(e, exprs.ColumnRef):
481
+ return False
482
+ tbl = e.col.tbl
483
+ return tbl.is_component_view() and tbl.is_iterator_column(e.col) and not e.col.is_stored
484
+ unstored_iter_col_refs = list(exprs.Expr.list_subexprs(analyzer.all_exprs, filter=refs_unstored_iter_col))
485
+ if len(unstored_iter_col_refs) > 0 and len(order_by_items) == 0:
486
+ # we don't already have a user-requested ordering and we access unstored iterator columns:
487
+ # order by the primary key of the component view, which minimizes the number of iterator instantiations
488
+ component_views = {e.col.tbl for e in unstored_iter_col_refs}
489
+ # TODO: generalize this to multi-level iteration
490
+ assert len(component_views) == 1
491
+ component_view = list(component_views)[0]
492
+ order_by_items = [
493
+ (exprs.RowidRef(component_view, idx), None)
494
+ for idx in range(len(component_view.store_tbl.rowid_columns()))
495
+ ]
496
+ order_by_origin = unstored_iter_col_refs[0]
497
+
498
+ for e in [e for e, _ in order_by_items]:
499
+ if e.sql_expr() is None:
500
+ raise excs.Error(f'order_by element cannot be expressed in SQL: {e}')
501
+ # we do ascending ordering by default, if not specified otherwise
502
+ order_by_items = [(e, True) if asc is None else (e, asc) for e, asc in order_by_items]
503
+ return order_by_items
504
+
505
+ @classmethod
506
+ def _is_contained_in(cls, l1: List[exprs.Expr], l2: List[exprs.Expr]) -> bool:
507
+ """Returns True if l1 is contained in l2"""
508
+ s1, s2 = set([e.id for e in l1]), set([e.id for e in l2])
509
+ return s1 <= s2
510
+
511
+ @classmethod
512
+ def _insert_prefetch_node(
513
+ cls, tbl_id: UUID, output_exprs: List[exprs.Expr], row_builder: exprs.RowBuilder, input: exec.ExecNode
514
+ ) -> exec.ExecNode:
515
+ """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
516
+ # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
517
+ # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
518
+ # aren't explicitly captured as dependencies
519
+ media_col_refs = [
520
+ e for e in list(row_builder.unique_exprs) if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()
521
+ ]
522
+ if len(media_col_refs) == 0:
523
+ return input
524
+ # we need to prefetch external files for media column types
525
+ file_col_info = [exprs.ColumnSlotIdx(e.col, e.slot_idx) for e in media_col_refs]
526
+ prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input)
527
+ return prefetch_node
528
+
529
+ @classmethod
530
+ def create_query_plan(
531
+ cls, tbl: catalog.TableVersionPath, select_list: Optional[List[exprs.Expr]] = None,
532
+ where_clause: Optional[exprs.Predicate] = None, group_by_clause: Optional[List[exprs.Expr]] = None,
533
+ order_by_clause: Optional[List[Tuple[exprs.Expr, bool]]] = None, limit: Optional[int] = None,
534
+ with_pk: bool = False, ignore_errors: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
535
+ ) -> exec.ExecNode:
536
+ """Return plan for executing a query.
537
+ Updates 'select_list' in place to make it executable.
538
+ TODO: make exact_version_only a flag and use the versions from tbl
539
+ """
540
+ if select_list is None:
541
+ select_list = []
542
+ if group_by_clause is None:
543
+ group_by_clause = []
544
+ if order_by_clause is None:
545
+ order_by_clause = []
546
+ if exact_version_only is None:
547
+ exact_version_only = []
548
+ analyzer = Analyzer(
549
+ tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
550
+ order_by_clause=order_by_clause)
551
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [], analyzer.sql_exprs)
552
+
553
+ analyzer.finalize(row_builder)
554
+ # select_list: we need to materialize everything that's been collected
555
+ # with_pk: for now, we always retrieve the PK, because we need it for the file cache
556
+ plan = cls._create_query_plan(
557
+ tbl, row_builder, analyzer=analyzer, limit=limit, with_pk=True, exact_version_only=exact_version_only)
558
+ plan.ctx.ignore_errors = ignore_errors
559
+ select_list.clear()
560
+ select_list.extend(analyzer.select_list)
561
+ return plan
562
+
563
+ @classmethod
564
+ def _create_query_plan(
565
+ cls, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder, analyzer: Analyzer,
566
+ limit: Optional[int] = None, with_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
567
+ ) -> exec.ExecNode:
568
+ """
569
+ Args:
570
+ plan_target: if not None, generate a plan that materializes only expression that can be evaluted
571
+ in the context of that table version (eg, if 'tbl' is a view, 'plan_target' might be the base)
572
+ TODO: make exact_version_only a flag and use the versions from tbl
573
+ """
574
+ if exact_version_only is None:
575
+ exact_version_only = []
576
+ assert isinstance(tbl, catalog.TableVersionPath)
577
+ is_agg_query = len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0
578
+ ctx = exec.ExecContext(row_builder)
579
+
580
+ order_by_items = cls._determine_ordering(analyzer)
581
+ sql_limit = 0 if is_agg_query else limit # if we're aggregating, the limit applies to the agg output
582
+ sql_select_list = analyzer.sql_exprs.copy()
583
+ plan = exec.SqlScanNode(
584
+ tbl, row_builder, select_list=sql_select_list, where_clause=analyzer.sql_where_clause,
585
+ filter=analyzer.filter, similarity_clause=analyzer.similarity_clause, order_by_items=order_by_items,
586
+ limit=sql_limit, set_pk=with_pk, exact_version_only=exact_version_only)
587
+ plan = cls._insert_prefetch_node(tbl.tbl_version.id, analyzer.select_list, row_builder, plan)
588
+
589
+ if len(analyzer.group_by_clause) > 0 or len(analyzer.agg_fn_calls) > 0:
590
+ # we're doing aggregation; the input of the AggregateNode are the grouping exprs plus the
591
+ # args of the agg fn calls
592
+ agg_input = exprs.ExprSet(analyzer.group_by_clause.copy())
593
+ for fn_call in analyzer.agg_fn_calls:
594
+ agg_input.extend(fn_call.components)
595
+ if not cls._is_contained_in(agg_input, analyzer.sql_exprs):
596
+ # we need an ExprEvalNode
597
+ plan = exec.ExprEvalNode(row_builder, agg_input, analyzer.sql_exprs, input=plan)
598
+
599
+ # batch size for aggregation input: this could be the entire table, so we need to divide it into
600
+ # smaller batches; at the same time, we need to make the batches large enough to amortize the
601
+ # function call overhead
602
+ # TODO: increase this if we have NOS calls in order to reduce the cost of switching models, but take
603
+ # into account the amount of memory needed for intermediate images
604
+ ctx.batch_size = 16
605
+
606
+ plan = exec.AggregationNode(
607
+ tbl.tbl_version, row_builder, analyzer.group_by_clause, analyzer.agg_fn_calls, agg_input, input=plan)
608
+ agg_output = analyzer.group_by_clause + analyzer.agg_fn_calls
609
+ if not cls._is_contained_in(analyzer.select_list, agg_output):
610
+ # we need an ExprEvalNode to evaluate the remaining output exprs
611
+ plan = exec.ExprEvalNode(
612
+ row_builder, analyzer.select_list, agg_output, input=plan)
613
+ else:
614
+ if not cls._is_contained_in(analyzer.select_list, analyzer.sql_exprs):
615
+ # we need an ExprEvalNode to evaluate the remaining output exprs
616
+ plan = exec.ExprEvalNode(row_builder, analyzer.select_list, analyzer.sql_exprs, input=plan)
617
+ # we're returning everything to the user, so we might as well do it in a single batch
618
+ ctx.batch_size = 0
619
+
620
+ plan.set_ctx(ctx)
621
+ return plan
622
+
623
+ @classmethod
624
+ def analyze(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Predicate) -> Analyzer:
625
+ return Analyzer(tbl, [], where_clause=where_clause)
626
+
627
+ @classmethod
628
+ def create_add_column_plan(
629
+ cls, tbl: catalog.TableVersionPath, col: catalog.Column
630
+ ) -> Tuple[exec.ExecNode, Optional[int], Optional[int]]:
631
+ """Creates a plan for InsertableTable.add_column()
632
+ Returns:
633
+ plan: the plan to execute
634
+ ctx: the context to use for the plan
635
+ value_expr slot idx for the plan output (for computed cols)
636
+ embedding slot idx for the plan output (for indexed image cols)
637
+ """
638
+ assert isinstance(tbl, catalog.TableVersionPath)
639
+ index_info: List[Tuple[catalog.Column, func.Function]] = []
640
+ if col.is_indexed:
641
+ from pixeltable.functions.nos.image_embedding import openai_clip
642
+ index_info = [(col, openai_clip)]
643
+ row_builder = exprs.RowBuilder(
644
+ output_exprs=[], columns=[col], indices=index_info, input_exprs=[])
645
+ analyzer = Analyzer(tbl, row_builder.default_eval_ctx.target_exprs)
646
+ plan = cls._create_query_plan(tbl, row_builder=row_builder, analyzer=analyzer, with_pk=True)
647
+ plan.ctx.batch_size = 16
648
+ plan.ctx.show_pbar = True
649
+ plan.ctx.ignore_errors = True
650
+
651
+ # we want to flush images
652
+ if col.is_computed and col.is_stored and col.col_type.is_image_type():
653
+ plan.set_stored_img_cols(row_builder.output_slot_idxs())
654
+ value_expr_slot_idx: Optional[int] = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
655
+ embedding_slot_idx: Optional[int] = row_builder.index_slot_idxs()[0].slot_idx if col.is_indexed else None
656
+ return plan, value_expr_slot_idx, embedding_slot_idx