pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (202) hide show
  1. pixeltable/__init__.py +23 -5
  2. pixeltable/_version.py +1 -0
  3. pixeltable/catalog/__init__.py +5 -3
  4. pixeltable/catalog/catalog.py +1318 -404
  5. pixeltable/catalog/column.py +186 -115
  6. pixeltable/catalog/dir.py +1 -2
  7. pixeltable/catalog/globals.py +11 -43
  8. pixeltable/catalog/insertable_table.py +167 -79
  9. pixeltable/catalog/path.py +61 -23
  10. pixeltable/catalog/schema_object.py +9 -10
  11. pixeltable/catalog/table.py +626 -308
  12. pixeltable/catalog/table_metadata.py +101 -0
  13. pixeltable/catalog/table_version.py +713 -569
  14. pixeltable/catalog/table_version_handle.py +37 -6
  15. pixeltable/catalog/table_version_path.py +42 -29
  16. pixeltable/catalog/tbl_ops.py +50 -0
  17. pixeltable/catalog/update_status.py +191 -0
  18. pixeltable/catalog/view.py +108 -94
  19. pixeltable/config.py +128 -22
  20. pixeltable/dataframe.py +188 -100
  21. pixeltable/env.py +407 -136
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +3 -0
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +231 -0
  27. pixeltable/exec/cell_reconstruction_node.py +135 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +7 -6
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +190 -30
  37. pixeltable/exec/globals.py +32 -0
  38. pixeltable/exec/in_memory_data_node.py +18 -18
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +206 -101
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +34 -30
  46. pixeltable/exprs/column_ref.py +92 -96
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +152 -55
  50. pixeltable/exprs/expr.py +62 -43
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +75 -37
  54. pixeltable/exprs/globals.py +1 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +10 -27
  57. pixeltable/exprs/is_null.py +1 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +5 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +127 -53
  64. pixeltable/exprs/rowid_ref.py +8 -12
  65. pixeltable/exprs/similarity_expr.py +50 -25
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +10 -10
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +20 -18
  78. pixeltable/func/signature.py +43 -16
  79. pixeltable/func/tools.py +23 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +6 -0
  82. pixeltable/functions/anthropic.py +93 -33
  83. pixeltable/functions/audio.py +114 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +1 -1
  86. pixeltable/functions/deepseek.py +20 -9
  87. pixeltable/functions/fireworks.py +2 -2
  88. pixeltable/functions/gemini.py +28 -11
  89. pixeltable/functions/globals.py +13 -13
  90. pixeltable/functions/groq.py +108 -0
  91. pixeltable/functions/huggingface.py +1046 -23
  92. pixeltable/functions/image.py +9 -18
  93. pixeltable/functions/llama_cpp.py +23 -8
  94. pixeltable/functions/math.py +3 -4
  95. pixeltable/functions/mistralai.py +4 -15
  96. pixeltable/functions/ollama.py +16 -9
  97. pixeltable/functions/openai.py +104 -82
  98. pixeltable/functions/openrouter.py +143 -0
  99. pixeltable/functions/replicate.py +2 -2
  100. pixeltable/functions/reve.py +250 -0
  101. pixeltable/functions/string.py +21 -28
  102. pixeltable/functions/timestamp.py +13 -14
  103. pixeltable/functions/together.py +4 -6
  104. pixeltable/functions/twelvelabs.py +92 -0
  105. pixeltable/functions/util.py +6 -1
  106. pixeltable/functions/video.py +1388 -106
  107. pixeltable/functions/vision.py +7 -7
  108. pixeltable/functions/whisper.py +15 -7
  109. pixeltable/functions/whisperx.py +179 -0
  110. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  111. pixeltable/globals.py +332 -105
  112. pixeltable/index/base.py +13 -22
  113. pixeltable/index/btree.py +23 -22
  114. pixeltable/index/embedding_index.py +32 -44
  115. pixeltable/io/__init__.py +4 -2
  116. pixeltable/io/datarows.py +7 -6
  117. pixeltable/io/external_store.py +49 -77
  118. pixeltable/io/fiftyone.py +11 -11
  119. pixeltable/io/globals.py +29 -28
  120. pixeltable/io/hf_datasets.py +17 -9
  121. pixeltable/io/label_studio.py +70 -66
  122. pixeltable/io/lancedb.py +3 -0
  123. pixeltable/io/pandas.py +12 -11
  124. pixeltable/io/parquet.py +13 -93
  125. pixeltable/io/table_data_conduit.py +71 -47
  126. pixeltable/io/utils.py +3 -3
  127. pixeltable/iterators/__init__.py +2 -1
  128. pixeltable/iterators/audio.py +21 -11
  129. pixeltable/iterators/document.py +116 -55
  130. pixeltable/iterators/image.py +5 -2
  131. pixeltable/iterators/video.py +293 -13
  132. pixeltable/metadata/__init__.py +4 -2
  133. pixeltable/metadata/converters/convert_18.py +2 -2
  134. pixeltable/metadata/converters/convert_19.py +2 -2
  135. pixeltable/metadata/converters/convert_20.py +2 -2
  136. pixeltable/metadata/converters/convert_21.py +2 -2
  137. pixeltable/metadata/converters/convert_22.py +2 -2
  138. pixeltable/metadata/converters/convert_24.py +2 -2
  139. pixeltable/metadata/converters/convert_25.py +2 -2
  140. pixeltable/metadata/converters/convert_26.py +2 -2
  141. pixeltable/metadata/converters/convert_29.py +4 -4
  142. pixeltable/metadata/converters/convert_34.py +2 -2
  143. pixeltable/metadata/converters/convert_36.py +2 -2
  144. pixeltable/metadata/converters/convert_37.py +15 -0
  145. pixeltable/metadata/converters/convert_38.py +39 -0
  146. pixeltable/metadata/converters/convert_39.py +124 -0
  147. pixeltable/metadata/converters/convert_40.py +73 -0
  148. pixeltable/metadata/converters/util.py +13 -12
  149. pixeltable/metadata/notes.py +4 -0
  150. pixeltable/metadata/schema.py +79 -42
  151. pixeltable/metadata/utils.py +74 -0
  152. pixeltable/mypy/__init__.py +3 -0
  153. pixeltable/mypy/mypy_plugin.py +123 -0
  154. pixeltable/plan.py +274 -223
  155. pixeltable/share/__init__.py +1 -1
  156. pixeltable/share/packager.py +259 -129
  157. pixeltable/share/protocol/__init__.py +34 -0
  158. pixeltable/share/protocol/common.py +170 -0
  159. pixeltable/share/protocol/operation_types.py +33 -0
  160. pixeltable/share/protocol/replica.py +109 -0
  161. pixeltable/share/publish.py +213 -57
  162. pixeltable/store.py +238 -175
  163. pixeltable/type_system.py +104 -63
  164. pixeltable/utils/__init__.py +2 -3
  165. pixeltable/utils/arrow.py +108 -13
  166. pixeltable/utils/av.py +298 -0
  167. pixeltable/utils/azure_store.py +305 -0
  168. pixeltable/utils/code.py +3 -3
  169. pixeltable/utils/console_output.py +4 -1
  170. pixeltable/utils/coroutine.py +6 -23
  171. pixeltable/utils/dbms.py +31 -5
  172. pixeltable/utils/description_helper.py +4 -5
  173. pixeltable/utils/documents.py +5 -6
  174. pixeltable/utils/exception_handler.py +7 -30
  175. pixeltable/utils/filecache.py +6 -6
  176. pixeltable/utils/formatter.py +4 -6
  177. pixeltable/utils/gcs_store.py +283 -0
  178. pixeltable/utils/http_server.py +2 -3
  179. pixeltable/utils/iceberg.py +1 -2
  180. pixeltable/utils/image.py +17 -0
  181. pixeltable/utils/lancedb.py +88 -0
  182. pixeltable/utils/local_store.py +316 -0
  183. pixeltable/utils/misc.py +5 -0
  184. pixeltable/utils/object_stores.py +528 -0
  185. pixeltable/utils/pydantic.py +60 -0
  186. pixeltable/utils/pytorch.py +5 -6
  187. pixeltable/utils/s3_store.py +392 -0
  188. pixeltable-0.4.20.dist-info/METADATA +587 -0
  189. pixeltable-0.4.20.dist-info/RECORD +218 -0
  190. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info}/WHEEL +1 -1
  191. pixeltable-0.4.20.dist-info/entry_points.txt +2 -0
  192. pixeltable/__version__.py +0 -3
  193. pixeltable/ext/__init__.py +0 -17
  194. pixeltable/ext/functions/__init__.py +0 -11
  195. pixeltable/ext/functions/whisperx.py +0 -77
  196. pixeltable/utils/media_store.py +0 -77
  197. pixeltable/utils/s3.py +0 -17
  198. pixeltable/utils/sample.py +0 -25
  199. pixeltable-0.4.0rc3.dist-info/METADATA +0 -435
  200. pixeltable-0.4.0rc3.dist-info/RECORD +0 -189
  201. pixeltable-0.4.0rc3.dist-info/entry_points.txt +0 -3
  202. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info/licenses}/LICENSE +0 -0
