pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
pixeltable/plan.py CHANGED
@@ -2,17 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  import enum
5
- from typing import Any, Iterable, Optional, Sequence, Literal
5
+ from textwrap import dedent
6
+ from typing import Any, Iterable, Literal, Sequence, cast
6
7
  from uuid import UUID
7
8
 
9
+ import pgvector.sqlalchemy # type: ignore[import-untyped]
8
10
  import sqlalchemy as sql
9
11
 
10
12
  import pixeltable as pxt
11
- import pixeltable.exec as exec
12
- from pixeltable import catalog
13
- from pixeltable import exceptions as excs
14
- from pixeltable import exprs
15
- from pixeltable.exec.sql_node import OrderByItem, OrderByClause, combine_order_by_clauses, print_order_by_clause
13
+ from pixeltable import catalog, exceptions as excs, exec, exprs
14
+ from pixeltable.catalog import Column, TableVersionHandle
15
+ from pixeltable.exec.sql_node import OrderByClause, OrderByItem, combine_order_by_clauses, print_order_by_clause
16
16
 
17
17
 
18
18
  def _is_agg_fn_call(e: exprs.Expr) -> bool:
@@ -20,7 +20,7 @@ def _is_agg_fn_call(e: exprs.Expr) -> bool:
20
20
 
21
21
 
22
22
  def _get_combined_ordering(
23
- o1: list[tuple[exprs.Expr, bool]], o2: list[tuple[exprs.Expr, bool]]
23
+ o1: list[tuple[exprs.Expr, bool]], o2: list[tuple[exprs.Expr, bool]]
24
24
  ) -> list[tuple[exprs.Expr, bool]]:
25
25
  """Returns an ordering that's compatible with both o1 and o2, or an empty list if no such ordering exists"""
26
26
  result: list[tuple[exprs.Expr, bool]] = []
@@ -56,24 +56,103 @@ class JoinType(enum.Enum):
56
56
  def validated(cls, name: str, error_prefix: str) -> JoinType:
57
57
  try:
58
58
  return cls[name.upper()]
59
- except KeyError:
60
- val_strs = ', '.join(f'{s.lower()!r}' for s in cls.__members__.keys())
61
- raise excs.Error(f'{error_prefix} must be one of: [{val_strs}]')
59
+ except KeyError as exc:
60
+ val_strs = ', '.join(f'{s.lower()!r}' for s in cls.__members__)
61
+ raise excs.Error(f'{error_prefix} must be one of: [{val_strs}]') from exc
62
62
 
63
63
 
64
64
  @dataclasses.dataclass
65
65
  class JoinClause:
66
66
  """Corresponds to a single 'JOIN ... ON (...)' clause in a SELECT statement; excludes the joined table."""
67
+
67
68
  join_type: JoinType
68
- join_predicate: Optional[exprs.Expr] # None for join_type == CROSS
69
+ join_predicate: exprs.Expr | None # None for join_type == CROSS
69
70
 
70
71
 
71
72
  @dataclasses.dataclass
72
73
  class FromClause:
73
- """Corresponds to the From-clause ('FROM <tbl> JOIN ... ON (...) JOIN ...') of a SELECT statement """
74
+ """Corresponds to the From-clause ('FROM <tbl> JOIN ... ON (...) JOIN ...') of a SELECT statement"""
75
+
74
76
  tbls: list[catalog.TableVersionPath]
75
77
  join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
76
78
 
79
+ @property
80
+ def _first_tbl(self) -> catalog.TableVersionPath:
81
+ assert len(self.tbls) == 1
82
+ return self.tbls[0]
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class SampleClause:
87
+ """Defines a sampling clause for a table."""
88
+
89
+ version: int | None
90
+ n: int | None
91
+ n_per_stratum: int | None
92
+ fraction: float | None
93
+ seed: int | None
94
+ stratify_exprs: list[exprs.Expr] | None
95
+
96
+ # The version of the hashing algorithm used for ordering and fractional sampling.
97
+ CURRENT_VERSION = 1
98
+
99
+ def __post_init__(self) -> None:
100
+ # If no version was provided, provide the default version
101
+ if self.version is None:
102
+ self.version = self.CURRENT_VERSION
103
+
104
+ @property
105
+ def is_stratified(self) -> bool:
106
+ """Check if the sampling is stratified"""
107
+ return self.stratify_exprs is not None and len(self.stratify_exprs) > 0
108
+
109
+ @property
110
+ def is_repeatable(self) -> bool:
111
+ """Return true if the same rows will continue to be sampled if source rows are added or deleted."""
112
+ return not self.is_stratified and self.fraction is not None
113
+
114
+ def display_str(self, inline: bool = False) -> str:
115
+ return str(self)
116
+
117
+ def as_dict(self) -> dict:
118
+ """Return a dictionary representation of the object"""
119
+ d = dataclasses.asdict(self)
120
+ d['_classname'] = self.__class__.__name__
121
+ if self.is_stratified:
122
+ d['stratify_exprs'] = [e.as_dict() for e in self.stratify_exprs]
123
+ return d
124
+
125
+ @classmethod
126
+ def from_dict(cls, d: dict) -> SampleClause:
127
+ """Create a SampleClause from a dictionary representation"""
128
+ d_cleaned = {key: value for key, value in d.items() if key != '_classname'}
129
+ s = cls(**d_cleaned)
130
+ if s.is_stratified:
131
+ s.stratify_exprs = [exprs.Expr.from_dict(e) for e in d_cleaned.get('stratify_exprs', [])]
132
+ return s
133
+
134
+ def __repr__(self) -> str:
135
+ s = ','.join(e.display_str(inline=True) for e in self.stratify_exprs)
136
+ return (
137
+ f'sample_{self.version}(n={self.n}, n_per_stratum={self.n_per_stratum}, '
138
+ f'fraction={self.fraction}, seed={self.seed}, [{s}])'
139
+ )
140
+
141
+ @classmethod
142
+ def fraction_to_md5_hex(cls, fraction: float) -> str:
143
+ """Return the string representation of an approximation (to ~1e-9) of a fraction of the total space
144
+ of md5 hash values.
145
+ This is used for fractional sampling.
146
+ """
147
+ # Maximum count for the upper 32 bits of MD5: 2^32
148
+ max_md5_value = (2**32) - 1
149
+
150
+ # Calculate the fraction of this value
151
+ threshold_int = max_md5_value * int(1_000_000_000 * fraction) // 1_000_000_000
152
+
153
+ # Convert to hexadecimal string with padding
154
+ return format(threshold_int, '08x') + 'ffffffffffffffffffffffff'
155
+
77
156
 
78
157
  class Analyzer:
