pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. pixeltable/__init__.py +42 -8
  2. pixeltable/{dataframe.py → _query.py} +470 -206
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +5 -4
  5. pixeltable/catalog/catalog.py +1785 -432
  6. pixeltable/catalog/column.py +190 -113
  7. pixeltable/catalog/dir.py +2 -4
  8. pixeltable/catalog/globals.py +19 -46
  9. pixeltable/catalog/insertable_table.py +191 -98
  10. pixeltable/catalog/path.py +63 -23
  11. pixeltable/catalog/schema_object.py +11 -15
  12. pixeltable/catalog/table.py +843 -436
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +978 -657
  15. pixeltable/catalog/table_version_handle.py +72 -16
  16. pixeltable/catalog/table_version_path.py +112 -43
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +134 -90
  20. pixeltable/config.py +134 -22
  21. pixeltable/env.py +471 -157
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +4 -1
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +11 -7
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +106 -56
  37. pixeltable/exec/globals.py +35 -0
  38. pixeltable/exec/in_memory_data_node.py +19 -19
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +351 -84
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +36 -23
  46. pixeltable/exprs/column_ref.py +213 -89
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +164 -54
  50. pixeltable/exprs/expr.py +70 -44
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +100 -40
  54. pixeltable/exprs/globals.py +2 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +18 -32
  57. pixeltable/exprs/is_null.py +7 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +27 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +167 -67
  64. pixeltable/exprs/rowid_ref.py +25 -10
  65. pixeltable/exprs/similarity_expr.py +58 -40
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +17 -11
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +29 -27
  78. pixeltable/func/signature.py +46 -19
  79. pixeltable/func/tools.py +31 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +16 -0
  82. pixeltable/functions/anthropic.py +123 -77
  83. pixeltable/functions/audio.py +147 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +7 -4
  86. pixeltable/functions/deepseek.py +35 -43
  87. pixeltable/functions/document.py +81 -0
  88. pixeltable/functions/fal.py +76 -0
  89. pixeltable/functions/fireworks.py +11 -20
  90. pixeltable/functions/gemini.py +195 -39
  91. pixeltable/functions/globals.py +142 -14
  92. pixeltable/functions/groq.py +108 -0
  93. pixeltable/functions/huggingface.py +1056 -24
  94. pixeltable/functions/image.py +115 -57
  95. pixeltable/functions/json.py +1 -1
  96. pixeltable/functions/llama_cpp.py +28 -13
  97. pixeltable/functions/math.py +67 -5
  98. pixeltable/functions/mistralai.py +18 -55
  99. pixeltable/functions/net.py +70 -0
  100. pixeltable/functions/ollama.py +20 -13
  101. pixeltable/functions/openai.py +240 -226
  102. pixeltable/functions/openrouter.py +143 -0
  103. pixeltable/functions/replicate.py +4 -4
  104. pixeltable/functions/reve.py +250 -0
  105. pixeltable/functions/string.py +239 -69
  106. pixeltable/functions/timestamp.py +16 -16
  107. pixeltable/functions/together.py +24 -84
  108. pixeltable/functions/twelvelabs.py +188 -0
  109. pixeltable/functions/util.py +6 -1
  110. pixeltable/functions/uuid.py +30 -0
  111. pixeltable/functions/video.py +1515 -107
  112. pixeltable/functions/vision.py +8 -8
  113. pixeltable/functions/voyageai.py +289 -0
  114. pixeltable/functions/whisper.py +16 -8
  115. pixeltable/functions/whisperx.py +179 -0
  116. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  117. pixeltable/globals.py +362 -115
  118. pixeltable/index/base.py +17 -21
  119. pixeltable/index/btree.py +28 -22
  120. pixeltable/index/embedding_index.py +100 -118
  121. pixeltable/io/__init__.py +4 -2
  122. pixeltable/io/datarows.py +8 -7
  123. pixeltable/io/external_store.py +56 -105
  124. pixeltable/io/fiftyone.py +13 -13
  125. pixeltable/io/globals.py +31 -30
  126. pixeltable/io/hf_datasets.py +61 -16
  127. pixeltable/io/label_studio.py +74 -70
  128. pixeltable/io/lancedb.py +3 -0
  129. pixeltable/io/pandas.py +21 -12
  130. pixeltable/io/parquet.py +25 -105
  131. pixeltable/io/table_data_conduit.py +250 -123
  132. pixeltable/io/utils.py +4 -4
  133. pixeltable/iterators/__init__.py +2 -1
  134. pixeltable/iterators/audio.py +26 -25
  135. pixeltable/iterators/base.py +9 -3
  136. pixeltable/iterators/document.py +112 -78
  137. pixeltable/iterators/image.py +12 -15
  138. pixeltable/iterators/string.py +11 -4
  139. pixeltable/iterators/video.py +523 -120
  140. pixeltable/metadata/__init__.py +14 -3
  141. pixeltable/metadata/converters/convert_13.py +2 -2
  142. pixeltable/metadata/converters/convert_18.py +2 -2
  143. pixeltable/metadata/converters/convert_19.py +2 -2
  144. pixeltable/metadata/converters/convert_20.py +2 -2
  145. pixeltable/metadata/converters/convert_21.py +2 -2
  146. pixeltable/metadata/converters/convert_22.py +2 -2
  147. pixeltable/metadata/converters/convert_24.py +2 -2
  148. pixeltable/metadata/converters/convert_25.py +2 -2
  149. pixeltable/metadata/converters/convert_26.py +2 -2
  150. pixeltable/metadata/converters/convert_29.py +4 -4
  151. pixeltable/metadata/converters/convert_30.py +34 -21
  152. pixeltable/metadata/converters/convert_34.py +2 -2
  153. pixeltable/metadata/converters/convert_35.py +9 -0
  154. pixeltable/metadata/converters/convert_36.py +38 -0
  155. pixeltable/metadata/converters/convert_37.py +15 -0
  156. pixeltable/metadata/converters/convert_38.py +39 -0
  157. pixeltable/metadata/converters/convert_39.py +124 -0
  158. pixeltable/metadata/converters/convert_40.py +73 -0
  159. pixeltable/metadata/converters/convert_41.py +12 -0
  160. pixeltable/metadata/converters/convert_42.py +9 -0
  161. pixeltable/metadata/converters/convert_43.py +44 -0
  162. pixeltable/metadata/converters/util.py +20 -31
  163. pixeltable/metadata/notes.py +9 -0
  164. pixeltable/metadata/schema.py +140 -53
  165. pixeltable/metadata/utils.py +74 -0
  166. pixeltable/mypy/__init__.py +3 -0
  167. pixeltable/mypy/mypy_plugin.py +123 -0
  168. pixeltable/plan.py +382 -115
  169. pixeltable/share/__init__.py +1 -1
  170. pixeltable/share/packager.py +547 -83
  171. pixeltable/share/protocol/__init__.py +33 -0
  172. pixeltable/share/protocol/common.py +165 -0
  173. pixeltable/share/protocol/operation_types.py +33 -0
  174. pixeltable/share/protocol/replica.py +119 -0
  175. pixeltable/share/publish.py +257 -59
  176. pixeltable/store.py +311 -194
  177. pixeltable/type_system.py +373 -211
  178. pixeltable/utils/__init__.py +2 -3
  179. pixeltable/utils/arrow.py +131 -17
  180. pixeltable/utils/av.py +298 -0
  181. pixeltable/utils/azure_store.py +346 -0
  182. pixeltable/utils/coco.py +6 -6
  183. pixeltable/utils/code.py +3 -3
  184. pixeltable/utils/console_output.py +4 -1
  185. pixeltable/utils/coroutine.py +6 -23
  186. pixeltable/utils/dbms.py +32 -6
  187. pixeltable/utils/description_helper.py +4 -5
  188. pixeltable/utils/documents.py +7 -18
  189. pixeltable/utils/exception_handler.py +7 -30
  190. pixeltable/utils/filecache.py +6 -6
  191. pixeltable/utils/formatter.py +86 -48
  192. pixeltable/utils/gcs_store.py +295 -0
  193. pixeltable/utils/http.py +133 -0
  194. pixeltable/utils/http_server.py +2 -3
  195. pixeltable/utils/iceberg.py +1 -2
  196. pixeltable/utils/image.py +17 -0
  197. pixeltable/utils/lancedb.py +90 -0
  198. pixeltable/utils/local_store.py +322 -0
  199. pixeltable/utils/misc.py +5 -0
  200. pixeltable/utils/object_stores.py +573 -0
  201. pixeltable/utils/pydantic.py +60 -0
  202. pixeltable/utils/pytorch.py +5 -6
  203. pixeltable/utils/s3_store.py +527 -0
  204. pixeltable/utils/sql.py +26 -0
  205. pixeltable/utils/system.py +30 -0
  206. pixeltable-0.5.7.dist-info/METADATA +579 -0
  207. pixeltable-0.5.7.dist-info/RECORD +227 -0
  208. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  209. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  210. pixeltable/__version__.py +0 -3
  211. pixeltable/catalog/named_function.py +0 -40
  212. pixeltable/ext/__init__.py +0 -17
  213. pixeltable/ext/functions/__init__.py +0 -11
  214. pixeltable/ext/functions/whisperx.py +0 -77
  215. pixeltable/utils/media_store.py +0 -77
  216. pixeltable/utils/s3.py +0 -17
  217. pixeltable-0.3.14.dist-info/METADATA +0 -434
  218. pixeltable-0.3.14.dist-info/RECORD +0 -186
  219. pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
  220. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/plan.py CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import enum