pixeltable/plan.py CHANGED
@@ -3,16 +3,16 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import enum
5
5
  from textwrap import dedent
6
- from typing import Any, Iterable, Literal, NamedTuple, Optional, Sequence
6
+ from typing import Any, Iterable, Literal, Sequence, cast
7
7
  from uuid import UUID
8
8
 
9
+ import pgvector.sqlalchemy # type: ignore[import-untyped]
9
10
  import sqlalchemy as sql
10
11
 
11
12
  import pixeltable as pxt
12
13
  from pixeltable import catalog, exceptions as excs, exec, exprs
13
14
  from pixeltable.catalog import Column, TableVersionHandle
14
15
  from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
15
- from pixeltable.utils.sample import sample_key
16
16
 
17
17
 
18
18
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -66,7 +66,7 @@ class JoinClause:
66
66
  """Corresponds to a single 'JOIN ... ON (...)' clause in a SELECT statement; excludes the joined table."""
67
67
 
68
68
  join_type: JoinType
69
- join_predicate: Optional[exprs.Expr] # None for join_type == CROSS
69
+ join_predicate: exprs.Expr | None # None for join_type == CROSS
70
70
 
71
71
 
72
72
  @dataclasses.dataclass
@@ -86,25 +86,20 @@ class FromClause:
86
86
  class SampleClause:
87
87
  """Defines a sampling clause for a table."""
88
88
 
89
- version: Optional[int]
90
- n: Optional[int]
91
- n_per_stratum: Optional[int]
92
- fraction: Optional[float]
93
- seed: Optional[int]
94
- stratify_exprs: Optional[list[exprs.Expr]]
95
-
96
- # This seed value is used if one is not supplied
97
- DEFAULT_SEED = 0
89
+ version: int | None
90
+ n: int | None
91
+ n_per_stratum: int | None
92
+ fraction: float | None
93
+ seed: int | None
94
+ stratify_exprs: list[exprs.Expr] | None
98
95
 
99
96
  # The version of the hashing algorithm used for ordering and fractional sampling.
100
97
  CURRENT_VERSION = 1
101
98
 
102
99
  def __post_init__(self) -> None:
103
- """If no version was provided, provide the default version"""
100
+ # If no version was provided, provide the default version
104
101
  if self.version is None:
105
102
  self.version = self.CURRENT_VERSION
106
- if self.seed is None:
107
- self.seed = self.DEFAULT_SEED
108
103
 
109
104
  @property
110
105
  def is_stratified(self) -> bool:
@@ -159,16 +154,6 @@ class SampleClause:
159
154
  return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
160
155
 
161
156
 
162
- class SamplingClauses(NamedTuple):
163
- """Clauses provided when rewriting a SampleClause"""
164
-
165
- where: exprs.Expr
166
- group_by_clause: Optional[list[exprs.Expr]]
167
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]]
168
- limit: Optional[exprs.Expr]
169
- sample_clause: Optional[SampleClause]
170
-
171
-
172
157
  class Analyzer:
173
158
  """
174
159
  Performs semantic analysis of a query and stores the analysis state.
@@ -177,17 +162,19 @@ class Analyzer:
177
162
  from_clause: FromClause
178
163
  all_exprs: list[exprs.Expr] # union of all exprs, aside from sql_where_clause
179
164
  select_list: list[exprs.Expr]
180
- group_by_clause: Optional[list[exprs.Expr]] # None for non-aggregate queries; [] for agg query w/o grouping
165
+ group_by_clause: list[exprs.Expr] | None # None for non-aggregate queries; [] for agg query w/o grouping
181
166
  grouping_exprs: list[exprs.Expr] # [] for non-aggregate queries or agg query w/o grouping
182
167
  order_by_clause: OrderByClause
168
+ stratify_exprs: list[exprs.Expr] # [] if no stratiifcation is required
169
+ sample_clause: SampleClause | None # None if no sampling clause is present
183
170
 
184
171
  sql_elements: exprs.SqlElementCache
185
172
 
186
173
  # Where clause of the Select stmt of the SQL scan
187
- sql_where_clause: Optional[exprs.Expr]
174
+ sql_where_clause: exprs.Expr | None
188
175
 
189
176
  # filter predicate applied to output rows of the SQL scan
190
- filter: Optional[exprs.Expr]
177
+ filter: exprs.Expr | None
191
178
 
192
179
  agg_fn_calls: list[exprs.FunctionCall] # grouping aggregation (ie, not window functions)
193
180
  window_fn_calls: list[exprs.FunctionCall]
@@ -197,9 +184,10 @@ class Analyzer:
197
184
  self,
198
185
  from_clause: FromClause,
199
186
  select_list: Sequence[exprs.Expr],
200
- where_clause: Optional[exprs.Expr] = None,
201
- group_by_clause: Optional[list[exprs.Expr]] = None,
202
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
187
+ where_clause: exprs.Expr | None = None,
188
+ group_by_clause: list[exprs.Expr] | None = None,
189
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
190
+ sample_clause: SampleClause | None = None,
203
191
  ):
204
192
  if order_by_clause is None:
205
193
  order_by_clause = []
@@ -213,6 +201,11 @@ class Analyzer:
213
201
  self.group_by_clause = (
214
202
  [e.resolve_computed_cols() for e in group_by_clause] if group_by_clause is not None else None
215
203
  )
204
+ self.sample_clause = sample_clause
205
+ if self.sample_clause is not None and self.sample_clause.is_stratified:
206
+ self.stratify_exprs = [e.resolve_computed_cols() for e in sample_clause.stratify_exprs]
207
+ else:
208
+ self.stratify_exprs = []
216
209
  self.order_by_clause = [OrderByItem(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
217
210
 
218
211
  self.sql_where_clause = None
@@ -228,8 +221,11 @@ class Analyzer:
228
221
  self.all_exprs.append(join_clause.join_predicate)
229
222
  if self.group_by_clause is not None:
230
223
  self.all_exprs.extend(self.group_by_clause)
224
+ self.all_exprs.extend(self.stratify_exprs)
231
225
  self.all_exprs.extend(e for e, _ in self.order_by_clause)
232
226
  if self.filter is not None:
227
+ if sample_clause is not None:
228
+ raise excs.Error(f'Filter {self.filter} not expressible in SQL')
233
229
  self.all_exprs.append(self.filter)
234
230
 
235
231
  self.agg_order_by = []
@@ -334,7 +330,7 @@ class Analyzer:
334
330
  row_builder.set_slot_idxs(self.agg_fn_calls)
335
331
  row_builder.set_slot_idxs(self.agg_order_by)
336
332
 
337
- def get_window_fn_ob_clause(self) -> Optional[OrderByClause]:
333
+ def get_window_fn_ob_clause(self) -> OrderByClause | None:
338
334
  clause: list[OrderByClause] = []
339
335
  for fn_call in self.window_fn_calls:
340
336
  # window functions require ordering by the group_by/order_by clauses
@@ -352,7 +348,7 @@ class Analyzer:
352
348
  class Planner:
353
349
  # TODO: create an exec.CountNode and change this to create_count_plan()
354
350
  @classmethod
355
- def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
351
+ def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Expr | None = None) -> sql.Select:
356
352
  stmt = sql.select(sql.func.count().label('all_count'))
357
353
  refd_tbl_ids: set[UUID] = set()
358
354
  if where_clause is not None:
@@ -378,21 +374,14 @@ class Planner:
378
374
 
379
375
  cls.__check_valid_columns(tbl, stored_cols, 'inserted into')
380
376
 
381
- row_builder = exprs.RowBuilder([], stored_cols, [])
377
+ row_builder = exprs.RowBuilder([], stored_cols, [], tbl)
382
378
 
383
379
  # create InMemoryDataNode for 'rows'
384
380
  plan: exec.ExecNode = exec.InMemoryDataNode(
385
381
  TableVersionHandle(tbl.id, tbl.effective_version), rows, row_builder, tbl.next_row_id
386
382
  )
387
383
 
388
- media_input_col_info = [
389
- exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
390
- for col_ref in row_builder.input_exprs
391
- if isinstance(col_ref, exprs.ColumnRef) and col_ref.col_type.is_media_type()
392
- ]
393
- if len(media_input_col_info) > 0:
394
- # prefetch external files for all input column refs
395
- plan = exec.CachePrefetchNode(tbl.id, media_input_col_info, input=plan)
384
+ plan = cls._add_prefetch_node(tbl.id, row_builder.input_exprs, input_node=plan)
396
385
 
397
386
  computed_exprs = row_builder.output_exprs - row_builder.input_exprs
398
387
  if len(computed_exprs) > 0:
@@ -400,10 +389,9 @@ class Planner:
400
389
  plan = exec.ExprEvalNode(
401
390
  row_builder, computed_exprs, plan.output_exprs, input=plan, maintain_input_order=False
402
391
  )
392
+ if any(c.col_type.is_json_type() or c.col_type.is_array_type() for c in stored_cols):
393
+ plan = exec.CellMaterializationNode(plan)
403
394
 
404
- stored_col_info = row_builder.output_slot_idxs()
405
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
406
- plan.set_stored_img_cols(stored_img_col_info)
407
395
  plan.set_ctx(
408
396
  exec.ExecContext(
409
397
  row_builder,
@@ -413,10 +401,12 @@ class Planner:
413
401
  ignore_errors=ignore_errors,
414
402
  )
415
403
  )
404
+ plan = cls._add_save_node(plan)
405
+
416
406
  return plan
417
407
 
418
408
  @classmethod
419
- def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: Optional[int] = None) -> list[exprs.Expr]:
409
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: int | None = None) -> list[exprs.Expr]:
420
410
  """Return list of RowidRef for the given number of associated rowids"""
421
411
  if num_rowid_cols is None:
422
412
  num_rowid_cols = target.get().num_rowid_columns()
@@ -430,14 +420,17 @@ class Planner:
430
420
  plan = df._create_query_plan() # ExecNode constructed by the DataFrame
431
421
 
432
422
  # Modify the plan RowBuilder to register the output columns
423
+ needs_cell_materialization = False
433
424
  for col_name, expr in zip(df.schema.keys(), df._select_list_exprs):
434
425
  assert col_name in tbl.cols_by_name
435
426
  col = tbl.cols_by_name[col_name]
436
427
  plan.row_builder.add_table_column(col, expr.slot_idx)
428
+ needs_cell_materialization = (
429
+ needs_cell_materialization or col.col_type.is_json_type() or col.col_type.is_array_type()
430
+ )
437
431
 
438
- stored_col_info = plan.row_builder.output_slot_idxs()
439
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
440
- plan.set_stored_img_cols(stored_img_col_info)
432
+ if needs_cell_materialization:
433
+ plan = exec.CellMaterializationNode(plan)
441
434
 
442
435
  plan.set_ctx(
443
436
  exec.ExecContext(
@@ -454,16 +447,18 @@ class Planner:
454
447
  tbl: catalog.TableVersionPath,
455
448
  update_targets: dict[catalog.Column, exprs.Expr],
456
449
  recompute_targets: list[catalog.Column],
457
- where_clause: Optional[exprs.Expr],
450
+ where_clause: exprs.Expr | None,
458
451
  cascade: bool,
459
452
  ) -> tuple[exec.ExecNode, list[str], list[catalog.Column]]:
460
453
  """Creates a plan to materialize updated rows.