79
158
  """
@@ -83,26 +162,33 @@ class Analyzer:
83
162
  from_clause: FromClause
84
163
  all_exprs: list[exprs.Expr] # union of all exprs, aside from sql_where_clause
85
164
  select_list: list[exprs.Expr]
86
- group_by_clause: Optional[list[exprs.Expr]] # None for non-aggregate queries; [] for agg query w/o grouping
165
+ group_by_clause: list[exprs.Expr] | None # None for non-aggregate queries; [] for agg query w/o grouping
87
166
  grouping_exprs: list[exprs.Expr] # [] for non-aggregate queries or agg query w/o grouping
88
167
  order_by_clause: OrderByClause
168
+ stratify_exprs: list[exprs.Expr] # [] if no stratiifcation is required
169
+ sample_clause: SampleClause | None # None if no sampling clause is present
89
170
 
90
171
  sql_elements: exprs.SqlElementCache
91
172
 
92
173
  # Where clause of the Select stmt of the SQL scan
93
- sql_where_clause: Optional[exprs.Expr]
174
+ sql_where_clause: exprs.Expr | None
94
175
 
95
176
  # filter predicate applied to output rows of the SQL scan
96
- filter: Optional[exprs.Expr]
177
+ filter: exprs.Expr | None
97
178
 
98
179
  agg_fn_calls: list[exprs.FunctionCall] # grouping aggregation (ie, not window functions)
99
180
  window_fn_calls: list[exprs.FunctionCall]
100
181
  agg_order_by: list[exprs.Expr]
101
182
 
102
183
  def __init__(
103
- self, from_clause: FromClause, select_list: Sequence[exprs.Expr],
104
- where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[list[exprs.Expr]] = None,
105
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None):
184
+ self,
185
+ from_clause: FromClause,
186
+ select_list: Sequence[exprs.Expr],
187
+ where_clause: exprs.Expr | None = None,
188
+ group_by_clause: list[exprs.Expr] | None = None,
189
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
190
+ sample_clause: SampleClause | None = None,
191
+ ):
106
192
  if order_by_clause is None:
107
193
  order_by_clause = []
108
194
  self.from_clause = from_clause
@@ -115,6 +201,11 @@ class Analyzer:
115
201
  self.group_by_clause = (
116
202
  [e.resolve_computed_cols() for e in group_by_clause] if group_by_clause is not None else None
117
203
  )
204
+ self.sample_clause = sample_clause
205
+ if self.sample_clause is not None and self.sample_clause.is_stratified:
206
+ self.stratify_exprs = [e.resolve_computed_cols() for e in sample_clause.stratify_exprs]
207
+ else:
208
+ self.stratify_exprs = []
118
209
  self.order_by_clause = [OrderByItem(e.resolve_computed_cols(), asc) for e, asc in order_by_clause]
119
210
 
120
211
  self.sql_where_clause = None
@@ -130,8 +221,11 @@ class Analyzer:
130
221
  self.all_exprs.append(join_clause.join_predicate)
131
222
  if self.group_by_clause is not None:
132
223
  self.all_exprs.extend(self.group_by_clause)
224
+ self.all_exprs.extend(self.stratify_exprs)
133
225
  self.all_exprs.extend(e for e, _ in self.order_by_clause)
134
226
  if self.filter is not None:
227
+ if sample_clause is not None:
228
+ raise excs.Error(f'Filter {self.filter} not expressible in SQL')
135
229
  self.all_exprs.append(self.filter)
136
230
 
137
231
  self.agg_order_by = []
@@ -145,12 +239,17 @@ class Analyzer:
145
239
  candidates = self.select_list
146
240
  agg_fn_calls = exprs.ExprSet(
147
241
  exprs.Expr.list_subexprs(
148
- candidates, expr_class=exprs.FunctionCall,
149
- filter=lambda e: bool(e.is_agg_fn_call and not e.is_window_fn_call)))
242
+ candidates,
243
+ expr_class=exprs.FunctionCall,
244
+ filter=lambda e: bool(e.is_agg_fn_call and not e.is_window_fn_call),
245
+ )
246
+ )
150
247
  self.agg_fn_calls = list(agg_fn_calls)
151
248
  window_fn_calls = exprs.ExprSet(
152
249
  exprs.Expr.list_subexprs(
153
- candidates, expr_class=exprs.FunctionCall, filter=lambda e: bool(e.is_window_fn_call)))
250
+ candidates, expr_class=exprs.FunctionCall, filter=lambda e: bool(e.is_window_fn_call)
251
+ )
252
+ )
154
253
  self.window_fn_calls = list(window_fn_calls)
155
254
  if len(self.agg_fn_calls) == 0:
156
255
  # nothing to do
@@ -164,19 +263,25 @@ class Analyzer:
164
263
  is_agg_output = [self._determine_agg_status(e, grouping_expr_ids)[0] for e in self.select_list]
165
264
  if is_agg_output.count(False) > 0:
166
265
  raise excs.Error(
167
- f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}')
168
-
169
- # check that filter doesn't contain aggregates
170
- if self.filter is not None:
171
- if any(_is_agg_fn_call(e) for e in self.filter.subexprs(expr_class=exprs.FunctionCall)):
172
- raise excs.Error(f'Filter cannot contain aggregate functions: {self.filter}')
266
+ f'Invalid non-aggregate expression in aggregate query: {self.select_list[is_agg_output.index(False)]}'
267
+ )
268
+
269
+ # check that Where clause and filter doesn't contain aggregates
270
+ if self.sql_where_clause is not None and any(
271
+ _is_agg_fn_call(e) for e in self.sql_where_clause.subexprs(expr_class=exprs.FunctionCall)
272
+ ):
273
+ raise excs.Error(f'where() cannot contain aggregate functions: {self.sql_where_clause}')
274
+ if self.filter is not None and any(
275
+ _is_agg_fn_call(e) for e in self.filter.subexprs(expr_class=exprs.FunctionCall)
276
+ ):
277
+ raise excs.Error(f'where() cannot contain aggregate functions: {self.filter}')
173
278
 
174
279
  # check that grouping exprs don't contain aggregates and can be expressed as SQL (we perform sort-based
175
280
  # aggregation and rely on the SqlScanNode returning data in the correct order)
176
281
  for e in self.group_by_clause:
177
282
  if not self.sql_elements.contains(e):
178
283
  raise excs.Error(f'Invalid grouping expression, needs to be expressible in SQL: {e}')
179
- if e._contains(filter=lambda e: _is_agg_fn_call(e)):
284
+ if e._contains(filter=_is_agg_fn_call):
180
285
  raise excs.Error(f'Grouping expression contains aggregate function: {e}')
181
286
 
182
287
  def _determine_agg_status(self, e: exprs.Expr, grouping_expr_ids: set[int]) -> tuple[bool, bool]:
@@ -194,14 +299,15 @@ class Analyzer:
194
299
  return True, False
195
300
  elif isinstance(e, exprs.Literal):
196
301
  return True, True
197
- elif isinstance(e, exprs.ColumnRef) or isinstance(e, exprs.RowidRef):
302
+ elif isinstance(e, (exprs.ColumnRef, exprs.RowidRef)):
198
303
  # we already know that this isn't a grouping expr
199
304
  return False, True
200
305
  else:
201
306
  # an expression such as <grouping expr 1> + <grouping expr 2> can both be the output and input of agg
202
307
  assert len(e.components) > 0
203
308
  component_is_output, component_is_input = zip(
204
- *[self._determine_agg_status(c, grouping_expr_ids) for c in e.components])
309
+ *[self._determine_agg_status(c, grouping_expr_ids) for c in e.components]
310
+ )
205
311
  is_output = component_is_output.count(True) == len(e.components)
206
312
  is_input = component_is_input.count(True) == len(e.components)
207
313
  if not is_output and not is_input:
@@ -224,13 +330,14 @@ class Analyzer:
224
330
  row_builder.set_slot_idxs(self.agg_fn_calls)
225
331
  row_builder.set_slot_idxs(self.agg_order_by)
226
332
 
227
- def get_window_fn_ob_clause(self) -> Optional[OrderByClause]:
333
+ def get_window_fn_ob_clause(self) -> OrderByClause | None:
228
334
  clause: list[OrderByClause] = []
229
335
  for fn_call in self.window_fn_calls:
230
336
  # window functions require ordering by the group_by/order_by clauses
231
337
  group_by_exprs, order_by_exprs = fn_call.get_window_sort_exprs()
232
338
  clause.append(
233
- [OrderByItem(e, None) for e in group_by_exprs] + [OrderByItem(e, True) for e in order_by_exprs])
339
+ [OrderByItem(e, None) for e in group_by_exprs] + [OrderByItem(e, True) for e in order_by_exprs]
340
+ )
234
341
  return combine_order_by_clauses(clause)
235
342
 
236
343
  def has_agg(self) -> bool:
@@ -239,103 +346,113 @@ class Analyzer:
239
346
 
240
347
 
241
348
  class Planner:
242
- # TODO: create an exec.CountNode and change this to create_count_plan()
243
349
  @classmethod
244
- def create_count_stmt(
245
- cls, tbl: catalog.TableVersionPath, where_clause: Optional[exprs.Expr] = None
246
- ) -> sql.Select:
247
- stmt = sql.select(sql.func.count())
248
- refd_tbl_ids: set[UUID] = set()
249
- if where_clause is not None:
250
- analyzer = cls.analyze(tbl, where_clause)
251
- if analyzer.filter is not None:
252
- raise excs.Error(f'Filter {analyzer.filter} not expressible in SQL')
253
- clause_element = analyzer.sql_where_clause.sql_expr(analyzer.sql_elements)
254
- assert clause_element is not None
255
- stmt = stmt.where(clause_element)
256
- refd_tbl_ids = where_clause.tbl_ids()
257
- stmt = exec.SqlScanNode.create_from_clause(tbl, stmt, refd_tbl_ids)
258
- return stmt
350
+ def create_count_stmt(cls, query: 'pxt.Query') -> sql.Select:
351
+ """Creates a SQL SELECT COUNT(*) statement for counting rows in a Query."""
352
+ # Create the query plan
353
+ plan = query._create_query_plan()
354
+ sql_node = plan.get_node(exec.SqlNode)
355
+ assert sql_node is not None
356
+ if sql_node.py_filter is not None:
357
+ raise excs.Error('count() cannot be used with Python-only filters. Use collect() instead.')
358
+ # Get the SQL statement from the SqlNode as a CTE
359
+ cte, _ = sql_node.to_cte(keep_pk=True)
360
+ count_stmt = sql.select(sql.func.count().label('all_count')).select_from(cte)
361
+ return count_stmt
259
362
 
260
363
  @classmethod
261
364
  def create_insert_plan(
262
365
  cls, tbl: catalog.TableVersion, rows: list[dict[str, Any]], ignore_errors: bool
263
366
  ) -> exec.ExecNode:
264
367
  """Creates a plan for TableVersion.insert()"""
265
- assert not tbl.is_view()
368
+ assert not tbl.is_view
266
369
  # stored_cols: all cols we need to store, incl computed cols (and indices)
267
370
  stored_cols = [c for c in tbl.cols_by_id.values() if c.is_stored]
268
371
  assert len(stored_cols) > 0 # there needs to be something to store
269
- row_builder = exprs.RowBuilder([], stored_cols, [])
372
+
373
+ cls.__check_valid_columns(tbl, stored_cols, 'inserted into')
374
+
375
+ row_builder = exprs.RowBuilder([], stored_cols, [], tbl)
270
376
 
271
377
  # create InMemoryDataNode for 'rows'
272
- plan: exec.ExecNode = exec.InMemoryDataNode(tbl, rows, row_builder, tbl.next_rowid)
378
+ plan: exec.ExecNode = exec.InMemoryDataNode(tbl.handle, rows, row_builder, tbl.next_row_id)
273
379
 
274
- media_input_col_info = [
275
- exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
276
- for col_ref in row_builder.input_exprs
277
- if isinstance(col_ref, exprs.ColumnRef) and col_ref.col_type.is_media_type()
278
- ]
279
- if len(media_input_col_info) > 0:
280
- # prefetch external files for all input column refs
281
- plan = exec.CachePrefetchNode(tbl.id, media_input_col_info, input=plan)
380
+ plan = cls._add_prefetch_node(tbl.id, row_builder.input_exprs, input_node=plan)
282
381
 
283
382
  computed_exprs = row_builder.output_exprs - row_builder.input_exprs
284
383
  if len(computed_exprs) > 0:
285
384
  # add an ExprEvalNode when there are exprs to compute
286
- plan = exec.ExprEvalNode(row_builder, computed_exprs, plan.output_exprs, input=plan)
385
+ plan = exec.ExprEvalNode(
386
+ row_builder, computed_exprs, plan.output_exprs, input=plan, maintain_input_order=False
387
+ )
388
+ if any(c.col_type.supports_file_offloading() for c in stored_cols):
389
+ plan = exec.CellMaterializationNode(plan)
287
390
 
288
- stored_col_info = row_builder.output_slot_idxs()
289
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
290
- plan.set_stored_img_cols(stored_img_col_info)
291
391
  plan.set_ctx(
292
392
  exec.ExecContext(
293
- row_builder, batch_size=0, show_pbar=True, num_computed_exprs=len(computed_exprs),
294
- ignore_errors=ignore_errors))
393
+ row_builder,
394
+ batch_size=0,
395
+ show_pbar=True,
396
+ num_computed_exprs=len(computed_exprs),
397
+ ignore_errors=ignore_errors,
398
+ )
399
+ )
400
+ plan = cls._add_save_node(plan)
401
+
295
402
  return plan
296
403
 
297
404
  @classmethod
298
- def create_df_insert_plan(
299
- cls,
300
- tbl: catalog.TableVersion,
301
- df: 'pxt.DataFrame',
302
- ignore_errors: bool
405
+ def rowid_columns(cls, target: TableVersionHandle, num_rowid_cols: int | None = None) -> list[exprs.Expr]:
406
+ """Return list of RowidRef for the given number of associated rowids"""
407
+ if num_rowid_cols is None:
408
+ num_rowid_cols = target.get().num_rowid_columns()
409
+ return [exprs.RowidRef(target, i) for i in range(num_rowid_cols)]
410
+
411
+ @classmethod
412
+ def create_query_insert_plan(
413
+ cls, tbl: catalog.TableVersion, query: 'pxt.Query', ignore_errors: bool
303
414
  ) -> exec.ExecNode:
304
- assert not tbl.is_view()
305
- plan = df._create_query_plan() # ExecNode constructed by the DataFrame
415
+ assert not tbl.is_view
416
+ plan = query._create_query_plan() # ExecNode constructed by the Query
306
417
 
307
418
  # Modify the plan RowBuilder to register the output columns
308
- for col_name, expr in zip(df.schema.keys(), df._select_list_exprs):
419
+ needs_cell_materialization = False
420
+ for col_name, expr in zip(query.schema.keys(), query._select_list_exprs):
309
421
  assert col_name in tbl.cols_by_name
310
422
  col = tbl.cols_by_name[col_name]
311
423
  plan.row_builder.add_table_column(col, expr.slot_idx)
424
+ needs_cell_materialization = needs_cell_materialization or col.col_type.supports_file_offloading()
312
425
 
313
- stored_col_info = plan.row_builder.output_slot_idxs()
314
- stored_img_col_info = [info for info in stored_col_info if info.col.col_type.is_image_type()]
315
- plan.set_stored_img_cols(stored_img_col_info)
426
+ if needs_cell_materialization:
427
+ plan = exec.CellMaterializationNode(plan)
316
428
 
317
429
  plan.set_ctx(
318
430
  exec.ExecContext(
319
- plan.row_builder, batch_size=0, show_pbar=True, num_computed_exprs=0,
320
- ignore_errors=ignore_errors))
431
+ plan.row_builder, batch_size=0, show_pbar=True, num_computed_exprs=0, ignore_errors=ignore_errors
432
+ )
433
+ )
321
434
  plan.ctx.num_rows = 0 # Unknown
322
435
 
323
436
  return plan
324
437
 
325
438
  @classmethod
326
439
  def create_update_plan(
327
- cls, tbl: catalog.TableVersionPath,
328
- update_targets: dict[catalog.Column, exprs.Expr],
329
- recompute_targets: list[catalog.Column],
330
- where_clause: Optional[exprs.Expr], cascade: bool
440
+ cls,
441
+ tbl: catalog.TableVersionPath,
442
+ update_targets: dict[catalog.Column, exprs.Expr],
443
+ recompute_targets: list[catalog.Column],
444
+ where_clause: exprs.Expr | None,
445
+ cascade: bool,
331
446
  ) -> tuple[exec.ExecNode, list[str], list[catalog.Column]]:
332
447
  """Creates a plan to materialize updated rows.