5
5
  from textwrap import dedent
6
- from typing import Any, Iterable, Literal, Optional, Sequence
6
+ from typing import Any, Iterable, Literal, Sequence, cast
7
7
  from uuid import UUID
8
8
 
9
+ import pgvector.sqlalchemy # type: ignore[import-untyped]
9
10
  import sqlalchemy as sql
10
11
 
11
12
  import pixeltable as pxt
@@ -65,7 +66,7 @@ class JoinClause:
65
66
  """Corresponds to a single 'JOIN ... ON (...)' clause in a SELECT statement; excludes the joined table."""
66
67
 
67
68
  join_type: JoinType
68
- join_predicate: Optional[exprs.Expr] # None for join_type == CROSS
69
+ join_predicate: exprs.Expr | None # None for join_type == CROSS
69
70
 
70
71
 
71
72
  @dataclasses.dataclass
@@ -75,6 +76,83 @@ class FromClause:
75
76
  tbls: list[catalog.TableVersionPath]
76
77
  join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
77
78
 
79
+ @property
80
+ def _first_tbl(self) -> catalog.TableVersionPath:
81
+ assert len(self.tbls) == 1
82
+ return self.tbls[0]
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class SampleClause:
87
+ """Defines a sampling clause for a table."""
88
+
89
+ version: int | None
90
+ n: int | None
91
+ n_per_stratum: int | None
92
+ fraction: float | None
93
+ seed: int | None
94
+ stratify_exprs: list[exprs.Expr] | None
95
+
96
+ # The version of the hashing algorithm used for ordering and fractional sampling.
97
+ CURRENT_VERSION = 1
98
+
99
+ def __post_init__(self) -> None:
100
+ # If no version was provided, provide the default version
101
+ if self.version is None:
102
+ self.version = self.CURRENT_VERSION
103
+
104
+ @property
105
+ def is_stratified(self) -> bool:
106
+ """Check if the sampling is stratified"""
107
+ return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
108
+
109
+ @property
110
+ def is_repeatable(self) -> bool:
111
+ """Return true if the same rows will continue to be sampled if source rows are added or deleted."""
112
+ return not self.is_stratified and self.fraction is not None
113
+
114
+ def display_str(self, inline: bool = False) -> str:
115
+ return str(self)
116
+
117
+ def as_dict(self) -> dict:
118
+ """Return a dictionary representation of the object"""
119
+ d = dataclasses.asdict(self)
120
+ d['_classname'] = self.__class__.__name__
121
+ if self.is_stratified:
122
+ d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
123
+ return d
124
+
125
+ @classmethod
126
+ def from_dict(cls, d: dict) -> SampleClause:
127
+ """Create a SampleClause from a dictionary representation"""
128
+ d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
129
+ s = cls(**d_cleaned)
130
+ if s.is_stratified:
131
+ s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
132
+ return s
133
+
134
+ def __repr__(self) -> str:
135
+ s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
136
+ return (
137
+ f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
138
+ f'fraction={self.fraction}, seed={self.seed}, [{s}])'
139
+ )
140
+
141
+ @classmethod
142
+ def fraction_to_md5_hex(cls, fraction: float) -> str:
143
+ """Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
144
+ of md5 hash values.
145
+ This is used for fractional sampling.
146
+ """
147
+ # Maximum count for the upper 32 bits of MD5: 2^32
148
+ max_md5_value = (2**32) - 1
149
+
150
+ # Calculate the fraction of this value
151
+ threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
152
+
153
+ # Convert to hexadecimal string with padding
154
+ return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
155
+
78
156
 