454
+
461
455
  The plan:
462
456
  - retrieves rows that are visible at the current version of the table
463
457
  - materializes all stored columns and the update targets
464
458
  - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
465
459
  and copies the values of all other stored columns
466
460
  - if cascade is False, copies all columns that aren't update targets from the original rows
461
+
467
462
  Returns:
468
463
  - root node of the plan
469
464
  - list of qualified column names that are getting updated
@@ -473,26 +468,33 @@ class Planner:
473
468
  assert isinstance(tbl, catalog.TableVersionPath)
474
469
  target = tbl.tbl_version.get() # the one we need to update
475
470
  updated_cols = list(update_targets.keys())
471
+ recomputed_cols: set[Column]
476
472
  if len(recompute_targets) > 0:
477
- recomputed_cols = set(recompute_targets)
473
+ assert len(update_targets) == 0
474
+ recomputed_cols = {*recompute_targets}
475
+ if cascade:
476
+ recomputed_cols |= target.get_dependent_columns(recomputed_cols)
478
477
  else:
479
478
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
480
- # regardless of cascade, we need to update all indices on any updated column
481
- idx_val_cols = target.get_idx_val_columns(updated_cols)
482
- recomputed_cols.update(idx_val_cols)
483
- # we only need to recompute stored columns (unstored ones are substituted away)
484
- recomputed_cols = {c for c in recomputed_cols if c.is_stored}
479
+ # regardless of cascade, we need to update all indices on any updated/recomputed column
480
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
481
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
482
+ recomputed_cols.update(idx_val_cols)
483
+ # we only need to recompute stored columns (unstored ones are substituted away)
484
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
485
485
 
486
486
  cls.__check_valid_columns(tbl.tbl_version.get(), recomputed_cols, 'updated in')
487
487
 
488
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl.id == tbl.tbl_version.id}
488
+ # our query plan
489
+ # - evaluates the update targets and recomputed columns
490
+ # - copies all other stored columns
491
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == tbl.tbl_version.id}
489
492
  copied_cols = [
490
493
  col
491
494
  for col in target.cols_by_id.values()
492
495
  if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
493
496
  ]
494
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
495
- select_list.extend(update_targets.values())
497
+ select_list: list[exprs.Expr] = list(update_targets.values())
496
498
 
497
499
  recomputed_exprs = [
498
500
  c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
@@ -503,13 +505,25 @@ class Planner:
503
505
  select_list.extend(recomputed_exprs)
504
506
 
505
507
  # we need to retrieve the PK columns of the existing rows
506
- plan = cls.create_query_plan(FromClause(tbls=[tbl]), select_list, where_clause=where_clause, ignore_errors=True)
507
- all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
508
+ plan = cls.create_query_plan(
509
+ FromClause(tbls=[tbl]),
510
+ select_list=select_list,
511
+ columns=copied_cols,
512
+ where_clause=where_clause,
513
+ ignore_errors=True,
514
+ )
515
+ evaluated_cols = updated_cols + list(recomputed_base_cols) # same order as select_list
508
516
  # update row builder with column information
509
- for i, col in enumerate(all_base_cols):
517
+ plan.row_builder.add_table_columns(copied_cols)
518
+ for i, col in enumerate(evaluated_cols):
510
519
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
520
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
521
+
522
+ plan = cls._add_cell_materialization_node(plan)
523
+ plan = cls._add_save_node(plan)
524
+
511
525
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
512
- return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
526
+ return plan, [f'{c.get_tbl().name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
513
527
 
514
528
  @classmethod
515
529
  def __check_valid_columns(
@@ -529,6 +543,79 @@ class Planner:
529
543
  .format(validation_error=col.value_expr.validation_error)
530
544
  )
531
545
 
546
+ @classmethod
547
+ def _cell_md_col_refs(cls, expr_list: Iterable[exprs.Expr]) -> list[exprs.ColumnRef]:
548
+ """Return list of ColumnRefs that need their cellmd values for reconstruction"""
549
+ json_col_refs = list(
550
+ exprs.Expr.list_subexprs(
551
+ expr_list,
552
+ expr_class=exprs.ColumnRef,
553
+ filter=lambda e: cast(exprs.ColumnRef, e).col.col_type.is_json_type(),
554
+ traverse_matches=False,
555
+ )
556
+ )
557
+
558
+ def needs_reconstruction(e: exprs.Expr) -> bool:
559
+ assert isinstance(e, exprs.ColumnRef)
560
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
561
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
562
+
563
+ array_col_refs = list(
564
+ exprs.Expr.list_subexprs(
565
+ expr_list, expr_class=exprs.ColumnRef, filter=needs_reconstruction, traverse_matches=False
566
+ )
567
+ )
568
+
569
+ return json_col_refs + array_col_refs
570
+
571
+ @classmethod
572
+ def _add_cell_materialization_node(cls, input: exec.ExecNode) -> exec.ExecNode:
573
+ # we need a CellMaterializationNode if any of the evaluated output columns are json or array-typed
574
+ has_target_cols = any(
575
+ col.col_type.is_json_type() or col.col_type.is_array_type()
576
+ for col, slot_idx in input.row_builder.table_columns.items()
577
+ if slot_idx is not None
578
+ )
579
+ if has_target_cols:
580
+ return exec.CellMaterializationNode(input)
581
+ else:
582
+ return input
583
+
584
+ @classmethod
585
+ def _add_cell_reconstruction_node(cls, expr_list: list[exprs.Expr], input: exec.ExecNode) -> exec.ExecNode:
586
+ """
587
+ Add a CellReconstructionNode, if required by any of the exprs in expr_list.
588
+
589
+ Cell reconstruction is required for
590
+ 1) all json-typed ColumnRefs that are not used as part of a JsonPath (the latter does its own reconstruction)
591
+ or as part of a ColumnPropertyRef
592
+ 2) all array-typed ColumnRefs that are not used as part of a ColumnPropertyRef
593
+ """
594
+
595
+ def json_filter(e: exprs.Expr) -> bool:
596
+ if isinstance(e, exprs.JsonPath):
597
+ return not e.is_relative_path() and isinstance(e.anchor, exprs.ColumnRef)
598
+ if isinstance(e, exprs.ColumnPropertyRef):
599
+ return e.col_ref.col.col_type.is_json_type()
600
+ return isinstance(e, exprs.ColumnRef) and e.col.col_type.is_json_type()
601
+
602
+ def array_filter(e: exprs.Expr) -> bool:
603
+ if isinstance(e, exprs.ColumnPropertyRef):
604
+ return e.col_ref.col.col_type.is_array_type()
605
+ if not isinstance(e, exprs.ColumnRef):
606
+ return False
607
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
608
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
609
+
610
+ json_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=json_filter, traverse_matches=False))
611
+ json_refs = [e for e in json_candidates if isinstance(e, exprs.ColumnRef)]
612
+ array_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=array_filter, traverse_matches=False))
613
+ array_refs = [e for e in array_candidates if isinstance(e, exprs.ColumnRef)]
614
+ if len(json_refs) > 0 or len(array_refs) > 0:
615
+ return exec.CellReconstructionNode(json_refs, array_refs, input.row_builder, input=input)
616
+ else:
617
+ return input
618
+
532
619
  @classmethod