448
+
333
449
  The plan:
334
450
  - retrieves rows that are visible at the current version of the table
335
451
  - materializes all stored columns and the update targets
336
452
  - if cascade is True, recomputes all computed columns that transitively depend on the updated columns
337
453
  and copies the values of all other stored columns
338
454
  - if cascade is False, copies all columns that aren't update targets from the original rows
455
+
339
456
  Returns:
340
457
  - root node of the plan
341
458
  - list of qualified column names that are getting updated
@@ -343,46 +460,178 @@ class Planner:
343
460
  """
344
461
  # retrieve all stored cols and all target exprs
345
462
  assert isinstance(tbl, catalog.TableVersionPath)
346
- target = tbl.tbl_version # the one we need to update
463
+ target = tbl.tbl_version.get() # the one we need to update
347
464
  updated_cols = list(update_targets.keys())
465
+ recomputed_cols: set[Column]
348
466
  if len(recompute_targets) > 0:
349
- recomputed_cols = set(recompute_targets)
467
+ assert len(update_targets) == 0
468
+ recomputed_cols = {*recompute_targets}
469
+ if cascade:
470
+ recomputed_cols |= target.get_dependent_columns(recomputed_cols)
350
471
  else:
351
472
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
352
- # regardless of cascade, we need to update all indices on any updated column
353
- idx_val_cols = target.get_idx_val_columns(updated_cols)
354
- recomputed_cols.update(idx_val_cols)
355
- # we only need to recompute stored columns (unstored ones are substituted away)
356
- recomputed_cols = {c for c in recomputed_cols if c.is_stored}
357
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
473
+ # regardless of cascade, we need to update all indices on any updated/recomputed column
474
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
475
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
476
+ recomputed_cols.update(idx_val_cols)
477
+ # we only need to recompute stored columns (unstored ones are substituted away)
478
+ recomputed_cols = {c for c in recomputed_cols if c.is_stored}
479
+
480
+ cls.__check_valid_columns(tbl.tbl_version.get(), recomputed_cols, 'updated in')
481
+
482
+ # our query plan
483
+ # - evaluates the update targets and recomputed columns
484
+ # - copies all other stored columns
485
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == tbl.tbl_version.id}
358
486
  copied_cols = [
359
- col for col in target.cols_by_id.values()
360
- if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
487
+ col
488
+ for col in target.cols_by_id.values()
489
+ if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
361
490
  ]
362
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
363
- select_list.extend(update_targets.values())
491
+ select_list: list[exprs.Expr] = list(update_targets.values())
364
492
 
365
- recomputed_exprs = \
366
- [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
493
+ recomputed_exprs = [
494
+ c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
495
+ ]
367
496
  # recomputed cols reference the new values of the updated cols
368
497
  spec: dict[exprs.Expr, exprs.Expr] = {exprs.ColumnRef(col): e for col, e in update_targets.items()}
369
498
  exprs.Expr.list_substitute(recomputed_exprs, spec)
370
499
  select_list.extend(recomputed_exprs)
371
500
 
372
501
  # we need to retrieve the PK columns of the existing rows
373
- plan = cls.create_query_plan(FromClause(tbls=[tbl]), select_list, where_clause=where_clause, ignore_errors=True)
374
- all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
502
+ plan = cls.create_query_plan(
503
+ FromClause(tbls=[tbl]),
504
+ select_list=select_list,
505
+ columns=copied_cols,
506
+ where_clause=where_clause,
507
+ ignore_errors=True,
508
+ )
509
+ evaluated_cols = updated_cols + list(recomputed_base_cols) # same order as select_list
375
510
  # update row builder with column information
376
- for i, col in enumerate(all_base_cols):
511
+ plan.row_builder.add_table_columns(copied_cols)
512
+ for i, col in enumerate(evaluated_cols):
377
513
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
514
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
515
+
516
+ plan = cls._add_cell_materialization_node(plan)
517
+ plan = cls._add_save_node(plan)
518
+
378
519
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
379
- return plan, [f'{c.tbl.name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
520
+ return plan, [f'{c.get_tbl().name}.{c.name}' for c in updated_cols + recomputed_user_cols], recomputed_user_cols
521
+
522
+ @classmethod
523
+ def __check_valid_columns(
524
+ cls, tbl: catalog.TableVersion, cols: Iterable[Column], op_name: Literal['inserted into', 'updated in']
525
+ ) -> None:
526
+ for col in cols:
527
+ if col.value_expr is not None and not col.value_expr.is_valid:
528
+ raise excs.Error(
529
+ dedent(
530
+ f"""
531
+ Data cannot be {op_name} the table {tbl.name!r},
532
+ because the column {col.name!r} is currently invalid:
533
+ {{validation_error}}
534
+ """
535
+ )
536
+ .strip()
537
+ .format(validation_error=col.value_expr.validation_error)
538
+ )
539
+
540
+ @classmethod
541
+ def _cell_md_col_refs(cls, expr_list: Iterable[exprs.Expr]) -> list[exprs.ColumnRef]:
542
+ """Return list of ColumnRefs that need their cellmd values for reconstruction"""
543
+ json_col_refs = list(
544
+ exprs.Expr.list_subexprs(
545
+ expr_list,
546
+ expr_class=exprs.ColumnRef,
547
+ filter=lambda e: cast(exprs.ColumnRef, e).col.col_type.is_json_type(),
548
+ traverse_matches=False,
549
+ )
550
+ )
551
+
552
+ def needs_reconstruction(e: exprs.Expr) -> bool:
553
+ assert isinstance(e, exprs.ColumnRef)
554
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
555
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
556
+
557
+ array_col_refs = list(
558
+ exprs.Expr.list_subexprs(
559
+ expr_list, expr_class=exprs.ColumnRef, filter=needs_reconstruction, traverse_matches=False
560
+ )
561
+ )
562
+
563
+ binary_col_refs = list(
564
+ exprs.Expr.list_subexprs(
565
+ expr_list,
566
+ expr_class=exprs.ColumnRef,
567
+ filter=lambda e: cast(exprs.ColumnRef, e).col.col_type.is_binary_type(),
568
+ traverse_matches=False,
569
+ )
570
+ )
571
+
572
+ return json_col_refs + array_col_refs + binary_col_refs
573
+
574
+ @classmethod
575
+ def _add_cell_materialization_node(cls, input: exec.ExecNode) -> exec.ExecNode:
576
+ # we need a CellMaterializationNode if any of the evaluated output columns are json or array-typed
577
+ has_target_cols = any(
578
+ col.col_type.supports_file_offloading()
579
+ for col, slot_idx in input.row_builder.table_columns.items()
580
+ if slot_idx is not None
581
+ )
582
+ if has_target_cols:
583
+ return exec.CellMaterializationNode(input)
584
+ else:
585
+ return input
586
+
587
+ @classmethod
588
+ def _add_cell_reconstruction_node(cls, expr_list: list[exprs.Expr], input: exec.ExecNode) -> exec.ExecNode:
589
+ """
590
+ Add a CellReconstructionNode, if required by any of the exprs in expr_list.
591
+
592
+ Cell reconstruction is required for
593
+ 1) all json-typed ColumnRefs that are not used as part of a JsonPath (the latter does its own reconstruction)
594
+ or as part of a ColumnPropertyRef
595
+ 2) all array-typed ColumnRefs that are not used as part of a ColumnPropertyRef
596
+ """
597
+
598
+ def json_filter(e: exprs.Expr) -> bool:
599
+ if isinstance(e, exprs.JsonPath):
600
+ return not e.is_relative_path() and isinstance(e.anchor, exprs.ColumnRef)
601
+ if isinstance(e, exprs.ColumnPropertyRef):
602
+ return e.col_ref.col.col_type.is_json_type()
603
+ return isinstance(e, exprs.ColumnRef) and e.col.col_type.is_json_type()
604
+
605
+ def array_filter(e: exprs.Expr) -> bool:
606
+ if isinstance(e, exprs.ColumnPropertyRef):
607
+ return e.col_ref.col.col_type.is_array_type()
608
+ if not isinstance(e, exprs.ColumnRef):
609
+ return False
610
+ # Vector-typed array columns are used for vector indexes, and are stored in the db
611
+ return e.col.col_type.is_array_type() and not isinstance(e.col.sa_col_type, pgvector.sqlalchemy.Vector)
612
+
613
+ def binary_filter(e: exprs.Expr) -> bool:
614
+ return isinstance(e, exprs.ColumnRef) and e.col.col_type.is_binary_type()
615
+
616
+ json_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=json_filter, traverse_matches=False))
617
+ json_refs = [e for e in json_candidates if isinstance(e, exprs.ColumnRef)]
618
+ array_candidates = list(exprs.Expr.list_subexprs(expr_list, filter=array_filter, traverse_matches=False))
619
+ array_refs = [e for e in array_candidates if isinstance(e, exprs.ColumnRef)]
620
+ binary_refs = list(
621
+ exprs.Expr.list_subexprs(expr_list, exprs.ColumnRef, filter=binary_filter, traverse_matches=False)
622
+ )
623
+ if len(json_refs) > 0 or len(array_refs) > 0 or len(binary_refs) > 0:
624
+ return exec.CellReconstructionNode(json_refs, array_refs, binary_refs, input.row_builder, input=input)
625
+ else:
626
+ return input
380
627
 
381
628
  @classmethod
382
629
  def create_batch_update_plan(
383
- cls, tbl: catalog.TableVersionPath,
384
- batch: list[dict[catalog.Column, exprs.Expr]], rowids: list[tuple[int, ...]],
385
- cascade: bool
630
+ cls,
631
+ tbl: catalog.TableVersionPath,
632
+ batch: list[dict[catalog.Column, exprs.Expr]],
633
+ rowids: list[tuple[int, ...]],
634
+ cascade: bool,
386
635
  ) -> tuple[exec.ExecNode, exec.RowUpdateNode, sql.ColumnElement[bool], list[catalog.Column], list[catalog.Column]]:
387
636
  """