79
157
  class Analyzer:
80
158
  """
@@ -84,17 +162,19 @@ class Analyzer:
84
162
  from_clause: FromClause
85
163
  all_exprs: list[exprs.Expr] # union of all exprs, aside from sql_where_clause
86
164
  select_list: list[exprs.Expr]
87
- group_by_clause: Optional[list[exprs.Expr]] # None for non-aggregate queries; [] for agg query w/o grouping
165
+ group_by_clause: list[exprs.Expr] | None # None for non-aggregate queries; [] for agg query w/o grouping
88
166
  grouping_exprs: list[exprs.Expr] # [] for non-aggregate queries or agg query w/o grouping
89
167
  order_by_clause: OrderByClause
168
+ stratify_exprs: list[exprs.Expr] # [] if no stratiifcation is required
169
+ sample_clause: SampleClause | None # None if no sampling clause is present
90
170
 
91
171
  sql_elements: exprs.SqlElementCache
92
172
 
93
173
  # Where clause of the Select stmt of the SQL scan
94
- sql_where_clause: Optional[exprs.Expr]
174
+ sql_where_clause: exprs.Expr | None
95
175
 
96
176
  # filter predicate applied to output rows of the SQL scan
97
- filter: Optional[exprs.Expr]
177
+ filter: exprs.Expr | None
98
178
 
99
179
  agg_fn_calls: list[exprs.FunctionCall] # grouping aggregation (ie, not window functions)
100
180
  window_fn_calls: list[exprs.FunctionCall]
@@ -104,9 +184,10 @@ class Analyzer:
104
184
  self,
105
185
  from_clause: FromClause,
106
186
  select_list: Sequence[exprs.Expr],
107
- where_clause: Optional[exprs.Expr] = None,
108
- group_by_clause: Optional[list[exprs.Expr]] = None,
109
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
187
+ where_clause: exprs.Expr | None = None,
188
+ group_by_clause: list[exprs.Expr] | None = None,
189
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
190
+ sample_clause: SampleClause | None = None,
110
191
  ):
111
192
  if order_by_clause is None:
112
193
  order_by_clause = []
@@ -120,6 +201,11 @@ class Analyzer:
120
201
  self.group_by_clause = (
121
202
  [e.resolve_computed_cols() for e in group_by_clause] if group_by_clause is not None else None
122
203
  )
204
+ self.sample_clause = sample_clause
205
+ if self.sample_clause is not None and self.sample_clause.is_stratified:
206
+ self.stratify_exprs = [e.resolve_computed_cols() for e in sample_clause.stratify_exprs]
207
+ else:
208
+ self.stratify_exprs = []
123
209
  self.order_by_clause = [OrderByItem(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
124
210
 
125
211
  self.sql_where_clause = None
@@ -135,8 +221,11 @@ class Analyzer:
135
221
  self.all_exprs.append(join_clause.join_predicate)
136
222
  if self.group_by_clause is not None:
137
223
  self.all_exprs.extend(self.group_by_clause)
224
+ self.all_exprs.extend(self.stratify_exprs)
138
225
  self.all_exprs.extend(e for e, _ in self.order_by_clause)
139
226
  if self.filter is not None:
227
+ if sample_clause is not None:
228
+ raise excs.Error(f'Filter {self.filter} not expressible in SQL')
140
229
  self.all_exprs.append(self.filter)
141
230
 
142
231
  self.agg_order_by = []
@@ -241,7 +330,7 @@ class Analyzer:
241
330
  row_builder.set_slot_idxs(self.agg_fn_calls)
242
331
  row_builder.set_slot_idxs(self.agg_order_by)
243
332
 
244
- def get_window_fn_ob_clause(self) -> Optional[OrderByClause]:
333
+ def get_window_fn_ob_clause(self) -> OrderByClause | None:
245
334
  clause: list[OrderByClause] = []
246
335
  for fn_call in self.window_fn_calls:
247
336
  # window functions require ordering by the group_by/order_by clauses
@@ -257,21 +346,19 @@ class Analyzer:
257
346
 
258
347
 
259
348
  class Planner:
260
- # TODO: create an exec.CountNode and change this to create_count_plan()
261
349
  @classmethod
262
- def create_count_stmt(cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None) -> sql.Select:
263
- stmt = sql.select(sql.func.count())
264
- refd_tbl_ids: set[UUID] = set()
265
- if where_clause is not None:
266
- analyzer = cls.analyze(tbl, where_clause)
267
- if analyzer.filter is not None:
268
- raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
269
- clause_element = analyzer.sql_where_clause.sql_expr(analyzer.sql_elements)
270
- assert clause_element is not None
271
- stmt = stmt.where(clause_element)
272
- refd_tbl_ids = where_clause.tbl_ids()
273
- stmt = exec.SqlScanNode.create_from_clause(tbl, stmt, refd_tbl_ids)
274
- return stmt
350
+ def create_count_stmt(cls, query: 'pxt.Query') -> sql.Select:
351
+ """Creates a SQL SELECT COUNT(*) statement for counting rows in a Query."""
352
+ # Create the query plan
353
+ plan = query._create_query_plan()
354
+ sql_node = plan.get_node(exec.SqlNode)
355
+ assert sql_node is not None
356
+ if sql_node.py_filter is not None:
357
+ raise excs.Error('count() cannot be used with Python-only filters. Use collect() instead.')
358
+ # Get the SQL statement from the SqlNode as a CTE
359
+ cte, _ = sql_node.to_cte(keep_pk=True)
360
+ count_stmt = sql.select(sql.func.count().label('all_count')).select_from(cte)
361
+ return count_stmt
275
362
 
276
363
  @classmethod
277
364
  def create_insert_plan(
@@ -285,21 +372,12 @@ class Planner:
285
372
 
286
373
  cls.__check_valid_columns(tbl, stored_cols, 'inserted into')
287
374
 
288
- row_builder = exprs.RowBuilder([], stored_cols, [])
375
+ row_builder = exprs.RowBuilder([], stored_cols, [], tbl)
289
376
 
290
377
  # create InMemoryDataNode for 'rows'
291
- plan: exec.ExecNode = exec.InMemoryDataNode(
292
- TableVersionHandle(tbl.id, tbl.effective_version), rows, row_builder, tbl.next_rowid
293
- )
378
+ plan: exec.ExecNode = exec.InMemoryDataNode(tbl.handle, rows, row_builder, tbl.next_row_id)
294
379
 
295
- media_input_col_info = [
296
- exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
297
- for col_ref in row_builder.input_exprs
298
- if isinstance(col_ref, exprs.ColumnRef) and col_ref.col_type.is_media_type()
299
- ]
300
- if len(media_input_col_info) > 0:
301
- # prefetch external files for all input column refs
302
- plan = exec.CachePrefetchNode(tbl.id, media_input_col_info, input=plan)
380
+ plan = cls._add_prefetch_node(tbl.id, row_builder.input_exprs, input_node=plan)
303
381
 
304
382
  computed_exprs = row_builder.output_exprs - row_builder.input_exprs
305
383
  if len(computed_exprs) > 0:
@@ -307,10 +385,9 @@ class Planner:
307
385
  plan = exec.ExprEvalNode(
308
386
  row_builder, computed_exprs, plan.output_exprs, input=plan, maintain_input_order=False
309
387
  )
388
+ if any(c.col_type.supports_file_offloading() for c in stored_cols):
389
+ plan = exec.CellMaterializationNode(plan)
310
390
 
311
- stored_col_info = row_builder.output_slot_idxs()
312
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
313
- plan.set_stored_img_cols(stored_img_col_info)
314
391
  plan.set_ctx(
315
392
  exec.ExecContext(
316
393
  row_builder,
@@ -320,24 +397,34 @@ class Planner:
320
397
  ignore_errors=ignore_errors,
321
398
  )
322
399
  )
400
+ plan = cls._add_save_node(plan)
401
+
323
402
  return plan
324
403
 
325
404
  @classmethod
326
- def create_df_insert_plan(
327
- cls, tbl: catalog.TableVersion, df: 'pxt.DataFrame', ignore_errors: bool
405
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: int | None = None) -> list[exprs.Expr]:
406
+ """Return list of RowidRef for the given number of associated rowids"""
407
+ if num_rowid_cols is None:
408
+ num_rowid_cols = target.get().num_rowid_columns()
409
+ return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
410
+
411
+ @classmethod
412
+ def create_query_insert_plan(
413
+ cls, tbl: catalog.TableVersion, query: 'pxt.Query', ignore_errors: bool
328
414
  ) -> exec.ExecNode:
329
415
  assert not tbl.is_view
330
- plan = df._create_query_plan() # ExecNode constructed by the DataFrame
416
+ plan = query._create_query_plan() # ExecNode constructed by the Query
331
417
 
332
418
  # Modify the plan RowBuilder to register the output columns
333
- for col_name, expr in zip(df.schema.keys(), df._select_list_exprs):
419
+ needs_cell_materialization = False
420
+ for col_name, expr in zip(query.schema.keys(), query._select_list_exprs):
334
421
  assert col_name in tbl.cols_by_name
335
422
  col = tbl.cols_by_name[col_name]
336
423
  plan.row_builder.add_table_column(col, expr.slot_idx)
424
+ needs_cell_materialization = needs_cell_materialization or col.col_type.supports_file_offloading()
337
425
 
338
- stored_col_info = plan.row_builder.output_slot_idxs()
339
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
340
- plan.set_stored_img_cols(stored_img_col_info)
426
+ if needs_cell_materialization:
427
+ plan = exec.CellMaterializationNode(plan)
341
428
 
342
429
  plan.set_ctx(
343
430
  exec.ExecContext(
@@ -354,16 +441,18 @@ class Planner:
354
441
  tbl: catalog.TableVersionPath,
355
442
  update_targets: dict[catalog.Column, exprs.Expr],
356
443
  recompute_targets: list[catalog.Column],
357
- where_clause: Optional[exprs.Expr],
444
+ where_clause: exprs.Expr | None,
358
445
  cascade: bool,
359
446
  ) -> tuple[exec.ExecNode, list[str], list[catalog.Column]]:
360
447
  """Creates a plan to materialize updated rows.