533
620
  def create_batch_update_plan(
534
621
  cls,
@@ -547,8 +634,8 @@ class Planner:
547
634
  """
548
635
  assert isinstance(tbl, catalog.TableVersionPath)
549
636
  target = tbl.tbl_version.get() # the one we need to update
550
- sa_key_cols: list[sql.Column] = []
551
- key_vals: list[tuple] = []
637
+ sa_key_cols: list[sql.Column]
638
+ key_vals: list[tuple]
552
639
  if len(rowids) > 0:
553
640
  sa_key_cols = target.store_tbl.rowid_columns()
554
641
  key_vals = rowids
@@ -561,18 +648,18 @@ class Planner:
561
648
  updated_cols = batch[0].keys() - target.primary_key_columns()
562
649
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
563
650
  # regardless of cascade, we need to update all indices on any updated column
564
- idx_val_cols = target.get_idx_val_columns(updated_cols)
651
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
652
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
565
653
  recomputed_cols.update(idx_val_cols)
566
654
  # we only need to recompute stored columns (unstored ones are substituted away)
567
655
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
568
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl.id == target.id}
656
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == target.id}
569
657
  copied_cols = [
570
658
  col
571
659
  for col in target.cols_by_id.values()
572
660
  if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
573
661
  ]
574
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
575
- select_list.extend(exprs.ColumnRef(col) for col in updated_cols)
662
+ select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in updated_cols]
576
663
 
577
664
  recomputed_exprs = [
578
665
  c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
@@ -588,25 +675,39 @@ class Planner:
588
675
  sql_exprs = list(
589
676
  exprs.Expr.list_subexprs(analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False)
590
677
  )
591
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs)
678
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs, target)
592
679
  analyzer.finalize(row_builder)
593
- sql_lookup_node = exec.SqlLookupNode(tbl, row_builder, sql_exprs, sa_key_cols, key_vals)
680
+
681
+ cell_md_col_refs = cls._cell_md_col_refs(sql_exprs)
682
+ sql_lookup_node = exec.SqlLookupNode(
683
+ tbl,
684
+ row_builder,
685
+ sql_exprs,
686
+ columns=copied_cols,
687
+ sa_key_cols=sa_key_cols,
688
+ key_vals=key_vals,
689
+ cell_md_col_refs=cell_md_col_refs,
690
+ )
594
691
  col_vals = [{col: row[col].val for col in updated_cols} for row in batch]
595
692
  row_update_node = exec.RowUpdateNode(tbl, key_vals, len(rowids) > 0, col_vals, row_builder, sql_lookup_node)
596
693
  plan: exec.ExecNode = row_update_node
597
694
  if not cls._is_contained_in(analyzer.select_list, sql_exprs):
598
695
  # we need an ExprEvalNode to evaluate the remaining output exprs
599
696
  plan = exec.ExprEvalNode(row_builder, analyzer.select_list, sql_exprs, input=plan)
697
+
600
698
  # update row builder with column information
601
- all_base_cols = copied_cols + list(updated_cols) + list(recomputed_base_cols) # same order as select_list
699
+ evaluated_cols = list(updated_cols) + list(recomputed_base_cols) # same order as select_list
602
700
  row_builder.set_slot_idxs(select_list, remove_duplicates=False)
603
- for i, col in enumerate(all_base_cols):
701
+ plan.row_builder.add_table_columns(copied_cols)
702
+ for i, col in enumerate(evaluated_cols):
604
703
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
605
-
606
- ctx = exec.ExecContext(row_builder)
607
- # we're returning everything to the user, so we might as well do it in a single batch
704
+ ctx = exec.ExecContext(row_builder, num_computed_exprs=len(recomputed_exprs))
705
+ # TODO: correct batch size?
608
706
  ctx.batch_size = 0
609
707
  plan.set_ctx(ctx)
708
+
709
+ plan = cls._add_cell_materialization_node(plan)
710
+ plan = cls._add_save_node(plan)
610
711
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
611
712
  return (
612
713
  plan,
@@ -656,13 +757,13 @@ class Planner:
656
757
  ignore_errors=True,
657
758
  exact_version_only=view.get_bases(),
658
759
  )
659
- for i, col in enumerate(copied_cols + list(recomputed_cols)): # same order as select_list
760
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
761
+ materialized_cols = copied_cols + list(recomputed_cols) # same order as select_list
762
+ for i, col in enumerate(materialized_cols):
660
763
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
661
- # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
662
- stored_img_col_info = [
663
- info for info in plan.row_builder.output_slot_idxs() if info.col.col_type.is_image_type()
664
- ]
665
- plan.set_stored_img_cols(stored_img_col_info)
764
+ plan = cls._add_cell_materialization_node(plan)
765
+ plan = cls._add_save_node(plan)
766
+
666
767
  return plan
667
768
 
668
769
  @classmethod
@@ -691,25 +792,13 @@ class Planner:
691
792
  # 2. for component views: iterator args
692
793
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
693
794
 
694
- # If this contains a sample specification, modify / create where, group_by, order_by, and limit clauses
695
795
  from_clause = FromClause(tbls=[view.base])
696
- where, group_by_clause, order_by_clause, limit, sample_clause = cls.create_sample_clauses(
697
- from_clause, target.sample_clause, target.predicate, None, [], None
698
- )
699
-
700
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
701
796
  base_analyzer = Analyzer(
702
- from_clause,
703
- iterator_args,
704
- where_clause=where,
705
- group_by_clause=group_by_clause,
706
- order_by_clause=order_by_clause,
797
+ from_clause, iterator_args, where_clause=target.predicate, sample_clause=target.sample_clause
707
798
  )
708
- row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [])
709
-
710
- if target.sample_clause is not None and base_analyzer.filter is not None:
711
- raise excs.Error(f'Filter {base_analyzer.filter} not expressible in SQL')
799
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [], target)
712
800
 
801
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
713
802
  # execution plan:
714
803
  # 1. materialize exprs computed from the base that are needed for stored view columns
715
804
  # 2. if it's an iterator view, expand the base rows into component rows
@@ -723,19 +812,13 @@ class Planner:
723
812
 
724
813
  # Create a new analyzer reflecting exactly what is required from the base table
725
814
  base_analyzer = Analyzer(
726
- from_clause,
727
- base_output_exprs,
728
- where_clause=where,
729
- group_by_clause=group_by_clause,
730
- order_by_clause=order_by_clause,
815
+ from_clause, base_output_exprs, where_clause=target.predicate, sample_clause=target.sample_clause
731
816
  )
732
817
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
733
818
  plan = cls._create_query_plan(
734
819
  row_builder=row_builder,
735
820
  analyzer=base_analyzer,
736
821
  eval_ctx=base_eval_ctx,
737
- limit=limit,
738
- sample_clause=sample_clause,
739
822
  with_pk=True,
740
823
  exact_version_only=view.get_bases() if propagates_insert else [],
741
824
  )
@@ -747,10 +830,12 @@ class Planner:
747
830
  row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs, input=plan
748
831
  )
749
832
 
750
- stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
751
- plan.set_stored_img_cols(stored_img_col_info)
752
833
  exec_ctx.ignore_errors = True
753
834
  plan.set_ctx(exec_ctx)
835
+ if any(c.col_type.is_json_type() or c.col_type.is_array_type() for c in stored_cols):
836
+ plan = exec.CellMaterializationNode(plan)
837
+ plan = cls._add_save_node(plan)
838
+
754
839
  return plan, len(row_builder.default_eval_ctx.target_exprs)
755
840
 
756
841
  @classmethod
@@ -761,7 +846,7 @@ class Planner:
761
846
  raise excs.Error(f'Join predicate {join_clause.join_predicate} not expressible in SQL')
762
847
 
763
848
  @classmethod
764
- def _create_combined_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> Optional[OrderByClause]:
849
+ def _create_combined_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> OrderByClause | None:
765
850
  """Verify that the various ordering requirements don't conflict and return a combined ordering"""
766
851
  ob_clauses: list[OrderByClause] = [analyzer.order_by_clause.copy()]
767
852
 
@@ -795,22 +880,29 @@ class Planner:
795
880
  combined_ordering = combined
796
881
  return combined_ordering
797
882
 
883
+ @classmethod
884
+ def _add_save_node(cls, input_node: exec.ExecNode) -> exec.ExecNode:
885
+ """Add an ObjectStoreSaveNode, if needed."""
886
+ media_col_info = input_node.row_builder.media_output_col_info
887
+ if len(media_col_info) == 0:
888
+ return input_node
889
+ else:
890
+ return exec.ObjectStoreSaveNode(media_col_info, input_node)
891
+
798
892
  @classmethod
799
893
  def _is_contained_in(cls, l1: Iterable[exprs.Expr], l2: Iterable[exprs.Expr]) -> bool:
800
894
  """Returns True if l1 is contained in l2"""
801
895
  return {e.id for e in l1} <= {e.id for e in l2}
802
896
 
803
897
  @classmethod
804
- def _insert_prefetch_node(
805
- cls, tbl_id: UUID, row_builder: exprs.RowBuilder, input_node: exec.ExecNode
898
+ def _add_prefetch_node(
899
+ cls, tbl_id: UUID, expressions: Iterable[exprs.Expr], input_node: exec.ExecNode
806
900
  ) -> exec.ExecNode:
807
- """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
901
+ """Add a CachePrefetch node, if needed."""
808
902
  # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
809
903
  # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
810
904
  # aren't explicitly captured as dependencies
811
- media_col_refs = [
812
- e for e in list(row_builder.unique_exprs) if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()
813
- ]
905
+ media_col_refs = [e for e in expressions if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()]
814
906
  if len(media_col_refs) == 0:
815
907
  return input_node
816
908
  # we need to prefetch external files for media column types
@@ -818,101 +910,52 @@ class Planner:
818
910
  prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
819
911
  return prefetch_node
820
912
 
821
- @classmethod
822
- def create_sample_clauses(
823
- cls,
824
- from_clause: FromClause,
825
- sample_clause: SampleClause,
826
- where_clause: Optional[exprs.Expr],
827
- group_by_clause: Optional[list[exprs.Expr]],
828
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]],
829
- limit: Optional[exprs.Expr],
830
- ) -> SamplingClauses:
831
- """tuple[
832
- exprs.Expr,
833
- Optional[list[exprs.Expr]],
834
- Optional[list[tuple[exprs.Expr, bool]]],
835
- Optional[exprs.Expr],
836
- Optional[SampleClause],
837
- ]:"""
838
- """Construct clauses required for sampling under various conditions.
839
- If there is no sampling, then return the original clauses.
840
- If the sample is stratified, then return only the group by clause. The rest of the
841
- mechanism for stratified sampling is provided by the SampleSqlNode.
842
- If the sample is non-stratified, then rewrite the query to accommodate the supplied where clause,
843
- and provide the other clauses required for sampling
844
- """
845
-
846
- # If no sample clause, return the original clauses
847
- if sample_clause is None:
848
- return SamplingClauses(where_clause, group_by_clause, order_by_clause, limit, None)
849
-
850
- # If the sample clause is stratified, create a group by clause
851
- if sample_clause.is_stratified:
852
- group_by = sample_clause.stratify_exprs
853
- # Note that limit is not possible here
854
- return SamplingClauses(where_clause, group_by, order_by_clause, None, sample_clause)
855
-
856
- else:
857
- # If non-stratified sampling, construct a where clause, order_by, and limit clauses
858
- # Construct an expression for sorting rows and limiting row counts
859
- s_key = sample_key(
860
- exprs.Literal(sample_clause.seed), *cls.rowid_columns(from_clause._first_tbl.tbl_version)
861
- )
862
-
863
- # Construct a suitable where clause
864
- where = where_clause
865
- if sample_clause.fraction is not None:
866
- fraction_md5_hex = exprs.Expr.from_object(
867
- sample_clause.fraction_to_md5_hex(float(sample_clause.fraction))
868
- )
869
- f_where = s_key < fraction_md5_hex
870
- where = where & f_where if where is not None else f_where
871
-
872
- order_by: list[tuple[exprs.Expr, bool]] = [(s_key, True)]
873
- limit = exprs.Literal(sample_clause.n)
874
- # Note that group_by is not possible here
875
- return SamplingClauses(where, None, order_by, limit, None)
876
-
877
913
  @classmethod
878
914
  def create_query_plan(
879
915
  cls,
880
916
  from_clause: FromClause,
881
- select_list: Optional[list[exprs.Expr]] = None,
882
- where_clause: Optional[exprs.Expr] = None,
883
- group_by_clause: Optional[list[exprs.Expr]] = None,
884
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
885
- limit: Optional[exprs.Expr] = None,
886
- sample_clause: Optional[SampleClause] = None,
917
+ select_list: list[exprs.Expr] | None = None,
918
+ columns: list[catalog.Column] | None = None,
919
+ where_clause: exprs.Expr | None = None,
920
+ group_by_clause: list[exprs.Expr] | None = None,
921
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
922
+ limit: exprs.Expr | None = None,
923
+ sample_clause: SampleClause | None = None,
887
924
  ignore_errors: bool = False,
888
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
925
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
889
926
  ) -> exec.ExecNode:
890
- """Return plan for executing a query.
927
+ """
928
+ Return plan for executing a query.
929
+
930
+ The plan:
931
+ - materializes the values of select_list exprs into their respective slots
932
+ - materializes cell values of 'columns' (and their cellmd, if applicable) into DataRow.cell_vals/cell_md
933
+
891
934
  Updates 'select_list' in place to make it executable.
892
935
  TODO: make exact_version_only a flag and use the versions from tbl
893
936
  """
894
937
  if select_list is None:
895
938
  select_list = []
939
+ if columns is None:
940
+ columns = []
896
941
  if order_by_clause is None:
897
942
  order_by_clause = []
898
943
  if exact_version_only is None:
899
944
  exact_version_only = []
900
945
 
901
- # Modify clauses to include sample clause
902
- where, group_by_clause, order_by_clause, limit, sample = cls.create_sample_clauses(
903
- from_clause, sample_clause, where_clause, group_by_clause, order_by_clause, limit
904
- )
905
-
906
946
  analyzer = Analyzer(
907
947
  from_clause,
908
948
  select_list,
909
- where_clause=where,
949
+ where_clause=where_clause,
910
950
  group_by_clause=group_by_clause,
911
951
  order_by_clause=order_by_clause,
952
+ sample_clause=sample_clause,
912
953
  )
913
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
914
- if sample_clause is not None and analyzer.filter is not None:
915
- raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
954
+ # If the from_clause has a single table, we can use it as the context table for the RowBuilder.
955
+ # Otherwise there is no context table, but that's ok, because the context table is only needed for
956
+ # table mutations, which can't happen during a join.
957
+ context_tbl = from_clause.tbls[0].tbl_version.get() if len(from_clause.tbls) == 1 else None
958
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [], context_tbl)
916
959
 
917
960
  analyzer.finalize(row_builder)
918
961
  # select_list: we need to materialize everything that's been collected
@@ -922,8 +965,8 @@ class Planner:
922
965
  row_builder=row_builder,
923
966
  analyzer=analyzer,
924
967
  eval_ctx=eval_ctx,
968
+ columns=columns,
925
969
  limit=limit,
926
- sample_clause=sample,
927
970
  with_pk=True,
928
971
  exact_version_only=exact_version_only,
929
972
  )