388
637
  Returns:
@@ -393,9 +642,9 @@ class Planner:
393
642
  - list of user-visible columns that are being recomputed
394
643
  """
395
644
  assert isinstance(tbl, catalog.TableVersionPath)
396
- target = tbl.tbl_version # the one we need to update
397
- sa_key_cols: list[sql.Column] = []
398
- key_vals: list[tuple] = []
645
+ target = tbl.tbl_version.get() # the one we need to update
646
+ sa_key_cols: list[sql.Column]
647
+ key_vals: list[tuple]
399
648
  if len(rowids) > 0:
400
649
  sa_key_cols = target.store_tbl.rowid_columns()
401
650
  key_vals = rowids
@@ -408,21 +657,23 @@ class Planner:
408
657
  updated_cols = batch[0].keys() - target.primary_key_columns()
409
658
  recomputed_cols = target.get_dependent_columns(updated_cols) if cascade else set()
410
659
  # regardless of cascade, we need to update all indices on any updated column
411
- idx_val_cols = target.get_idx_val_columns(updated_cols)
660
+ modified_base_cols = [c for c in set(updated_cols) | recomputed_cols if c.get_tbl().id == target.id]
661
+ idx_val_cols = target.get_idx_val_columns(modified_base_cols)
412
662
  recomputed_cols.update(idx_val_cols)
413
663
  # we only need to recompute stored columns (unstored ones are substituted away)
414
664
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
415
- recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
665
+ recomputed_base_cols = {col for col in recomputed_cols if col.get_tbl().id == target.id}
416
666
  copied_cols = [
417
- col for col in target.cols_by_id.values()
418
- if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
667
+ col
668
+ for col in target.cols_by_id.values()
669
+ if col.is_stored and col not in updated_cols and col not in recomputed_base_cols
419
670
  ]
420
- select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
421
- select_list.extend(exprs.ColumnRef(col) for col in updated_cols)
671
+ select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in updated_cols]
422
672
 
423
- recomputed_exprs = \
424
- [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
425
- # the RowUpdateNode updates columns in-place, ie, in the original ColumnRef; no further sustitution is needed
673
+ recomputed_exprs = [
674
+ c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols
675
+ ]
676
+ # the RowUpdateNode updates columns in-place, ie, in the original ColumnRef; no further substitution is needed
426
677
  select_list.extend(recomputed_exprs)
427
678
 
428
679
  # ExecNode tree (from bottom to top):
@@ -430,36 +681,54 @@ class Planner:
430
681
  # - RowUpdateNode to update the retrieved rows
431
682
  # - ExprEvalNode to evaluate the remaining output exprs
432
683
  analyzer = Analyzer(FromClause(tbls=[tbl]), select_list)
433
- sql_exprs = list(exprs.Expr.list_subexprs(
434
- analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False))
435
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs)
684
+ sql_exprs = list(
685
+ exprs.Expr.list_subexprs(analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False)
686
+ )
687
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs, target)
436
688
  analyzer.finalize(row_builder)
437
- sql_lookup_node = exec.SqlLookupNode(tbl, row_builder, sql_exprs, sa_key_cols, key_vals)
689
+
690
+ cell_md_col_refs = cls._cell_md_col_refs(sql_exprs)
691
+ sql_lookup_node = exec.SqlLookupNode(
692
+ tbl,
693
+ row_builder,
694
+ sql_exprs,
695
+ columns=copied_cols,
696
+ sa_key_cols=sa_key_cols,
697
+ key_vals=key_vals,
698
+ cell_md_col_refs=cell_md_col_refs,
699
+ )
438
700
  col_vals = [{col: row[col].val for col in updated_cols} for row in batch]
439
701
  row_update_node = exec.RowUpdateNode(tbl, key_vals, len(rowids) > 0, col_vals, row_builder, sql_lookup_node)
440
702
  plan: exec.ExecNode = row_update_node
441
703
  if not cls._is_contained_in(analyzer.select_list, sql_exprs):
442
704
  # we need an ExprEvalNode to evaluate the remaining output exprs
443
705
  plan = exec.ExprEvalNode(row_builder, analyzer.select_list, sql_exprs, input=plan)
706
+
444
707
  # update row builder with column information
445
- all_base_cols = copied_cols + list(updated_cols) + list(recomputed_base_cols) # same order as select_list
708
+ evaluated_cols = list(updated_cols) + list(recomputed_base_cols) # same order as select_list
446
709
  row_builder.set_slot_idxs(select_list, remove_duplicates=False)
447
- for i, col in enumerate(all_base_cols):
710
+ plan.row_builder.add_table_columns(copied_cols)
711
+ for i, col in enumerate(evaluated_cols):
448
712
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
449
-
450
- ctx = exec.ExecContext(row_builder)
451
- # we're returning everything to the user, so we might as well do it in a single batch
713
+ ctx = exec.ExecContext(row_builder, num_computed_exprs=len(recomputed_exprs))
714
+ # TODO: correct batch size?
452
715
  ctx.batch_size = 0
453
716
  plan.set_ctx(ctx)
717
+
718
+ plan = cls._add_cell_materialization_node(plan)
719
+ plan = cls._add_save_node(plan)
454
720
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
455
721
  return (
456
- plan, row_update_node, sql_lookup_node.where_clause_element, list(updated_cols) + recomputed_user_cols,
457
- recomputed_user_cols
722
+ plan,
723
+ row_update_node,
724
+ sql_lookup_node.where_clause_element,
725
+ list(updated_cols) + recomputed_user_cols,
726
+ recomputed_user_cols,
458
727
  )
459
728
 
460
729
  @classmethod
461
730
  def create_view_update_plan(
462
- cls, view: catalog.TableVersionPath, recompute_targets: list[catalog.Column]
731
+ cls, view: catalog.TableVersionPath, recompute_targets: list[catalog.Column]
463
732
  ) -> exec.ExecNode:
464
733
  """Creates a plan to materialize updated rows for a view, given that the base table has been updated.