448
+
361
449
  The plan:
362
450
  - retrieves rows that are visible at the current version of the table
363
451
  - materializes all stored columns and the update targets
364
452
  - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
365
453
  and copies the values of all other stored columns
366
454
  - if cascade is False, copies all columns that aren't update targets from the original rows
455
+
367
456
  Returns:
368
457
  - root node of the plan
369
458
  - list of qualified column names that are getting updated
@@ -373,26 +462,33 @@ class Planner:
373
462
  assert isinstance(tbl, catalog.TableVersionPath)
374
463
  target = tbl.tbl_version.get() # the one we need to update
375
464
  updated_cols = list(update_targets.keys())
465
+ recomputed_cols: set[Column]
376
466
  if len(recompute_targets) > 0:
377
- recomputed_cols = set(recompute_targets)
467
+ assert len(update_targets) == 0
468
+ recomputed_cols = {*recompute_targets}
469
+ if cascade:
470
+ recomputed_cols |= target.get_dependent_columns(recomputed_cols)
378
471
  else:
379
472
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
380
- # regardless of cascade, we need to update all indices on any updated column
381
- idx_val_cols = target.get_idx_val_columns(updated_cols)
382
- recomputed_cols.update(idx_val_cols)
383
- # we only need to recompute stored columns (unstored ones are substituted away)
384
- recomputed_cols = {c for c in recomputed_cols if c.is_stored}
473
+ # regardless of cascade, we need to update all indices on any updated/recomputed column
474
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
475
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
476
+ recomputed_cols.update(idx_val_cols)
477
+ # we only need to recompute stored columns (unstored ones are substituted away)
478
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
385
479
 