@@ -938,10 +981,10 @@ class Planner:
938
981
  row_builder: exprs.RowBuilder,
939
982
  analyzer: Analyzer,
940
983
  eval_ctx: exprs.RowBuilder.EvalCtx,
941
- limit: Optional[exprs.Expr] = None,
942
- sample_clause: Optional[SampleClause] = None,
984
+ columns: list[catalog.Column] | None = None,
985
+ limit: exprs.Expr | None = None,
943
986
  with_pk: bool = False,
944
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
987
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
945
988
  ) -> exec.ExecNode:
946
989
  """
947
990
  Create plan to materialize eval_ctx.
@@ -951,6 +994,8 @@ class Planner:
951
994
  in the context of that table version (eg, if 'tbl' is a view, 'plan_target' might be the base)
952
995
  TODO: make exact_version_only a flag and use the versions from tbl
953
996
  """
997
+ if columns is None:
998
+ columns = []
954
999
  if exact_version_only is None:
955
1000
  exact_version_only = []
956
1001
  sql_elements = analyzer.sql_elements
@@ -958,6 +1003,7 @@ class Planner:
958
1003
  analyzer.window_fn_calls
959
1004
  )
960
1005
  ctx = exec.ExecContext(row_builder)
1006
+
961
1007
  combined_ordering = cls._create_combined_ordering(analyzer, verify_agg=is_python_agg)