465
734
  The plan:
@@ -477,27 +746,33 @@ class Planner:
477
746
  - list of columns that are being recomputed
478
747
  """
479
748
  assert isinstance(view, catalog.TableVersionPath)
480
- assert view.is_view()
481
- target = view.tbl_version # the one we need to update
749
+ assert view.is_view
750
+ target = view.tbl_version.get() # the one we need to update
482
751
  # retrieve all stored cols and all target exprs
483
752
  recomputed_cols = set(recompute_targets.copy())
484
- copied_cols = [col for col in target.cols_by_id.values() if col.is_stored and not col in recomputed_cols]
753
+ copied_cols = [col for col in target.cols_by_id.values() if col.is_stored and col not in recomputed_cols]
485
754
  select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
486
755
  # resolve recomputed exprs to stored columns in the base
487
- recomputed_exprs = \
488
- [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_cols) for c in recomputed_cols]
756
+ recomputed_exprs = [
757
+ c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_cols) for c in recomputed_cols
758
+ ]
489
759
  select_list.extend(recomputed_exprs)
490
760
 
491
761
  # we need to retrieve the PK columns of the existing rows
492
762
  plan = cls.create_query_plan(
493
- FromClause(tbls=[view]), select_list, where_clause=target.predicate, ignore_errors=True,
494
- exact_version_only=view.get_bases())
495
- for i, col in enumerate(copied_cols + list(recomputed_cols)): # same order as select_list
763
+ FromClause(tbls=[view]),
764
+ select_list,
765
+ where_clause=target.predicate,
766
+ ignore_errors=True,
767
+ exact_version_only=view.get_bases(),
768
+ )
769
+ plan.ctx.num_computed_exprs = len(recomputed_exprs)
770
+ materialized_cols = copied_cols + list(recomputed_cols) # same order as select_list
771
+ for i, col in enumerate(materialized_cols):
496
772
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
497
- # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
498
- stored_img_col_info = \
499
- [info for info in plan.row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
500
- plan.set_stored_img_cols(stored_img_col_info)
773
+ plan = cls._add_cell_materialization_node(plan)
774
+ plan = cls._add_save_node(plan)
775
+
501
776
  return plan
502
777
 
503
778
  @classmethod
@@ -515,45 +790,61 @@ class Planner:
515
790
  - number of materialized values per row
516
791
  """