386
480
  cls.__check_valid_columns(tbl.tbl_version.get(), recomputed_cols, 'updated in')
387
481
 
388
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == tbl.tbl_version}
482
+ # our query plan
483
+ # - evaluates the update targets and recomputed columns
484
+ # - copies all other stored columns
485
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == tbl.tbl_version.id}
389
486
  copied_cols = [
390
487
  col
391
488
  for col in target.cols_by_id.values()
392
489
  if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
393
490
  ]
394
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
395
- select_list.extend(update_targets.values())
491
+ select_list: list[exprs.Expr] = list(update_targets.values())
396
492
 
397
493
  recomputed_exprs = [
398
494
  c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
@@ -403,13 +499,25 @@ class Planner:
403
499
  select_list.extend(recomputed_exprs)
404
500
 
405
501
  # we need to retrieve the PK columns of the existing rows
406
- plan = cls.create_query_plan(FromClause(tbls=[tbl]), select_list, where_clause=where_clause, ignore_errors=True)
407
- all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
502
+ plan = cls.create_query_plan(
503
+ FromClause(tbls=[tbl]),
504
+ select_list=select_list,
505
+ columns=copied_cols,
506
+ where_clause=where_clause,
507
+ ignore_errors=True,
508
+ )
509
+ evaluated_cols = updated_cols + list(recomputed_base_cols) # same order as select_list
408
510
  # update row builder with column information
409
- for i, col in enumerate(all_base_cols):
511
+ plan.row_builder.add_table_columns(copied_cols)
512
+ for i, col in enumerate(evaluated_cols):
410
513
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
514
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
515
+
516
+ plan = cls._add_cell_materialization_node(plan)
517
+ plan = cls._add_save_node(plan)
518
+
411
519
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
412
- return plan, [f'{c.tbl.get().name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
520
+ return plan, [f'{c.get_tbl().name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
413
521
 
414
522
  @classmethod
415
523
  def __check_valid_columns(
@@ -429,6 +537,94 @@ class Planner:
429
537
  .format(validation_error=col.value_expr.validation_error)
430
538
  )
431
539
 
540
+ @classmethod
541
+ def _cell_md_col_refs(cls, expr_list: Iterable[exprs.Expr]) -> list[exprs.ColumnRef]:
542
+ """Return list of ColumnRefs that need their cellmd values for reconstruction"""
543
+ json_col_refs = list(
544
+ exprs.Expr.list_subexprs(
545
+ expr_list,
546
+ expr_class=exprs.ColumnRef,
547
+ filter=lambda e: cast(exprs.ColumnRef, e).col.col_type.is_json_type(),
548
+ traverse_matches=False,
549
+ )
550
+ )
551
+
552
+ def needs_reconstruction(e: exprs.Expr) -> bool:
553
+ assert isinstance(e, exprs.ColumnRef)
554
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
555
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
556
+
557
+ array_col_refs = list(
558
+ exprs.Expr.list_subexprs(
559
+ expr_list, expr_class=exprs.ColumnRef, filter=needs_reconstruction, traverse_matches=False
560
+ )
561
+ )
562
+
563
+ binary_col_refs = list(
564
+ exprs.Expr.list_subexprs(
565
+ expr_list,
566
+ expr_class=exprs.ColumnRef,
567
+ filter=lambda e: cast(exprs.ColumnRef, e).col.col_type.is_binary_type(),
568
+ traverse_matches=False,
569
+ )
570
+ )
571
+
572
+ return json_col_refs + array_col_refs + binary_col_refs
573
+
574
+ @classmethod
575
+ def _add_cell_materialization_node(cls, input: exec.ExecNode) -> exec.ExecNode:
576
+ # we need a CellMaterializationNode if any of the evaluated output columns are json or array-typed
577
+ has_target_cols = any(
578
+ col.col_type.supports_file_offloading()
579
+ for col, slot_idx in input.row_builder.table_columns.items()
580
+ if slot_idx is not None
581
+ )
582
+ if has_target_cols:
583
+ return exec.CellMaterializationNode(input)
584
+ else:
585
+ return input
586
+
587
+ @classmethod
588
+ def _add_cell_reconstruction_node(cls, expr_list: list[exprs.Expr], input: exec.ExecNode) -> exec.ExecNode:
589
+ """
590
+ Add a CellReconstructionNode, if required by any of the exprs in expr_list.
591
+
592
+ Cell reconstruction is required for
593
+ 1) all json-typed ColumnRefs that are not used as part of a JsonPath (the latter does its own reconstruction)
594
+ or as part of a ColumnPropertyRef
595
+ 2) all array-typed ColumnRefs that are not used as part of a ColumnPropertyRef
596
+ """
597
+
598
+ def json_filter(e: exprs.Expr) -> bool:
599
+ if isinstance(e, exprs.JsonPath):
600
+ return not e.is_relative_path() and isinstance(e.anchor, exprs.ColumnRef)
601
+ if isinstance(e, exprs.ColumnPropertyRef):
602
+ return e.col_ref.col.col_type.is_json_type()
603
+ return isinstance(e, exprs.ColumnRef) and e.col.col_type.is_json_type()
604
+
605
+ def array_filter(e: exprs.Expr) -> bool:
606
+ if isinstance(e, exprs.ColumnPropertyRef):
607
+ return e.col_ref.col.col_type.is_array_type()
608
+ if not isinstance(e, exprs.ColumnRef):
609
+ return False
610
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
611
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
612
+
613
+ def binary_filter(e: exprs.Expr) -> bool:
614
+ return isinstance(e, exprs.ColumnRef) and e.col.col_type.is_binary_type()
615
+
616
+ json_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=json_filter, traverse_matches=False))
617
+ json_refs = [e for e in json_candidates if isinstance(e, exprs.ColumnRef)]
618
+ array_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=array_filter, traverse_matches=False))
619
+ array_refs = [e for e in array_candidates if isinstance(e, exprs.ColumnRef)]
620
+ binary_refs = list(
621
+ exprs.Expr.list_subexprs(expr_list, exprs.ColumnRef, filter=binary_filter, traverse_matches=False)
622
+ )
623
+ if len(json_refs) > 0 or len(array_refs) > 0 or len(binary_refs) > 0:
624
+ return exec.CellReconstructionNode(json_refs, array_refs, binary_refs, input.row_builder, input=input)
625
+ else:
626
+ return input
627
+
432
628
  @classmethod