962
1008
  cls._verify_join_clauses(analyzer)
963
1009
 
@@ -966,6 +1012,7 @@ class Planner:
966
1012
  # - join clause subexprs
967
1013
  # - subexprs of Where clause conjuncts that can't be run in SQL
968
1014
  # - all grouping exprs
1015
+ # - all stratify exprs
969
1016
  candidates = list(
970
1017
  exprs.Expr.list_subexprs(
971
1018
  analyzer.select_list,
@@ -980,10 +1027,12 @@ class Planner:
980
1027
  candidates.extend(
981
1028
  exprs.Expr.subexprs(analyzer.filter, filter=sql_elements.contains, traverse_matches=False)
982
1029
  )
983
- if analyzer.group_by_clause is not None:
984
- candidates.extend(
985
- exprs.Expr.list_subexprs(analyzer.group_by_clause, filter=sql_elements.contains, traverse_matches=False)
986
- )
1030
+ candidates.extend(
1031
+ exprs.Expr.list_subexprs(analyzer.grouping_exprs, filter=sql_elements.contains, traverse_matches=False)
1032
+ )
1033
+ candidates.extend(
1034
+ exprs.Expr.list_subexprs(analyzer.stratify_exprs, filter=sql_elements.contains, traverse_matches=False)
1035
+ )
987
1036
  # not isinstance(...): we don't want to materialize Literals via a Select
988
1037
  sql_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
989
1038
 
@@ -1005,8 +1054,15 @@ class Planner:
1005
1054
  traverse_matches=False,
1006
1055
  )
1007
1056
  )