517
792
  assert isinstance(view, catalog.TableVersionPath)
518
- assert view.is_view()
793
+ assert view.is_view
519
794
  # things we need to materialize as DataRows:
520
795
  # 1. stored computed cols
521
796
  # - iterator columns are effectively computed, just not with a value_expr
522
797
  # - we can ignore stored non-computed columns because they have a default value that is supplied directly by
523
798
  # the store
524
- target = view.tbl_version # the one we need to populate
799
+ target = view.tbl_version.get() # the one we need to populate
525
800
  stored_cols = [c for c in target.cols_by_id.values() if c.is_stored]
526
801
  # 2. for component views: iterator args
527
802
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
528
803
 
529
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
804
+ from_clause = FromClause(tbls=[view.base])
805
+ base_analyzer = Analyzer(
806
+ from_clause, iterator_args, where_clause=target.predicate, sample_clause=target.sample_clause
807
+ )
808
+ row_builder = exprs.RowBuilder(base_analyzer.all_exprs, stored_cols, [], target, for_view_load=True)
530
809
 
810
+ # if we're propagating an insert, we only want to see those base rows that were created for the current version
531
811
  # execution plan:
532
812
  # 1. materialize exprs computed from the base that are needed for stored view columns
533
813
  # 2. if it's an iterator view, expand the base rows into component rows
534
814
  # 3. materialize stored view columns that haven't been produced by step 1
535
815
  base_output_exprs = [e for e in row_builder.default_eval_ctx.exprs if e.is_bound_by([view.base])]
536
816
  view_output_exprs = [
537
- e for e in row_builder.default_eval_ctx.target_exprs
817
+ e
818
+ for e in row_builder.default_eval_ctx.target_exprs
538
819
  if e.is_bound_by([view]) and not e.is_bound_by([view.base])
539
820
  ]
540
- # if we're propagating an insert, we only want to see those base rows that were created for the current version
541
- base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
821
+
822
+ # Create a new analyzer reflecting exactly what is required from the base table
823
+ base_analyzer = Analyzer(
824
+ from_clause, base_output_exprs, where_clause=target.predicate, sample_clause=target.sample_clause
825
+ )
542
826
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
543
827
  plan = cls._create_query_plan(
544
- row_builder=row_builder, analyzer=base_analyzer, eval_ctx=base_eval_ctx, with_pk=True,
545
- exact_version_only=view.get_bases() if propagates_insert else [])
828
+ row_builder=row_builder,
829
+ analyzer=base_analyzer,
830
+ eval_ctx=base_eval_ctx,
831
+ with_pk=True,
832
+ exact_version_only=view.get_bases() if propagates_insert else [],
833
+ )
546
834
  exec_ctx = plan.ctx
547
- if target.is_component_view():
548
- plan = exec.ComponentIterationNode(target, plan)
835
+ if target.is_component_view:
836
+ plan = exec.ComponentIterationNode(view.tbl_version, plan)
549
837
  if len(view_output_exprs) > 0:
550
838
  plan = exec.ExprEvalNode(
551
- row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs,input=plan)
839
+ row_builder, output_exprs=view_output_exprs, input_exprs=base_output_exprs, input=plan
840
+ )
552
841
 
553
- stored_img_col_info = [info for info in row_builder.output_slot_idxs() if info.col.col_type.is_image_type()]
554
- plan.set_stored_img_cols(stored_img_col_info)
555
842
  exec_ctx.ignore_errors = True
556
843
  plan.set_ctx(exec_ctx)
844
+ if any(c.col_type.supports_file_offloading() for c in stored_cols):
845
+ plan = exec.CellMaterializationNode(plan)
846
+ plan = cls._add_save_node(plan)
847
+
557
848
  return plan, len(row_builder.default_eval_ctx.target_exprs)
558
849
 
559
850
  @classmethod
@@ -564,8 +855,8 @@ class Planner:
564
855
  raise excs.Error(f'Join predicate {join_clause.join_predicate} not expressible in SQL')
565
856
 
566
857
  @classmethod
567
- def _verify_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> None:
568
- """Verify that the various ordering requirements don't conflict"""
858
+ def _create_combined_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> OrderByClause | None:
859
+ """Verify that the various ordering requirements don't conflict and return a combined ordering"""
569
860
  ob_clauses: list[OrderByClause] = [analyzer.order_by_clause.copy()]
570
861
 
571
862
  if verify_agg:
@@ -577,13 +868,15 @@ class Planner:
577
868
  ob_clauses.append(ordering)
578
869
  for fn_call in analyzer.agg_fn_calls:
579
870
  # agg functions with an ordering requirement are implicitly ascending
580
- ordering = (
581
- [OrderByItem(e, None) for e in analyzer.group_by_clause]
582
- + [OrderByItem(e, True) for e in fn_call.get_agg_order_by()]
583
- )
871
+ ordering = [OrderByItem(e, None) for e in analyzer.group_by_clause] + [
872
+ OrderByItem(e, True) for e in fn_call.get_agg_order_by()
873
+ ]
584
874
  ob_clauses.append(ordering)
585
- if len(ob_clauses) <= 1:
586
- return
875
+
876
+ if len(ob_clauses) == 0:
877
+ return None
878
+ elif len(ob_clauses) == 1:
879
+ return ob_clauses[0]
587
880
 
588
881
  combined_ordering = ob_clauses[0]
589
882
  for ordering in ob_clauses[1:]:
@@ -591,60 +884,101 @@ class Planner:
591
884
  if combined is None:
592
885
  raise excs.Error(
593
886
  f'Incompatible ordering requirements: '
594
- f'{print_order_by_clause(combined_ordering)} vs {print_order_by_clause(ordering)}')
887
+ f'{print_order_by_clause(combined_ordering)} vs {print_order_by_clause(ordering)}'
888
+ )
595
889
  combined_ordering = combined
890
+ return combined_ordering
891
+
892
+ @classmethod
893
+ def _add_save_node(cls, input_node: exec.ExecNode) -> exec.ExecNode:
894
+ """Add an ObjectStoreSaveNode, if needed."""
895
+ media_col_info = input_node.row_builder.media_output_col_info
896
+ if len(media_col_info) == 0:
897
+ return input_node
898
+ else:
899
+ return exec.ObjectStoreSaveNode(media_col_info, input_node)
596
900
 
597
901
  @classmethod
598
902
  def _is_contained_in(cls, l1: Iterable[exprs.Expr], l2: Iterable[exprs.Expr]) -> bool:
599
903
  """Returns True if l1 is contained in l2"""
600
- s1, s2 = set(e.id for e in l1), set(e.id for e in l2)
601
- return s1 <= s2
904
+ return {e.id for e in l1} <= {e.id for e in l2}
602
905
 
603
906
  @classmethod