433
629
  def create_batch_update_plan(
434
630
  cls,
@@ -447,8 +643,8 @@ class Planner:
447
643
  """
448
644
  assert isinstance(tbl, catalog.TableVersionPath)
449
645
  target = tbl.tbl_version.get() # the one we need to update
450
- sa_key_cols: list[sql.Column] = []
451
- key_vals: list[tuple] = []
646
+ sa_key_cols: list[sql.Column]
647
+ key_vals: list[tuple]
452
648
  if len(rowids) > 0:
453
649
  sa_key_cols = target.store_tbl.rowid_columns()
454
650
  key_vals = rowids
@@ -461,18 +657,18 @@ class Planner:
461
657
  updated_cols = batch[0].keys() - target.primary_key_columns()
462
658
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
463
659
  # regardless of cascade, we need to update all indices on any updated column
464
- idx_val_cols = target.get_idx_val_columns(updated_cols)
660
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
661
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
465
662
  recomputed_cols.update(idx_val_cols)
466
663
  # we only need to recompute stored columns (unstored ones are substituted away)
467
664
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
468
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
665
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == target.id}
469
666
  copied_cols = [
470
667
  col
471
668
  for col in target.cols_by_id.values()
472
669
  if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
473
670
  ]
474
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
475
- select_list.extend(exprs.ColumnRef(col) for col in updated_cols)
671
+ select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in updated_cols]
476
672
 
477
673
  recomputed_exprs = [
478
674
  c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
@@ -488,25 +684,39 @@ class Planner:
488
684
  sql_exprs = list(
489
685
  exprs.Expr.list_subexprs(analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False)
490
686
  )
491
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs)
687
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs, target)
492
688
  analyzer.finalize(row_builder)
493
- sql_lookup_node = exec.SqlLookupNode(tbl, row_builder, sql_exprs, sa_key_cols, key_vals)
689
+
690
+ cell_md_col_refs = cls._cell_md_col_refs(sql_exprs)
691
+ sql_lookup_node = exec.SqlLookupNode(
692
+ tbl,
693
+ row_builder,
694
+ sql_exprs,
695
+ columns=copied_cols,
696
+ sa_key_cols=sa_key_cols,
697
+ key_vals=key_vals,
698
+ cell_md_col_refs=cell_md_col_refs,
699
+ )
494
700
  col_vals = [{col: row[col].val for col in updated_cols} for row in batch]
495
701
  row_update_node = exec.RowUpdateNode(tbl, key_vals, len(rowids) > 0, col_vals, row_builder, sql_lookup_node)
496
702
  plan: exec.ExecNode = row_update_node
497
703
  if not cls._is_contained_in(analyzer.select_list, sql_exprs):
498
704
  # we need an ExprEvalNode to evaluate the remaining output exprs
499
705
  plan = exec.ExprEvalNode(row_builder, analyzer.select_list, sql_exprs, input=plan)
706
+
500
707
  # update row builder with column information
501
- all_base_cols = copied_cols + list(updated_cols) + list(recomputed_base_cols) # same order as select_list
708
+ evaluated_cols = list(updated_cols) + list(recomputed_base_cols) # same order as select_list
502
709
  row_builder.set_slot_idxs(select_list, remove_duplicates=False)
503
- for i, col in enumerate(all_base_cols):
710
+ plan.row_builder.add_table_columns(copied_cols)
711
+ for i, col in enumerate(evaluated_cols):
504
712
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
505
-
506
- ctx = exec.ExecContext(row_builder)
507
- # we're returning everything to the user, so we might as well do it in a single batch
713
+ ctx = exec.ExecContext(row_builder, num_computed_exprs=len(recomputed_exprs))
714
+ # TODO: correct batch size?
508
715
  ctx.batch_size = 0
509
716
  plan.set_ctx(ctx)
717
+
718
+ plan = cls._add_cell_materialization_node(plan)
719
+ plan = cls._add_save_node(plan)
510
720
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
511
721
  return (
512
722
  plan,
@@ -556,13 +766,13 @@ class Planner:
556
766
  ignore_errors=True,
557
767
  exact_version_only=view.get_bases(),
558
768
  )
559
- for i, col in enumerate(copied_cols + list(recomputed_cols)): # same order as select_list
769
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
770
+ materialized_cols = copied_cols + list(recomputed_cols) # same order as select_list
771
+ for i, col in enumerate(materialized_cols):
560
772
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
561
- # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
562
- stored_img_col_info = [
563
- info for info in plan.row_builder.output_slot_idxs() if info.col.col_type.is_image_type()
564
- ]
565
- plan.set_stored_img_cols(stored_img_col_info)
773
+ plan = cls._add_cell_materialization_node(plan)
774
+ plan = cls._add_save_node(plan)
775
+
566
776
  return plan
567
777
 
568
778
  @classmethod
@@ -591,8 +801,13 @@ class Planner:
591
801
  # 2. for component views: iterator args
592
802
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
593
803
 
594
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
804
+ from_clause = FromClause(tbls=[view.base])
805
+ base_analyzer = Analyzer(
806
+ from_clause, iterator_args, where_clause=target.predicate, sample_clause=target.sample_clause
807
+ )
808
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [], target, for_view_load=True)
595
809
 
810
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
596
811
  # execution plan:
597
812
  # 1. materialize exprs computed from the base that are needed for stored view columns
598
813
  # 2. if it's an iterator view, expand the base rows into component rows
@@ -603,8 +818,11 @@ class Planner:
603
818
  for e in row_builder.default_eval_ctx.target_exprs
604
819
  if e.is_bound_by([view]) and not e.is_bound_by([view.base])
605
820
  ]
606
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
607
- base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
821
+
822
+ # Create a new analyzer reflecting exactly what is required from the base table
823
+ base_analyzer = Analyzer(
824
+ from_clause, base_output_exprs, where_clause=target.predicate, sample_clause=target.sample_clause
825
+ )
608
826
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
609
827
  plan = cls._create_query_plan(
610
828
  row_builder=row_builder,
@@ -621,10 +839,12 @@ class Planner:
621
839
  row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs, input=plan
622
840
  )
623
841
 
624
- stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
625
- plan.set_stored_img_cols(stored_img_col_info)
626
842
  exec_ctx.ignore_errors = True
627
843
  plan.set_ctx(exec_ctx)
844
+ if any(c.col_type.supports_file_offloading() for c in stored_cols):
845
+ plan = exec.CellMaterializationNode(plan)
846
+ plan = cls._add_save_node(plan)
847
+
628
848
  return plan, len(row_builder.default_eval_ctx.target_exprs)
629
849
 
630
850
  @classmethod
@@ -635,7 +855,7 @@ class Planner:
635
855
  raise excs.Error(f'Join predicate {join_clause.join_predicate} not expressible in SQL')
636
856
 
637
857
  @classmethod
638
- def _create_combined_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> Optional[OrderByClause]:
858
+ def _create_combined_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> OrderByClause | None:
639
859
  """Verify that the various ordering requirements don't conflict and return a combined ordering"""
640
860
  ob_clauses: list[OrderByClause] = [analyzer.order_by_clause.copy()]
641
861
 
@@ -669,22 +889,29 @@ class Planner:
669
889
  combined_ordering = combined
670
890
  return combined_ordering
671
891
 
892
+ @classmethod
893
+ def _add_save_node(cls, input_node: exec.ExecNode) -> exec.ExecNode:
894
+ """Add an ObjectStoreSaveNode, if needed."""
895
+ media_col_info = input_node.row_builder.media_output_col_info
896
+ if len(media_col_info) == 0:
897
+ return input_node
898
+ else:
899
+ return exec.ObjectStoreSaveNode(media_col_info, input_node)
900
+
672
901
  @classmethod
673
902
  def _is_contained_in(cls, l1: Iterable[exprs.Expr], l2: Iterable[exprs.Expr]) -> bool:
674
903
  """Returns True if l1 is contained in l2"""
675
904
  return {e.id for e in l1} <= {e.id for e in l2}
676
905
 
677
906
  @classmethod
678
- def _insert_prefetch_node(
679
- cls, tbl_id: UUID, row_builder: exprs.RowBuilder, input_node: exec.ExecNode
907
+ def _add_prefetch_node(
908
+ cls, tbl_id: UUID, expressions: Iterable[exprs.Expr], input_node: exec.ExecNode
680
909
  ) -> exec.ExecNode:
681
- """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
910
+ """Add a CachePrefetch node, if needed."""
682
911
  # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
683
912
  # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
684
913
  # aren't explicitly captured as dependencies
685
- media_col_refs = [
686
- e for e in list(row_builder.unique_exprs) if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()
687
- ]
914
+ media_col_refs = [e for e in expressions if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()]
688
915
  if len(media_col_refs) == 0:
689
916
  return input_node
690
917
  # we need to prefetch external files for media column types
@@ -696,32 +923,48 @@ class Planner:
696
923
  def create_query_plan(
697
924
  cls,
698
925
  from_clause: FromClause,
699
- select_list: Optional[list[exprs.Expr]] = None,
700
- where_clause: Optional[exprs.Expr] = None,
701
- group_by_clause: Optional[list[exprs.Expr]] = None,
702
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None,
703
- limit: Optional[exprs.Expr] = None,
926
+ select_list: list[exprs.Expr] | None = None,
927
+ columns: list[catalog.Column] | None = None,
928
+ where_clause: exprs.Expr | None = None,
929
+ group_by_clause: list[exprs.Expr] | None = None,
930
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
931
+ limit: exprs.Expr | None = None,
932
+ sample_clause: SampleClause | None = None,
704
933
  ignore_errors: bool = False,
705
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
934
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
706
935
  ) -> exec.ExecNode:
707
- """Return plan for executing a query.
936
+ """
937
+ Return plan for executing a query.
938
+
939
+ The plan:
940
+ - materializes the values of select_list exprs into their respective slots
941
+ - materializes cell values of 'columns' (and their cellmd, if applicable) into DataRow.cell_vals/cell_md
942
+
708
943
  Updates 'select_list' in place to make it executable.
709
944
  TODO: make exact_version_only a flag and use the versions from tbl
710
945
  """
711
946
  if select_list is None:
712
947
  select_list = []
948
+ if columns is None:
949
+ columns = []
713
950
  if order_by_clause is None:
714
951
  order_by_clause = []
715
952
  if exact_version_only is None:
716
953
  exact_version_only = []
954
+
717
955
  analyzer = Analyzer(
718
956
  from_clause,
719
957
  select_list,
720
958
  where_clause=where_clause,
721
959
  group_by_clause=group_by_clause,
722
960
  order_by_clause=order_by_clause,
961
+ sample_clause=sample_clause,
723
962
  )
724
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
963
+ # If the from_clause has a single table, we can use it as the context table for the RowBuilder.
964
+ # Otherwise there is no context table, but that's ok, because the context table is only needed for
965
+ # table mutations, which can't happen during a join.
966
+ context_tbl = from_clause.tbls[0].tbl_version.get() if len(from_clause.tbls) == 1 else None
967
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [], context_tbl)
725
968
 
726
969
  analyzer.finalize(row_builder)
727
970
  # select_list: we need to materialize everything that's been collected
@@ -731,6 +974,7 @@ class Planner:
731
974
  row_builder=row_builder,
732
975
  analyzer=analyzer,
733
976
  eval_ctx=eval_ctx,
977
+ columns=columns,
734
978
  limit=limit,
735
979
  with_pk=True,
736
980
  exact_version_only=exact_version_only,
@@ -746,9 +990,10 @@ class Planner:
746
990
  row_builder: exprs.RowBuilder,
747
991
  analyzer: Analyzer,
748
992
  eval_ctx: exprs.RowBuilder.EvalCtx,
749
- limit: Optional[exprs.Expr] = None,
993
+ columns: list[catalog.Column] | None = None,
994
+ limit: exprs.Expr | None = None,
750
995
  with_pk: bool = False,
751
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
996
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
752
997
  ) -> exec.ExecNode:
753
998
  """
754
999
  Create plan to materialize eval_ctx.
@@ -758,6 +1003,8 @@ class Planner:
758
1003
  in the context of that table version (eg, if 'tbl' is a view, 'plan_target' might be the base)
759
1004
  TODO: make exact_version_only a flag and use the versions from tbl
760
1005
  """
1006
+ if columns is None:
1007
+ columns = []
761
1008
  if exact_version_only is None:
762
1009
  exact_version_only = []
763
1010
  sql_elements = analyzer.sql_elements
@@ -765,6 +1012,7 @@ class Planner:
765
1012
  analyzer.window_fn_calls
766
1013
  )
767
1014
  ctx = exec.ExecContext(row_builder)
1015
+
768
1016
  combined_ordering = cls._create_combined_ordering(analyzer, verify_agg=is_python_agg)
769
1017
  cls._verify_join_clauses(analyzer)
770
1018
 
@@ -773,6 +1021,7 @@ class Planner:
773
1021
  # - join clause subexprs
774
1022
  # - subexprs of Where clause conjuncts that can't be run in SQL
775
1023
  # - all grouping exprs
1024
+ # - all stratify exprs
776
1025
  candidates = list(
777
1026
  exprs.Expr.list_subexprs(
778
1027
  analyzer.select_list,
@@ -787,10 +1036,12 @@ class Planner:
787
1036
  candidates.extend(
788
1037
  exprs.Expr.subexprs(analyzer.filter, filter=sql_elements.contains, traverse_matches=False)
789
1038
  )
790
- if analyzer.group_by_clause is not None:
791
- candidates.extend(
792
- exprs.Expr.list_subexprs(analyzer.group_by_clause, filter=sql_elements.contains, traverse_matches=False)
793
- )
1039
+ candidates.extend(
1040
+ exprs.Expr.list_subexprs(analyzer.grouping_exprs, filter=sql_elements.contains, traverse_matches=False)
1041
+ )
1042
+ candidates.extend(
1043
+ exprs.Expr.list_subexprs(analyzer.stratify_exprs, filter=sql_elements.contains, traverse_matches=False)
1044
+ )
794
1045
  # not isinstance(...): we don't want to materialize Literals via a Select
795
1046
  sql_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
796
1047
 
@@ -812,8 +1063,15 @@ class Planner:
812
1063
  traverse_matches=False,
813
1064
  )
814
1065
  )