1057
+
1008
1058
  plan = exec.SqlScanNode(
1009
- tbl, row_builder, select_list=tbl_scan_exprs, set_pk=with_pk, exact_version_only=exact_version_only
1059
+ tbl,
1060
+ row_builder,
1061
+ select_list=tbl_scan_exprs,
1062
+ columns=[c for c in columns if c.get_tbl().id == tbl.tbl_id],
1063
+ set_pk=with_pk,
1064
+ cell_md_col_refs=cls._cell_md_col_refs(tbl_scan_exprs),
1065
+ exact_version_only=exact_version_only,
1010
1066
  )
1011
1067
  tbl_scan_plans.append(plan)
1012
1068
 
@@ -1028,7 +1084,17 @@ class Planner:
1028
1084
  # we need to order the input for window functions
1029
1085
  plan.set_order_by(analyzer.get_window_fn_ob_clause())
1030
1086
 
1031
- plan = cls._insert_prefetch_node(tbl.tbl_version.id, row_builder, plan)
1087
+ if analyzer.sample_clause is not None:
1088
+ plan = exec.SqlSampleNode(
1089
+ row_builder,
1090
+ input=plan,
1091
+ select_list=tbl_scan_exprs,
1092
+ sample_clause=analyzer.sample_clause,
1093
+ stratify_exprs=analyzer.stratify_exprs,
1094
+ )
1095
+
1096
+ plan = cls._add_prefetch_node(tbl.tbl_version.id, row_builder.unique_exprs, plan)
1097
+ plan = cls._add_cell_reconstruction_node(analyzer.all_exprs, plan)
1032
1098
 
1033
1099
  if analyzer.group_by_clause is not None:
1034
1100
  # we're doing grouping aggregation; the input of the AggregateNode are the grouping exprs plus the
@@ -1050,26 +1116,12 @@ class Planner:
1050
1116
  sql_elements.contains_all(analyzer.select_list)
1051
1117
  and sql_elements.contains_all(analyzer.grouping_exprs)
1052
1118
  and isinstance(plan, exec.SqlNode)
1053
- and plan.to_cte(keep_pk=(sample_clause is not None)) is not None
1119
+ and plan.to_cte() is not None
1054
1120
  ):
1055
- if sample_clause is not None:
1056
- plan = exec.SqlSampleNode(
1057
- row_builder,
1058
- input=plan,
1059
- select_list=analyzer.select_list,
1060
- stratify_exprs=analyzer.group_by_clause,
1061
- sample_clause=sample_clause,
1062
- )
1063
- else:
1064
- plan = exec.SqlAggregationNode(
1065
- row_builder,
1066
- input=plan,
1067
- select_list=analyzer.select_list,
1068
- group_by_items=analyzer.group_by_clause,
1069
- )
1121
+ plan = exec.SqlAggregationNode(
1122
+ row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause
1123
+ )
1070
1124
  else:
1071
- if sample_clause is not None:
1072
- raise excs.Error('Sample clause not supported with Python aggregation')
1073
1125
  input_sql_node = plan.get_node(exec.SqlNode)
1074
1126
  assert combined_ordering is not None
1075
1127
  input_sql_node.set_order_by(combined_ordering)
@@ -1086,6 +1138,7 @@ class Planner:
1086
1138
  if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
1087
1139
  # we need an ExprEvalNode to evaluate the remaining output exprs
1088
1140
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
1141
+ plan = cls._add_save_node(plan)
1089
1142
  else:
1090
1143
  if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
1091
1144
  # we need an ExprEvalNode to evaluate the remaining output exprs
@@ -1119,26 +1172,24 @@ class Planner:
1119
1172
  return Analyzer(FromClause(tbls=[tbl]), [], where_clause=where_clause)
1120
1173
 
1121
1174
  @classmethod
1122
- def create_add_column_plan(
1123
- cls, tbl: catalog.TableVersionPath, col: catalog.Column
1124
- ) -> tuple[exec.ExecNode, Optional[int]]:
1175
+ def create_add_column_plan(cls, tbl: catalog.TableVersionPath, col: catalog.Column) -> exec.ExecNode:
1125
1176
  """Creates a plan for InsertableTable.add_column()
1126
1177
  Returns:
1127
1178
  plan: the plan to execute
1128
1179
  value_expr slot idx for the plan output (for computed cols)
1129
1180
  """
1130
1181
  assert isinstance(tbl, catalog.TableVersionPath)
1131
- row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
1182
+ row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[], tbl=tbl.tbl_version.get())
1132
1183
  analyzer = Analyzer(FromClause(tbls=[tbl]), row_builder.default_eval_ctx.target_exprs)
1133
1184
  plan = cls._create_query_plan(
1134
1185
  row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True
1135
1186
  )
1187
+
1136
1188
  plan.ctx.batch_size = 16
1137
1189
  plan.ctx.show_pbar = True
1138
1190
  plan.ctx.ignore_errors = True
1191
+ computed_exprs = row_builder.output_exprs - row_builder.input_exprs
1192
+ plan.ctx.num_computed_exprs = len(computed_exprs) # we are adding a computed column, so we need to evaluate it
1193
+ plan = cls._add_save_node(plan)
1139
1194
 
1140
- # we want to flush images
1141
- if col.is_computed and col.is_stored and col.col_type.is_image_type():
1142
- plan.set_stored_img_cols(row_builder.output_slot_idxs())
1143
- value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
1144
- return plan, value_expr_slot_idx
1195
+ return plan