604
- def _insert_prefetch_node(cls, tbl_id: UUID, row_builder: exprs.RowBuilder, input: exec.ExecNode) -> exec.ExecNode:
605
- """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
907
+ def _add_prefetch_node(
908
+ cls, tbl_id: UUID, expressions: Iterable[exprs.Expr], input_node: exec.ExecNode
909
+ ) -> exec.ExecNode:
910
+ """Add a CachePrefetch node, if needed."""
606
911
  # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
607
912
  # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
608
913
  # aren't explicitly captured as dependencies
609
- media_col_refs = [
610
- e for e in list(row_builder.unique_exprs) if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()
611
- ]
914
+ media_col_refs = [e for e in expressions if isinstance(e, exprs.ColumnRef) and e.col_type.is_media_type()]
612
915
  if len(media_col_refs) == 0:
613
- return input
916
+ return input_node
614
917
  # we need to prefetch external files for media column types
615
918
  file_col_info = [exprs.ColumnSlotIdx(e.col, e.slot_idx) for e in media_col_refs]
616
- prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input)
919
+ prefetch_node = exec.CachePrefetchNode(tbl_id, file_col_info, input_node)
617
920
  return prefetch_node
618
921
 
619
922
  @classmethod
620
923
  def create_query_plan(
621
- cls, from_clause: FromClause, select_list: Optional[list[exprs.Expr]] = None,
622
- where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[list[exprs.Expr]] = None,
623
- order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, limit: Optional[int] = None,
624
- ignore_errors: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
924
+ cls,
925
+ from_clause: FromClause,
926
+ select_list: list[exprs.Expr] | None = None,
927
+ columns: list[catalog.Column] | None = None,
928
+ where_clause: exprs.Expr | None = None,
929
+ group_by_clause: list[exprs.Expr] | None = None,
930
+ order_by_clause: list[tuple[exprs.Expr, bool]] | None = None,
931
+ limit: exprs.Expr | None = None,
932
+ sample_clause: SampleClause | None = None,
933
+ ignore_errors: bool = False,
934
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
625
935
  ) -> exec.ExecNode:
626
- """Return plan for executing a query.
936
+ """
937
+ Return plan for executing a query.
938
+
939
+ The plan:
940
+ - materializes the values of select_list exprs into their respective slots
941
+ - materializes cell values of 'columns' (and their cellmd, if applicable) into DataRow.cell_vals/cell_md
942
+
627
943
  Updates 'select_list' in place to make it executable.
628
944
  TODO: make exact_version_only a flag and use the versions from tbl
629
945
  """
630
946
  if select_list is None:
631
947
  select_list = []
948
+ if columns is None:
949
+ columns = []
632
950
  if order_by_clause is None:
633
951
  order_by_clause = []
634
952
  if exact_version_only is None:
635
953
  exact_version_only = []
954
+
636
955
  analyzer = Analyzer(
637
- from_clause, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
638
- order_by_clause=order_by_clause)
639
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
956
+ from_clause,
957
+ select_list,
958
+ where_clause=where_clause,
959
+ group_by_clause=group_by_clause,
960
+ order_by_clause=order_by_clause,
961
+ sample_clause=sample_clause,
962
+ )
963
+ # If the from_clause has a single table, we can use it as the context table for the RowBuilder.
964
+ # Otherwise there is no context table, but that's ok, because the context table is only needed for
965
+ # table mutations, which can't happen during a join.
966
+ context_tbl = from_clause.tbls[0].tbl_version.get() if len(from_clause.tbls) == 1 else None
967
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [], context_tbl)
640
968
 
641
969
  analyzer.finalize(row_builder)
642
970
  # select_list: we need to materialize everything that's been collected
643
971
  # with_pk: for now, we always retrieve the PK, because we need it for the file cache
644
972
  eval_ctx = row_builder.create_eval_ctx(analyzer.select_list)
645
973
  plan = cls._create_query_plan(
646
- row_builder=row_builder, analyzer=analyzer, eval_ctx=eval_ctx, limit=limit, with_pk=True,
647
- exact_version_only=exact_version_only)
974
+ row_builder=row_builder,
975
+ analyzer=analyzer,
976
+ eval_ctx=eval_ctx,
977
+ columns=columns,
978
+ limit=limit,
979
+ with_pk=True,
980
+ exact_version_only=exact_version_only,
981
+ )
648
982
  plan.ctx.ignore_errors = ignore_errors
649
983
  select_list.clear()
650
984
  select_list.extend(analyzer.select_list)
@@ -652,9 +986,14 @@ class Planner:
652
986
 
653
987
  @classmethod
654
988
  def _create_query_plan(
655
- cls, row_builder: exprs.RowBuilder, analyzer: Analyzer, eval_ctx: exprs.RowBuilder.EvalCtx,
656
- limit: Optional[int] = None, with_pk: bool = False,
657
- exact_version_only: Optional[list[catalog.TableVersion]] = None
989
+ cls,
990
+ row_builder: exprs.RowBuilder,
991
+ analyzer: Analyzer,
992
+ eval_ctx: exprs.RowBuilder.EvalCtx,
993
+ columns: list[catalog.Column] | None = None,
994
+ limit: exprs.Expr | None = None,
995
+ with_pk: bool = False,
996
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
658
997
  ) -> exec.ExecNode:
659
998
  """
660
999
  Create plan to materialize eval_ctx.
@@ -664,36 +1003,45 @@ class Planner:
664
1003
  in the context of that table version (eg, if 'tbl' is a view, 'plan_target' might be the base)
665
1004
  TODO: make exact_version_only a flag and use the versions from tbl
666
1005
  """
1006
+ if columns is None:
1007
+ columns = []
667
1008
  if exact_version_only is None:
668
1009
  exact_version_only = []
669
1010
  sql_elements = analyzer.sql_elements
670
- is_python_agg = (
671
- not sql_elements.contains_all(analyzer.agg_fn_calls)
672
- or not sql_elements.contains_all(analyzer.window_fn_calls)
1011
+ is_python_agg = not sql_elements.contains_all(analyzer.agg_fn_calls) or not sql_elements.contains_all(
1012
+ analyzer.window_fn_calls
673
1013
  )
674
1014
  ctx = exec.ExecContext(row_builder)
675
- cls._verify_ordering(analyzer, verify_agg=is_python_agg)
1015
+
1016
+ combined_ordering = cls._create_combined_ordering(analyzer, verify_agg=is_python_agg)
676
1017
  cls._verify_join_clauses(analyzer)
677
1018
 
678
1019
  # materialized with SQL table scans (ie, single-table SELECT statements):
679
1020
  # - select list subexprs that aren't aggregates
680
1021
  # - join clause subexprs
681
1022
  # - subexprs of Where clause conjuncts that can't be run in SQL
682
- # - all grouping exprs, if any aggregate function call can't be run in SQL (in that case, they all have to be
683
- # run in Python)
684
- candidates = list(exprs.Expr.list_subexprs(
685
- analyzer.select_list,
686
- filter=lambda e: (
1023
+ # - all grouping exprs
1024
+ # - all stratify exprs
1025
+ candidates = list(
1026
+ exprs.Expr.list_subexprs(
1027
+ analyzer.select_list,
1028
+ filter=lambda e: (
687
1029
  sql_elements.contains(e)
688
1030
  and not e._contains(cls=exprs.FunctionCall, filter=lambda e: bool(e.is_agg_fn_call))
689
- ),
690
- traverse_matches=False))
1031
+ ),
1032
+ traverse_matches=False,
1033
+ )
1034
+ )
691
1035
  if analyzer.filter is not None:
692
- candidates.extend(exprs.Expr.subexprs(
693
- analyzer.filter, filter=lambda e: sql_elements.contains(e), traverse_matches=False))
694
- if is_python_agg and analyzer.group_by_clause is not None:
695
- candidates.extend(exprs.Expr.list_subexprs(
696
- analyzer.group_by_clause, filter=lambda e: sql_elements.contains(e), traverse_matches=False))
1036
+ candidates.extend(
1037
+ exprs.Expr.subexprs(analyzer.filter, filter=sql_elements.contains, traverse_matches=False)
1038
+ )
1039
+ candidates.extend(
1040
+ exprs.Expr.list_subexprs(analyzer.grouping_exprs, filter=sql_elements.contains, traverse_matches=False)
1041
+ )
1042
+ candidates.extend(
1043
+ exprs.Expr.list_subexprs(analyzer.stratify_exprs, filter=sql_elements.contains, traverse_matches=False)
1044
+ )
697
1045
  # not isinstance(...): we don't want to materialize Literals via a Select
698
1046
  sql_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
699
1047
 
@@ -701,7 +1049,8 @@ class Planner:
701
1049
  join_exprs = exprs.ExprSet(
702
1050
  join_clause.join_predicate
703
1051
  for join_clause in analyzer.from_clause.join_clauses
704
- if join_clause.join_predicate is not None)
1052
+ if join_clause.join_predicate is not None
1053
+ )
705
1054
  scan_target_exprs = sql_exprs | join_exprs