1066
+
815
1067
  plan = exec.SqlScanNode(
816
- tbl, row_builder, select_list=tbl_scan_exprs, set_pk=with_pk, exact_version_only=exact_version_only
1068
+ tbl,
1069
+ row_builder,
1070
+ select_list=tbl_scan_exprs,
1071
+ columns=[c for c in columns if c.get_tbl().id == tbl.tbl_id],
1072
+ set_pk=with_pk,
1073
+ cell_md_col_refs=cls._cell_md_col_refs(tbl_scan_exprs),
1074
+ exact_version_only=exact_version_only,
817
1075
  )
818
1076
  tbl_scan_plans.append(plan)
819
1077
 
@@ -835,7 +1093,17 @@ class Planner:
835
1093
  # we need to order the input for window functions
836
1094
  plan.set_order_by(analyzer.get_window_fn_ob_clause())
837
1095
 
838
- plan = cls._insert_prefetch_node(tbl.tbl_version.id, row_builder, plan)
1096
+ if analyzer.sample_clause is not None:
1097
+ plan = exec.SqlSampleNode(
1098
+ row_builder,
1099
+ input=plan,
1100
+ select_list=tbl_scan_exprs,
1101
+ sample_clause=analyzer.sample_clause,
1102
+ stratify_exprs=analyzer.stratify_exprs,
1103
+ )
1104
+
1105
+ plan = cls._add_prefetch_node(tbl.tbl_version.id, row_builder.unique_exprs, plan)
1106
+ plan = cls._add_cell_reconstruction_node(analyzer.all_exprs, plan)
839
1107
 
840
1108
  if analyzer.group_by_clause is not None:
841
1109
  # we're doing grouping aggregation; the input of the AggregateNode are the grouping exprs plus the
@@ -879,6 +1147,7 @@ class Planner:
879
1147
  if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
880
1148
  # we need an ExprEvalNode to evaluate the remaining output exprs
881
1149
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
1150
+ plan = cls._add_save_node(plan)
882
1151
  else:
883
1152
  if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
884
1153
  # we need an ExprEvalNode to evaluate the remaining output exprs
@@ -912,26 +1181,24 @@ class Planner:
912
1181
  return Analyzer(FromClause(tbls=[tbl]), [], where_clause=where_clause)
913
1182
 
914
1183
  @classmethod
915
- def create_add_column_plan(
916
- cls, tbl: catalog.TableVersionPath, col: catalog.Column
917
- ) -> tuple[exec.ExecNode, Optional[int]]:
1184
+ def create_add_column_plan(cls, tbl: catalog.TableVersionPath, col: catalog.Column) -> exec.ExecNode:
918
1185
  """Creates a plan for InsertableTable.add_column()
919
1186
  Returns:
920
1187
  plan: the plan to execute
921
1188
  value_expr slot idx for the plan output (for computed cols)
922
1189
  """
923
1190
  assert isinstance(tbl, catalog.TableVersionPath)
924
- row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
1191
+ row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[], tbl=tbl.tbl_version.get())
925
1192
  analyzer = Analyzer(FromClause(tbls=[tbl]), row_builder.default_eval_ctx.target_exprs)
926
1193
  plan = cls._create_query_plan(
927
1194
  row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True
928
1195
  )
1196
+
929
1197
  plan.ctx.batch_size = 16
930
1198
  plan.ctx.show_pbar = True
931
1199
  plan.ctx.ignore_errors = True
1200
+ computed_exprs = row_builder.output_exprs - row_builder.input_exprs
1201
+ plan.ctx.num_computed_exprs = len(computed_exprs) # we are adding a computed column, so we need to evaluate it
1202
+ plan = cls._add_save_node(plan)
932
1203
 
933
- # we want to flush images
934
- if col.is_computed and col.is_stored and col.col_type.is_image_type():
935
- plan.set_stored_img_cols(row_builder.output_slot_idxs())
936
- value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
937
- return plan, value_expr_slot_idx
1204
+ return plan