706
1055
  tbl_scan_plans: list[exec.SqlScanNode] = []
707
1056
  plan: exec.ExecNode
@@ -711,16 +1060,28 @@ class Planner:
711
1060
  exprs.Expr.list_subexprs(
712
1061
  scan_target_exprs,
713
1062
  filter=lambda e: e.is_bound_by([tbl]) and not isinstance(e, exprs.Literal),
714
- traverse_matches=False))
1063
+ traverse_matches=False,
1064
+ )
1065
+ )
1066
+
715
1067
  plan = exec.SqlScanNode(
716
- tbl, row_builder, select_list=tbl_scan_exprs,
717
- set_pk=with_pk, exact_version_only=exact_version_only)
1068
+ tbl,
1069
+ row_builder,
1070
+ select_list=tbl_scan_exprs,
1071
+ columns=[c for c in columns if c.get_tbl().id == tbl.tbl_id],
1072
+ set_pk=with_pk,
1073
+ cell_md_col_refs=cls._cell_md_col_refs(tbl_scan_exprs),
1074
+ exact_version_only=exact_version_only,
1075
+ )
718
1076
  tbl_scan_plans.append(plan)
719
1077
 
720
1078
  if len(analyzer.from_clause.join_clauses) > 0:
721
1079
  plan = exec.SqlJoinNode(
722
- row_builder, inputs=tbl_scan_plans, join_clauses=analyzer.from_clause.join_clauses,
723
- select_list=sql_exprs)
1080
+ row_builder,
1081
+ inputs=tbl_scan_plans,
1082
+ join_clauses=analyzer.from_clause.join_clauses,
1083
+ select_list=sql_exprs,
1084
+ )
724
1085
  else:
725
1086
  plan = tbl_scan_plans[0]
726
1087
 
@@ -732,7 +1093,17 @@ class Planner:
732
1093
  # we need to order the input for window functions
733
1094
  plan.set_order_by(analyzer.get_window_fn_ob_clause())
734
1095
 
735
- plan = cls._insert_prefetch_node(tbl.tbl_version.id, row_builder, plan)
1096
+ if analyzer.sample_clause is not None:
1097
+ plan = exec.SqlSampleNode(
1098
+ row_builder,
1099
+ input=plan,
1100
+ select_list=tbl_scan_exprs,
1101
+ sample_clause=analyzer.sample_clause,
1102
+ stratify_exprs=analyzer.stratify_exprs,
1103
+ )
1104
+
1105
+ plan = cls._add_prefetch_node(tbl.tbl_version.id, row_builder.unique_exprs, plan)
1106
+ plan = cls._add_cell_reconstruction_node(analyzer.all_exprs, plan)
736
1107
 
737
1108
  if analyzer.group_by_clause is not None:
738
1109
  # we're doing grouping aggregation; the input of the AggregateNode are the grouping exprs plus the
@@ -750,36 +1121,57 @@ class Planner:
750
1121
  ctx.batch_size = 16
751
1122
 
752
1123
  # do aggregation in SQL if all agg exprs can be translated
753
- if (sql_elements.contains_all(analyzer.select_list)
754
- and sql_elements.contains_all(analyzer.grouping_exprs)
755
- and isinstance(plan, exec.SqlNode)
756
- and plan.to_cte() is not None):
1124
+ if (
1125
+ sql_elements.contains_all(analyzer.select_list)
1126
+ and sql_elements.contains_all(analyzer.grouping_exprs)
1127
+ and isinstance(plan, exec.SqlNode)
1128
+ and plan.to_cte() is not None
1129
+ ):
757
1130
  plan = exec.SqlAggregationNode(
758
- row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause)
1131
+ row_builder, input=plan, select_list=analyzer.select_list, group_by_items=analyzer.group_by_clause
1132
+ )
759
1133
  else:
1134
+ input_sql_node = plan.get_node(exec.SqlNode)
1135
+ assert combined_ordering is not None
1136
+ input_sql_node.set_order_by(combined_ordering)
760
1137
  plan = exec.AggregationNode(
761
- tbl.tbl_version, row_builder, analyzer.group_by_clause,
762
- analyzer.agg_fn_calls + analyzer.window_fn_calls, agg_input, input=plan)
1138
+ tbl.tbl_version,
1139
+ row_builder,
1140
+ analyzer.group_by_clause,
1141
+ analyzer.agg_fn_calls + analyzer.window_fn_calls,
1142
+ agg_input,
1143
+ input=plan,
1144
+ )
763
1145
  typecheck_dummy = analyzer.grouping_exprs + analyzer.agg_fn_calls + analyzer.window_fn_calls
764
1146
  agg_output = exprs.ExprSet(typecheck_dummy)
765
1147
  if not agg_output.issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
766
1148
  # we need an ExprEvalNode to evaluate the remaining output exprs
767
1149
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
1150
+ plan = cls._add_save_node(plan)
768
1151
  else:
769
1152
  if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
770
1153
  # we need an ExprEvalNode to evaluate the remaining output exprs
771
1154
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_exprs, input=plan)
772
1155
  # we're returning everything to the user, so we might as well do it in a single batch
1156
+ # TODO: return smaller batches in order to increase inter-ExecNode parallelism
773
1157
  ctx.batch_size = 0
774
1158
 
1159
+ sql_node = plan.get_node(exec.SqlNode)
775
1160
  if len(analyzer.order_by_clause) > 0:
776
1161
  # we have the last SqlNode we created produce the ordering
777
- sql_node = plan.get_node(exec.SqlNode)
778
1162
  assert sql_node is not None
779
1163
  sql_node.set_order_by(analyzer.order_by_clause)
780
1164
 
1165
+ # if we don't need an ordered result, tell the ExprEvalNode not to maintain input order (which allows us to
1166
+ # return batches earlier)
1167
+ if sql_node is not None and len(sql_node.order_by_clause) == 0:
1168
+ expr_eval_node = plan.get_node(exec.ExprEvalNode)
1169
+ if expr_eval_node is not None:
1170
+ expr_eval_node.set_input_order(False)
1171
+
781
1172
  if limit is not None:
782
- plan.set_limit(limit)
1173
+ assert isinstance(limit, exprs.Literal)
1174
+ plan.set_limit(limit.val)
783
1175
 
784
1176
  plan.set_ctx(ctx)
785
1177
  return plan
@@ -789,25 +1181,24 @@ class Planner:
789
1181
  return Analyzer(FromClause(tbls=[tbl]), [], where_clause=where_clause)
790
1182
 
791
1183
  @classmethod
792
- def create_add_column_plan(
793
- cls, tbl: catalog.TableVersionPath, col: catalog.Column
794
- ) -> tuple[exec.ExecNode, Optional[int]]:
1184
+ def create_add_column_plan(cls, tbl: catalog.TableVersionPath, col: catalog.Column) -> exec.ExecNode:
795
1185
  """Creates a plan for InsertableTable.add_column()
796
1186
  Returns:
797
1187
  plan: the plan to execute
798
1188
  value_expr slot idx for the plan output (for computed cols)
799
1189
  """
800
1190
  assert isinstance(tbl, catalog.TableVersionPath)
801
- row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
1191
+ row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[], tbl=tbl.tbl_version.get())
802
1192
  analyzer = Analyzer(FromClause(tbls=[tbl]), row_builder.default_eval_ctx.target_exprs)
803
1193
  plan = cls._create_query_plan(
804
- row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True)
1194
+ row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True
1195
+ )
1196
+
805
1197
  plan.ctx.batch_size = 16
806
1198
  plan.ctx.show_pbar = True
807
1199
  plan.ctx.ignore_errors = True
1200
+ computed_exprs = row_builder.output_exprs - row_builder.input_exprs
1201
+ plan.ctx.num_computed_exprs = len(computed_exprs) # we are adding a computed column, so we need to evaluate it
1202
+ plan = cls._add_save_node(plan)
808
1203
 
809
- # we want to flush images
810
- if col.is_computed and col.is_stored and col.col_type.is_image_type():
811
- plan.set_stored_img_cols(row_builder.output_slot_idxs())
812
- value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
813
- return plan, value_expr_slot_idx
1204
+ return plan