pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. pixeltable/__init__.py +83 -19
  2. pixeltable/_query.py +1444 -0
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +7 -4
  5. pixeltable/catalog/catalog.py +2394 -119
  6. pixeltable/catalog/column.py +225 -104
  7. pixeltable/catalog/dir.py +38 -9
  8. pixeltable/catalog/globals.py +53 -34
  9. pixeltable/catalog/insertable_table.py +265 -115
  10. pixeltable/catalog/path.py +80 -17
  11. pixeltable/catalog/schema_object.py +28 -43
  12. pixeltable/catalog/table.py +1270 -677
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +1270 -751
  15. pixeltable/catalog/table_version_handle.py +109 -0
  16. pixeltable/catalog/table_version_path.py +137 -42
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +251 -134
  20. pixeltable/config.py +215 -0
  21. pixeltable/env.py +736 -285
  22. pixeltable/exceptions.py +26 -2
  23. pixeltable/exec/__init__.py +7 -2
  24. pixeltable/exec/aggregation_node.py +39 -21
  25. pixeltable/exec/cache_prefetch_node.py +87 -109
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +25 -28
  29. pixeltable/exec/data_row_batch.py +11 -46
  30. pixeltable/exec/exec_context.py +26 -11
  31. pixeltable/exec/exec_node.py +35 -27
  32. pixeltable/exec/expr_eval/__init__.py +3 -0
  33. pixeltable/exec/expr_eval/evaluators.py +365 -0
  34. pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
  35. pixeltable/exec/expr_eval/globals.py +200 -0
  36. pixeltable/exec/expr_eval/row_buffer.py +74 -0
  37. pixeltable/exec/expr_eval/schedulers.py +413 -0
  38. pixeltable/exec/globals.py +35 -0
  39. pixeltable/exec/in_memory_data_node.py +35 -27
  40. pixeltable/exec/object_store_save_node.py +293 -0
  41. pixeltable/exec/row_update_node.py +44 -29
  42. pixeltable/exec/sql_node.py +414 -115
  43. pixeltable/exprs/__init__.py +8 -5
  44. pixeltable/exprs/arithmetic_expr.py +79 -45
  45. pixeltable/exprs/array_slice.py +5 -5
  46. pixeltable/exprs/column_property_ref.py +40 -26
  47. pixeltable/exprs/column_ref.py +254 -61
  48. pixeltable/exprs/comparison.py +14 -9
  49. pixeltable/exprs/compound_predicate.py +9 -10
  50. pixeltable/exprs/data_row.py +213 -72
  51. pixeltable/exprs/expr.py +270 -104
  52. pixeltable/exprs/expr_dict.py +6 -5
  53. pixeltable/exprs/expr_set.py +20 -11
  54. pixeltable/exprs/function_call.py +383 -284
  55. pixeltable/exprs/globals.py +18 -5
  56. pixeltable/exprs/in_predicate.py +7 -7
  57. pixeltable/exprs/inline_expr.py +37 -37
  58. pixeltable/exprs/is_null.py +8 -4
  59. pixeltable/exprs/json_mapper.py +120 -54
  60. pixeltable/exprs/json_path.py +90 -60
  61. pixeltable/exprs/literal.py +61 -16
  62. pixeltable/exprs/method_ref.py +7 -6
  63. pixeltable/exprs/object_ref.py +19 -8
  64. pixeltable/exprs/row_builder.py +238 -75
  65. pixeltable/exprs/rowid_ref.py +53 -15
  66. pixeltable/exprs/similarity_expr.py +65 -50
  67. pixeltable/exprs/sql_element_cache.py +5 -5
  68. pixeltable/exprs/string_op.py +107 -0
  69. pixeltable/exprs/type_cast.py +25 -13
  70. pixeltable/exprs/variable.py +2 -2
  71. pixeltable/func/__init__.py +9 -5
  72. pixeltable/func/aggregate_function.py +197 -92
  73. pixeltable/func/callable_function.py +119 -35
  74. pixeltable/func/expr_template_function.py +101 -48
  75. pixeltable/func/function.py +375 -62
  76. pixeltable/func/function_registry.py +20 -19
  77. pixeltable/func/globals.py +6 -5
  78. pixeltable/func/mcp.py +74 -0
  79. pixeltable/func/query_template_function.py +151 -35
  80. pixeltable/func/signature.py +178 -49
  81. pixeltable/func/tools.py +164 -0
  82. pixeltable/func/udf.py +176 -53
  83. pixeltable/functions/__init__.py +44 -4
  84. pixeltable/functions/anthropic.py +226 -47
  85. pixeltable/functions/audio.py +148 -11
  86. pixeltable/functions/bedrock.py +137 -0
  87. pixeltable/functions/date.py +188 -0
  88. pixeltable/functions/deepseek.py +113 -0
  89. pixeltable/functions/document.py +81 -0
  90. pixeltable/functions/fal.py +76 -0
  91. pixeltable/functions/fireworks.py +72 -20
  92. pixeltable/functions/gemini.py +249 -0
  93. pixeltable/functions/globals.py +208 -53
  94. pixeltable/functions/groq.py +108 -0
  95. pixeltable/functions/huggingface.py +1088 -95
  96. pixeltable/functions/image.py +155 -84
  97. pixeltable/functions/json.py +8 -11
  98. pixeltable/functions/llama_cpp.py +31 -19
  99. pixeltable/functions/math.py +169 -0
  100. pixeltable/functions/mistralai.py +50 -75
  101. pixeltable/functions/net.py +70 -0
  102. pixeltable/functions/ollama.py +29 -36
  103. pixeltable/functions/openai.py +548 -160
  104. pixeltable/functions/openrouter.py +143 -0
  105. pixeltable/functions/replicate.py +15 -14
  106. pixeltable/functions/reve.py +250 -0
  107. pixeltable/functions/string.py +310 -85
  108. pixeltable/functions/timestamp.py +37 -19
  109. pixeltable/functions/together.py +77 -120
  110. pixeltable/functions/twelvelabs.py +188 -0
  111. pixeltable/functions/util.py +7 -2
  112. pixeltable/functions/uuid.py +30 -0
  113. pixeltable/functions/video.py +1528 -117
  114. pixeltable/functions/vision.py +26 -26
  115. pixeltable/functions/voyageai.py +289 -0
  116. pixeltable/functions/whisper.py +19 -10
  117. pixeltable/functions/whisperx.py +179 -0
  118. pixeltable/functions/yolox.py +112 -0
  119. pixeltable/globals.py +716 -236
  120. pixeltable/index/__init__.py +3 -1
  121. pixeltable/index/base.py +17 -21
  122. pixeltable/index/btree.py +32 -22
  123. pixeltable/index/embedding_index.py +155 -92
  124. pixeltable/io/__init__.py +12 -7
  125. pixeltable/io/datarows.py +140 -0
  126. pixeltable/io/external_store.py +83 -125
  127. pixeltable/io/fiftyone.py +24 -33
  128. pixeltable/io/globals.py +47 -182
  129. pixeltable/io/hf_datasets.py +96 -127
  130. pixeltable/io/label_studio.py +171 -156
  131. pixeltable/io/lancedb.py +3 -0
  132. pixeltable/io/pandas.py +136 -115
  133. pixeltable/io/parquet.py +40 -153
  134. pixeltable/io/table_data_conduit.py +702 -0
  135. pixeltable/io/utils.py +100 -0
  136. pixeltable/iterators/__init__.py +8 -4
  137. pixeltable/iterators/audio.py +207 -0
  138. pixeltable/iterators/base.py +9 -3
  139. pixeltable/iterators/document.py +144 -87
  140. pixeltable/iterators/image.py +17 -38
  141. pixeltable/iterators/string.py +15 -12
  142. pixeltable/iterators/video.py +523 -127
  143. pixeltable/metadata/__init__.py +33 -8
  144. pixeltable/metadata/converters/convert_10.py +2 -3
  145. pixeltable/metadata/converters/convert_13.py +2 -2
  146. pixeltable/metadata/converters/convert_15.py +15 -11
  147. pixeltable/metadata/converters/convert_16.py +4 -5
  148. pixeltable/metadata/converters/convert_17.py +4 -5
  149. pixeltable/metadata/converters/convert_18.py +4 -6
  150. pixeltable/metadata/converters/convert_19.py +6 -9
  151. pixeltable/metadata/converters/convert_20.py +3 -6
  152. pixeltable/metadata/converters/convert_21.py +6 -8
  153. pixeltable/metadata/converters/convert_22.py +3 -2
  154. pixeltable/metadata/converters/convert_23.py +33 -0
  155. pixeltable/metadata/converters/convert_24.py +55 -0
  156. pixeltable/metadata/converters/convert_25.py +19 -0
  157. pixeltable/metadata/converters/convert_26.py +23 -0
  158. pixeltable/metadata/converters/convert_27.py +29 -0
  159. pixeltable/metadata/converters/convert_28.py +13 -0
  160. pixeltable/metadata/converters/convert_29.py +110 -0
  161. pixeltable/metadata/converters/convert_30.py +63 -0
  162. pixeltable/metadata/converters/convert_31.py +11 -0
  163. pixeltable/metadata/converters/convert_32.py +15 -0
  164. pixeltable/metadata/converters/convert_33.py +17 -0
  165. pixeltable/metadata/converters/convert_34.py +21 -0
  166. pixeltable/metadata/converters/convert_35.py +9 -0
  167. pixeltable/metadata/converters/convert_36.py +38 -0
  168. pixeltable/metadata/converters/convert_37.py +15 -0
  169. pixeltable/metadata/converters/convert_38.py +39 -0
  170. pixeltable/metadata/converters/convert_39.py +124 -0
  171. pixeltable/metadata/converters/convert_40.py +73 -0
  172. pixeltable/metadata/converters/convert_41.py +12 -0
  173. pixeltable/metadata/converters/convert_42.py +9 -0
  174. pixeltable/metadata/converters/convert_43.py +44 -0
  175. pixeltable/metadata/converters/util.py +44 -18
  176. pixeltable/metadata/notes.py +21 -0
  177. pixeltable/metadata/schema.py +185 -42
  178. pixeltable/metadata/utils.py +74 -0
  179. pixeltable/mypy/__init__.py +3 -0
  180. pixeltable/mypy/mypy_plugin.py +123 -0
  181. pixeltable/plan.py +616 -225
  182. pixeltable/share/__init__.py +3 -0
  183. pixeltable/share/packager.py +797 -0
  184. pixeltable/share/protocol/__init__.py +33 -0
  185. pixeltable/share/protocol/common.py +165 -0
  186. pixeltable/share/protocol/operation_types.py +33 -0
  187. pixeltable/share/protocol/replica.py +119 -0
  188. pixeltable/share/publish.py +349 -0
  189. pixeltable/store.py +398 -232
  190. pixeltable/type_system.py +730 -267
  191. pixeltable/utils/__init__.py +40 -0
  192. pixeltable/utils/arrow.py +201 -29
  193. pixeltable/utils/av.py +298 -0
  194. pixeltable/utils/azure_store.py +346 -0
  195. pixeltable/utils/coco.py +26 -27
  196. pixeltable/utils/code.py +4 -4
  197. pixeltable/utils/console_output.py +46 -0
  198. pixeltable/utils/coroutine.py +24 -0
  199. pixeltable/utils/dbms.py +92 -0
  200. pixeltable/utils/description_helper.py +11 -12
  201. pixeltable/utils/documents.py +60 -61
  202. pixeltable/utils/exception_handler.py +36 -0
  203. pixeltable/utils/filecache.py +38 -22
  204. pixeltable/utils/formatter.py +88 -51
  205. pixeltable/utils/gcs_store.py +295 -0
  206. pixeltable/utils/http.py +133 -0
  207. pixeltable/utils/http_server.py +14 -13
  208. pixeltable/utils/iceberg.py +13 -0
  209. pixeltable/utils/image.py +17 -0
  210. pixeltable/utils/lancedb.py +90 -0
  211. pixeltable/utils/local_store.py +322 -0
  212. pixeltable/utils/misc.py +5 -0
  213. pixeltable/utils/object_stores.py +573 -0
  214. pixeltable/utils/pydantic.py +60 -0
  215. pixeltable/utils/pytorch.py +20 -20
  216. pixeltable/utils/s3_store.py +527 -0
  217. pixeltable/utils/sql.py +32 -5
  218. pixeltable/utils/system.py +30 -0
  219. pixeltable/utils/transactional_directory.py +4 -3
  220. pixeltable-0.5.7.dist-info/METADATA +579 -0
  221. pixeltable-0.5.7.dist-info/RECORD +227 -0
  222. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  223. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  224. pixeltable/__version__.py +0 -3
  225. pixeltable/catalog/named_function.py +0 -36
  226. pixeltable/catalog/path_dict.py +0 -141
  227. pixeltable/dataframe.py +0 -894
  228. pixeltable/exec/expr_eval_node.py +0 -232
  229. pixeltable/ext/__init__.py +0 -14
  230. pixeltable/ext/functions/__init__.py +0 -8
  231. pixeltable/ext/functions/whisperx.py +0 -77
  232. pixeltable/ext/functions/yolox.py +0 -157
  233. pixeltable/tool/create_test_db_dump.py +0 -311
  234. pixeltable/tool/create_test_video.py +0 -81
  235. pixeltable/tool/doc_plugins/griffe.py +0 -50
  236. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  237. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  238. pixeltable/tool/embed_udf.py +0 -9
  239. pixeltable/tool/mypy_plugin.py +0 -55
  240. pixeltable/utils/media_store.py +0 -76
  241. pixeltable/utils/s3.py +0 -16
  242. pixeltable-0.2.26.dist-info/METADATA +0 -400
  243. pixeltable-0.2.26.dist-info/RECORD +0 -156
  244. pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
  245. {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,31 +1,34 @@
1
+ import datetime
1
2
  import logging
2
3
  import warnings
3
4
  from decimal import Decimal
4
- from typing import Iterable, Iterator, NamedTuple, Optional, TYPE_CHECKING, Sequence
5
+ from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Sequence
5
6
  from uuid import UUID
6
7
 
7
8
  import sqlalchemy as sql
8
9
 
9
- import pixeltable.catalog as catalog
10
- import pixeltable.exprs as exprs
10
+ from pixeltable import catalog, exprs
11
+ from pixeltable.env import Env
12
+
11
13
  from .data_row_batch import DataRowBatch
12
14
  from .exec_node import ExecNode
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  import pixeltable.plan
18
+ from pixeltable.plan import SampleClause
16
19
 
17
20
  _logger = logging.getLogger('pixeltable')
18
21
 
19
22
 
20
23
  class OrderByItem(NamedTuple):
21
24
  expr: exprs.Expr
22
- asc: Optional[bool]
25
+ asc: bool | None
23
26
 
24
27
 
25
28
  OrderByClause = list[OrderByItem]
26
29
 
27
30
 
28
- def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
31
+ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> OrderByClause | None:
29
32
  """Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
30
33
  Two clauses are compatible if for each of their respective items c1[i] and c2[i]
31
34
  a) the exprs are identical and
@@ -53,56 +56,91 @@ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[Order
53
56
 
54
57
 
55
58
  def print_order_by_clause(clause: OrderByClause) -> str:
56
- return ', '.join([
59
+ return ', '.join(
57
60
  f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
58
61
  for item in clause
59
- ])
62
+ )
60
63
 
61
64
 
62
65
  class SqlNode(ExecNode):
63
66
  """
64
- Materializes data from the store via a Select stmt.
67
+ Materializes data from the store via a SQL statement.
65
68
  This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
69
+ The pk columns are not included in the select list.
70
+ If set_pk is True, they are added to the end of the result set when creating the SQL statement
71
+ so they can always be referenced as cols[-num_pk_cols:] in the result set.
72
+ The pk_columns consist of the rowid columns of the target table followed by the version number.
73
+
74
+ If row_builder contains references to unstored iter columns, expands the select list to include their
75
+ SQL-materializable subexpressions.
76
+
77
+ Args:
78
+ select_list: output of the query
79
+ set_pk: if True, sets the primary for each DataRow
66
80
  """
67
81
 
68
- tbl: Optional[catalog.TableVersionPath]
82
+ tbl: catalog.TableVersionPath | None
69
83
  select_list: exprs.ExprSet
84
+ columns: list[catalog.Column] # for which columns to populate DataRow.cell_vals/cell_md
85
+ cell_md_refs: list[exprs.ColumnPropertyRef] # of ColumnRefs which also need DataRow.slot_cellmd for evaluation
70
86
  set_pk: bool
71
87
  num_pk_cols: int
72
- py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
73
- py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
74
- cte: Optional[sql.CTE]
88
+ py_filter: exprs.Expr | None # a predicate that can only be run in Python
89
+ py_filter_eval_ctx: exprs.RowBuilder.EvalCtx | None
90
+ cte: sql.CTE | None
75
91
  sql_elements: exprs.SqlElementCache
76
92
 
93
+ # execution state
94
+ sql_select_list_exprs: exprs.ExprSet
95
+ cellmd_item_idxs: exprs.ExprDict[int] # cellmd expr -> idx in sql select list
96
+ column_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
97
+ column_cellmd_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
98
+ result_cursor: sql.engine.CursorResult | None
99
+
77
100
  # where_clause/-_element: allow subclass to set one or the other (but not both)
78
- where_clause: Optional[exprs.Expr]
79
- where_clause_element: Optional[sql.ColumnElement]
101
+ where_clause: exprs.Expr | None
102
+ where_clause_element: sql.ColumnElement | None
80
103
 
81
104
  order_by_clause: OrderByClause
82
- limit: Optional[int]
105
+ limit: int | None
83
106
 
84
107
  def __init__(
85
- self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
86
- select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
108
+ self,
109
+ tbl: catalog.TableVersionPath | None,
110
+ row_builder: exprs.RowBuilder,
111
+ select_list: Iterable[exprs.Expr],
112
+ columns: list[catalog.Column],
113
+ sql_elements: exprs.SqlElementCache,
114
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
115
+ set_pk: bool = False,
87
116
  ):
88
- """
89
- If row_builder contains references to unstored iter columns, expands the select list to include their
90
- SQL-materializable subexpressions.
91
-
92
- Args:
93
- select_list: output of the query
94
- set_pk: if True, sets the primary for each DataRow
95
- """
96
117
  # create Select stmt
97
118
  self.sql_elements = sql_elements
98
119
  self.tbl = tbl
99
- assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
120
+ self.columns = columns
121
+ if cell_md_col_refs is not None:
122
+ assert all(ref.col.stores_cellmd for ref in cell_md_col_refs)
123
+ self.cell_md_refs = [
124
+ exprs.ColumnPropertyRef(ref, exprs.ColumnPropertyRef.Property.CELLMD) for ref in cell_md_col_refs
125
+ ]
126
+ else:
127
+ self.cell_md_refs = []
100
128
  self.select_list = exprs.ExprSet(select_list)
101
- # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
129
+ # unstored iter columns: we also need to retrieve whatever is needed to materialize the
130
+ # iter args and stored outputs
102
131
  for iter_arg in row_builder.unstored_iter_args.values():
103
132
  sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
104
- for e in sql_subexprs:
105
- self.select_list.add(e)
133
+ self.select_list.update(sql_subexprs)
134
+ # We query for unstored outputs only if we're not loading a view; when we're loading a view, we are populating
135
+ # those columns, so we need to keep them out of the select list. This isn't a problem, because view loads never
136
+ # need to call set_pos().
137
+ # TODO: This is necessary because create_view_load_plan passes stored output columns to `RowBuilder` via the
138
+ # `columns` parameter (even though they don't appear in `output_exprs`). This causes them to be recorded as
139
+ # expressions in `RowBuilder`, which creates a conflict if we add them here. If `RowBuilder` is restructured
140
+ # to keep them out of `unique_exprs`, then this conditional can be removed.
141
+ if not row_builder.for_view_load:
142
+ for outputs in row_builder.unstored_iter_outputs.values():
143
+ self.select_list.update(outputs)
106
144
  super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
107
145
 
108
146
  if tbl is not None:
@@ -115,9 +153,13 @@ class SqlNode(ExecNode):
115
153
  if set_pk:
116
154
  # we also need to retrieve the pk columns
117
155
  assert tbl is not None
118
- self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
156
+ self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
157
+ assert self.num_pk_cols > 1
119
158
 
120
159
  # additional state
160
+ self.cellmd_item_idxs = exprs.ExprDict()
161
+ self.column_item_idxs = {}
162
+ self.column_cellmd_item_idxs = {}
121
163
  self.result_cursor = None
122
164
  # the filter is provided by the subclass
123
165
  self.py_filter = None
@@ -128,14 +170,38 @@ class SqlNode(ExecNode):
128
170
  self.where_clause_element = None
129
171
  self.order_by_clause = []
130
172
 
131
- def _create_stmt(self) -> sql.Select:
132
- """Create Select from local state"""
173
+ if self.tbl is not None:
174
+ tv = self.tbl.tbl_version._tbl_version
175
+ if tv is not None:
176
+ assert tv.is_validated
133
177
 
134
- assert self.sql_elements.contains_all(self.select_list)
135
- sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
178
+ def _pk_col_items(self) -> list[sql.Column]:
136
179
  if self.set_pk:
180
+ # we need to retrieve the pk columns
137
181
  assert self.tbl is not None
138
- sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
182
+ assert self.tbl.tbl_version.get().is_validated
183
+ return self.tbl.tbl_version.get().store_tbl.pk_columns()
184
+ return []
185
+
186
+ def _init_exec_state(self) -> None:
187
+ assert self.sql_elements.contains_all(self.select_list)
188
+ self.sql_select_list_exprs = exprs.ExprSet(self.select_list)
189
+ self.cellmd_item_idxs = exprs.ExprDict((ref, self.sql_select_list_exprs.add(ref)) for ref in self.cell_md_refs)
190
+ column_refs = [exprs.ColumnRef(col) for col in self.columns]
191
+ self.column_item_idxs = {col_ref.col: self.sql_select_list_exprs.add(col_ref) for col_ref in column_refs}
192
+ column_cellmd_refs = [
193
+ exprs.ColumnPropertyRef(col_ref, exprs.ColumnPropertyRef.Property.CELLMD)
194
+ for col_ref in column_refs
195
+ if col_ref.col.stores_cellmd
196
+ ]
197
+ self.column_cellmd_item_idxs = {
198
+ cellmd_ref.col_ref.col: self.sql_select_list_exprs.add(cellmd_ref) for cellmd_ref in column_cellmd_refs
199
+ }
200
+
201
+ def _create_stmt(self) -> sql.Select:
202
+ """Create Select from local state"""
203
+ self._init_exec_state()
204
+ sql_select_list = [self.sql_elements.get(e) for e in self.sql_select_list_exprs] + self._pk_col_items()
139
205
  stmt = sql.select(*sql_select_list)
140
206
 
141
207
  where_clause_element = (
@@ -161,9 +227,10 @@ class SqlNode(ExecNode):
161
227
  def _ordering_tbl_ids(self) -> set[UUID]:
162
228
  return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
163
229
 
164
- def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
230
+ def to_cte(self, keep_pk: bool = False) -> tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]] | None:
165
231
  """
166
- Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
232
+ Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
233
+ keep_pk: if True, the PK columns are included in the CTE Select statement
167
234
 
168
235
  Returns:
169
236
  (CTE, dict from Expr to output column)
@@ -171,11 +238,11 @@ class SqlNode(ExecNode):
171
238
  if self.py_filter is not None:
172
239
  # the filter needs to run in Python
173
240
  return None
174
- self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
175
241
  if self.cte is None:
242
+ if not keep_pk:
243
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
176
244
  self.cte = self._create_stmt().cte()
177
- assert len(self.cte.c) == len(self.select_list)
178
- return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
245
+ return self.cte, exprs.ExprDict(zip(list(self.select_list) + self.cell_md_refs, self.cte.c)) # skip pk cols
179
246
 
180
247
  @classmethod
181
248
  def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
@@ -186,8 +253,11 @@ class SqlNode(ExecNode):
186
253
 
187
254
  @classmethod
188
255
  def create_from_clause(
189
- cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[set[UUID]] = None,
190
- exact_version_only: Optional[set[UUID]] = None
256
+ cls,
257
+ tbl: catalog.TableVersionPath,
258
+ stmt: sql.Select,
259
+ refd_tbl_ids: set[UUID] | None = None,
260
+ exact_version_only: set[UUID] | None = None,
191
261
  ) -> sql.Select:
192
262
  """Add From clause to stmt for tables/views referenced by materialized_exprs
193
263
  Args:
@@ -205,31 +275,35 @@ class SqlNode(ExecNode):
205
275
  exact_version_only = set()
206
276
  candidates = tbl.get_tbl_versions()
207
277
  assert len(candidates) > 0
208
- joined_tbls: list[catalog.TableVersion] = [candidates[0]]
209
- for tbl in candidates[1:]:
210
- if tbl.id in refd_tbl_ids:
211
- joined_tbls.append(tbl)
278
+ joined_tbls: list[catalog.TableVersionHandle] = [candidates[0]]
279
+ for t in candidates[1:]:
280
+ if t.id in refd_tbl_ids:
281
+ joined_tbls.append(t)
212
282
 
213
283
  first = True
214
- prev_tbl: catalog.TableVersion
215
- for tbl in joined_tbls[::-1]:
284
+ prev_tv: catalog.TableVersion | None = None
285
+ for t in joined_tbls[::-1]:
286
+ tv = t.get()
287
+ # _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
216
288
  if first:
217
- stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
289
+ stmt = stmt.select_from(tv.store_tbl.sa_tbl)
218
290
  first = False
219
291
  else:
220
- # join tbl to prev_tbl on prev_tbl's rowid cols
221
- prev_tbl_rowid_cols = prev_tbl.store_tbl.rowid_columns()
222
- tbl_rowid_cols = tbl.store_tbl.rowid_columns()
223
- rowid_clauses = \
224
- [c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[:len(prev_tbl_rowid_cols)])]
225
- stmt = stmt.join(tbl.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
226
- if tbl.id in exact_version_only:
227
- stmt = stmt.where(tbl.store_tbl.v_min_col == tbl.version)
292
+ # join tv to prev_tv on prev_tv's rowid cols
293
+ prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
294
+ tbl_rowid_cols = tv.store_tbl.rowid_columns()
295
+ rowid_clauses = [
296
+ c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
297
+ ]
298
+ stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
299
+
300
+ if t.id in exact_version_only:
301
+ stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
228
302
  else:
229
- stmt = stmt \
230
- .where(tbl.store_tbl.v_min_col <= tbl.version) \
231
- .where(tbl.store_tbl.v_max_col > tbl.version)
232
- prev_tbl = tbl
303
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
304
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
305
+ prev_tv = tv
306
+
233
307
  return stmt
234
308
 
235
309
  def set_where(self, where_clause: exprs.Expr) -> None:
@@ -255,18 +329,18 @@ class SqlNode(ExecNode):
255
329
  self.limit = limit
256
330
 
257
331
  def _log_explain(self, stmt: sql.Select) -> None:
332
+ conn = Env.get().conn
258
333
  try:
259
334
  # don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
260
335
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
261
- explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
336
+ explain_result = conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
262
337
  explain_str = '\n'.join([str(row) for row in explain_result])
263
338
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
264
339
  except Exception as e:
265
- _logger.warning(f'EXPLAIN failed')
340
+ _logger.warning(f'EXPLAIN failed with error: {e}')
266
341
 
267
- def __iter__(self) -> Iterator[DataRowBatch]:
342
+ async def __aiter__(self) -> AsyncIterator[DataRowBatch]:
268
343
  # run the query; do this here rather than in _open(), exceptions are only expected during iteration
269
- assert self.ctx.conn is not None
270
344
  with warnings.catch_warnings(record=True) as w:
271
345
  stmt = self._create_stmt()
272
346
  try:
@@ -274,35 +348,65 @@ class SqlNode(ExecNode):
274
348
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
275
349
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
276
350
  except Exception:
277
- pass
351
+ # log something if we can't log the compiled stmt
352
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
278
353
  self._log_explain(stmt)
279
354
 
280
- result_cursor = self.ctx.conn.execute(stmt)
281
- for warning in w:
355
+ conn = Env.get().conn
356
+ result_cursor = conn.execute(stmt)
357
+ for _ in w:
282
358
  pass
283
359
 
284
- tbl_version = self.tbl.tbl_version if self.tbl is not None else None
285
- output_batch = DataRowBatch(tbl_version, self.row_builder)
286
- output_row: Optional[exprs.DataRow] = None
360
+ output_batch = DataRowBatch(self.row_builder)
361
+ output_row: exprs.DataRow | None = None
287
362
  num_rows_returned = 0
363
+ is_using_cockroachdb = Env.get().is_using_cockroachdb
364
+ tzinfo = Env.get().default_time_zone
288
365
 
289
366
  for sql_row in result_cursor:
290
367
  output_row = output_batch.add_row(output_row)
291
368
 
292
369
  # populate output_row
370
+
293
371
  if self.num_pk_cols > 0:
294
- output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
372
+ output_row.set_pk(tuple(sql_row[-self.num_pk_cols :]))
373
+
374
+ # column copies
375
+ for col, item_idx in self.column_item_idxs.items():
376
+ output_row.cell_vals[col.id] = sql_row[item_idx]
377
+ for col, item_idx in self.column_cellmd_item_idxs.items():
378
+ cell_md_dict = sql_row[item_idx]
379
+ output_row.cell_md[col.id] = exprs.CellMd(**cell_md_dict) if cell_md_dict is not None else None
380
+
381
+ # populate DataRow.slot_cellmd, where requested
382
+ for cellmd_ref, item_idx in self.cellmd_item_idxs.items():
383
+ cell_md_dict = sql_row[item_idx]
384
+ output_row.slot_md[cellmd_ref.col_ref.slot_idx] = (
385
+ exprs.CellMd.from_dict(cell_md_dict) if cell_md_dict is not None else None
386
+ )
387
+
295
388
  # copy the output of the SQL query into the output row
296
389
  for i, e in enumerate(self.select_list):
297
390
  slot_idx = e.slot_idx
298
- # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
299
391
  if isinstance(sql_row[i], Decimal):
392
+ # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
300
393
  if e.col_type.is_int_type():
301
394
  output_row[slot_idx] = int(sql_row[i])
302
395
  elif e.col_type.is_float_type():
303
396
  output_row[slot_idx] = float(sql_row[i])
304
397
  else:
305
398
  raise RuntimeError(f'Unexpected Decimal value for {e}')
399
+ elif is_using_cockroachdb and isinstance(sql_row[i], datetime.datetime):
400
+ # Ensure that the datetime is timezone-aware and in the session time zone
401
+ # cockroachDB returns timestamps in the session time zone, with numeric offset,
402
+ # convert to the session time zone with the requested tzinfo for DST handling
403
+ if e.col_type.is_timestamp_type():
404
+ if isinstance(sql_row[i].tzinfo, datetime.timezone):
405
+ output_row[slot_idx] = sql_row[i].astimezone(tz=tzinfo)
406
+ else:
407
+ output_row[slot_idx] = sql_row[i]
408
+ else:
409
+ raise RuntimeError(f'Unexpected datetime value for {e}')
306
410
  else:
307
411
  output_row[slot_idx] = sql_row[i]
308
412
 
@@ -324,7 +428,7 @@ class SqlNode(ExecNode):
324
428
  if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
325
429
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
326
430
  yield output_batch
327
- output_batch = DataRowBatch(tbl_version, self.row_builder)
431
+ output_batch = DataRowBatch(self.row_builder)
328
432
 
329
433
  if len(output_batch) > 0:
330
434
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
@@ -340,22 +444,35 @@ class SqlScanNode(SqlNode):
340
444
  Materializes data from the store via a Select stmt.
341
445
 
342
446
  Supports filtering and ordering.
447
+
448
+ Args:
449
+ select_list: output of the query
450
+ set_pk: if True, sets the primary for each DataRow
451
+ exact_version_only: tables for which we only want to see rows created at the current version
343
452
  """
344
- exact_version_only: list[catalog.TableVersion]
453
+
454
+ exact_version_only: list[catalog.TableVersionHandle]
345
455
 
346
456
  def __init__(
347
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
457
+ self,
458
+ tbl: catalog.TableVersionPath,
459
+ row_builder: exprs.RowBuilder,
348
460
  select_list: Iterable[exprs.Expr],
349
- set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
461
+ columns: list[catalog.Column],
462
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
463
+ set_pk: bool = False,
464
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
350
465
  ):
351
- """
352
- Args:
353
- select_list: output of the query
354
- set_pk: if True, sets the primary for each DataRow
355
- exact_version_only: tables for which we only want to see rows created at the current version
356
- """
357
466
  sql_elements = exprs.SqlElementCache()
358
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
467
+ super().__init__(
468
+ tbl,
469
+ row_builder,
470
+ select_list,
471
+ columns=columns,
472
+ sql_elements=sql_elements,
473
+ set_pk=set_pk,
474
+ cell_md_col_refs=cell_md_col_refs,
475
+ )
359
476
  # create Select stmt
360
477
  if exact_version_only is None:
361
478
  exact_version_only = []
@@ -367,27 +484,41 @@ class SqlScanNode(SqlNode):
367
484
  where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
368
485
  refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
369
486
  stmt = self.create_from_clause(
370
- self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
487
+ self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only}
488
+ )
371
489
  return stmt
372
490
 
373
491
 
374
492
  class SqlLookupNode(SqlNode):
375
493
  """
376
494
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
495
+
496
+ Args:
497
+ select_list: output of the query
498
+ sa_key_cols: list of key columns in the store table
499
+ key_vals: list of key values to look up
377
500
  """
378
501
 
379
502
  def __init__(
380
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
381
- select_list: Iterable[exprs.Expr], sa_key_cols: list[sql.Column], key_vals: list[tuple],
503
+ self,
504
+ tbl: catalog.TableVersionPath,
505
+ row_builder: exprs.RowBuilder,
506
+ select_list: Iterable[exprs.Expr],
507
+ columns: list[catalog.Column],
508
+ sa_key_cols: list[sql.Column],
509
+ key_vals: list[tuple],
510
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
382
511
  ):
383
- """
384
- Args:
385
- select_list: output of the query
386
- sa_key_cols: list of key columns in the store table
387
- key_vals: list of key values to look up
388
- """
389
512
  sql_elements = exprs.SqlElementCache()
390
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
513
+ super().__init__(
514
+ tbl,
515
+ row_builder,
516
+ select_list,
517
+ columns=columns,
518
+ sql_elements=sql_elements,
519
+ set_pk=True,
520
+ cell_md_col_refs=cell_md_col_refs,
521
+ )
391
522
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
392
523
  self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
393
524
 
@@ -401,30 +532,33 @@ class SqlLookupNode(SqlNode):
401
532
  class SqlAggregationNode(SqlNode):
402
533
  """
403
534
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
535
+
536
+ Args:
537
+ select_list: can contain calls to AggregateFunctions
538
+ group_by_items: list of expressions to group by
539
+ limit: max number of rows to return: None = no limit
404
540
  """
405
541
 
406
- group_by_items: Optional[list[exprs.Expr]]
542
+ group_by_items: list[exprs.Expr] | None
543
+ input_cte: sql.CTE | None
407
544
 
408
545
  def __init__(
409
- self, row_builder: exprs.RowBuilder,
546
+ self,
547
+ row_builder: exprs.RowBuilder,
410
548
  input: SqlNode,
411
549
  select_list: Iterable[exprs.Expr],
412
- group_by_items: Optional[list[exprs.Expr]] = None,
413
- limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
550
+ group_by_items: list[exprs.Expr] | None = None,
551
+ limit: int | None = None,
552
+ exact_version_only: list[catalog.TableVersion] | None = None,
414
553
  ):
415
- """
416
- Args:
417
- select_list: can contain calls to AggregateFunctions
418
- group_by_items: list of expressions to group by
419
- limit: max number of rows to return: None = no limit
420
- """
421
- _, input_col_map = input.to_cte()
554
+ assert len(input.cell_md_refs) == 0 # there's no aggregation over json or arrays in SQL
555
+ self.input_cte, input_col_map = input.to_cte()
422
556
  sql_elements = exprs.SqlElementCache(input_col_map)
423
- super().__init__(None, row_builder, select_list, sql_elements)
557
+ super().__init__(None, row_builder, select_list, columns=[], sql_elements=sql_elements)
424
558
  self.group_by_items = group_by_items
425
559
 
426
560
  def _create_stmt(self) -> sql.Select:
427
- stmt = super()._create_stmt()
561
+ stmt = super()._create_stmt().select_from(self.input_cte)
428
562
  if self.group_by_items is not None:
429
563
  sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
430
564
  assert all(e is not None for e in sql_group_by_items)
@@ -436,12 +570,16 @@ class SqlJoinNode(SqlNode):
436
570
  """
437
571
  Materializes data from the store via a Select ... From ... that contains joins
438
572
  """
573
+
439
574
  input_ctes: list[sql.CTE]
440
575
  join_clauses: list['pixeltable.plan.JoinClause']
441
576
 
442
577
  def __init__(
443
- self, row_builder: exprs.RowBuilder,
444
- inputs: Sequence[SqlNode], join_clauses: list['pixeltable.plan.JoinClause'], select_list: Iterable[exprs.Expr]
578
+ self,
579
+ row_builder: exprs.RowBuilder,
580
+ inputs: Sequence[SqlNode],
581
+ join_clauses: list['pixeltable.plan.JoinClause'],
582
+ select_list: Iterable[exprs.Expr],
445
583
  ):
446
584
  assert len(inputs) > 1
447
585
  assert len(inputs) == len(join_clauses) + 1
@@ -452,20 +590,181 @@ class SqlJoinNode(SqlNode):
452
590
  input_cte, input_col_map = input_node.to_cte()
453
591
  self.input_ctes.append(input_cte)
454
592
  sql_elements.extend(input_col_map)
455
- super().__init__(None, row_builder, select_list, sql_elements)
593
+ cell_md_col_refs = [cell_md_ref.col_ref for input in inputs for cell_md_ref in input.cell_md_refs]
594
+ super().__init__(
595
+ None, row_builder, select_list, columns=[], sql_elements=sql_elements, cell_md_col_refs=cell_md_col_refs
596
+ )
456
597
 
457
598
  def _create_stmt(self) -> sql.Select:
458
599
  from pixeltable import plan
600
+
459
601
  stmt = super()._create_stmt()
460
602
  stmt = stmt.select_from(self.input_ctes[0])
461
603
  for i in range(len(self.join_clauses)):
462
604
  join_clause = self.join_clauses[i]
463
605
  on_clause = (
464
- self.sql_elements.get(join_clause.join_predicate) if join_clause.join_type != plan.JoinType.CROSS
606
+ self.sql_elements.get(join_clause.join_predicate)
607
+ if join_clause.join_type != plan.JoinType.CROSS
465
608
  else sql.sql.expression.literal(True)
466
609
  )
467
- is_outer = join_clause.join_type == plan.JoinType.LEFT or join_clause.join_type == plan.JoinType.FULL_OUTER
610
+ is_outer = join_clause.join_type in (plan.JoinType.LEFT, plan.JoinType.FULL_OUTER)
468
611
  stmt = stmt.join(
469
- self.input_ctes[i + 1], onclause=on_clause, isouter=is_outer,
470
- full=join_clause == plan.JoinType.FULL_OUTER)
471
- return stmt
612
+ self.input_ctes[i + 1],
613
+ onclause=on_clause,
614
+ isouter=is_outer,
615
+ full=join_clause == plan.JoinType.FULL_OUTER,
616
+ )
617
+ return stmt
618
+
619
+
620
+ class SqlSampleNode(SqlNode):
621
+ """
622
+ Returns rows sampled from the input node.
623
+
624
+ Args:
625
+ input: SqlNode to sample from
626
+ select_list: can contain calls to AggregateFunctions
627
+ sample_clause: specifies the sampling method
628
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
629
+ """
630
+
631
+ input_cte: sql.CTE | None
632
+ pk_count: int
633
+ stratify_exprs: list[exprs.Expr] | None
634
+ sample_clause: 'SampleClause'
635
+
636
+ def __init__(
637
+ self,
638
+ row_builder: exprs.RowBuilder,
639
+ input: SqlNode,
640
+ select_list: Iterable[exprs.Expr],
641
+ sample_clause: 'SampleClause',
642
+ stratify_exprs: list[exprs.Expr],
643
+ ):
644
+ assert isinstance(input, SqlNode)
645
+ self.input_cte, input_col_map = input.to_cte(keep_pk=True)
646
+ self.pk_count = input.num_pk_cols
647
+ assert self.pk_count > 1
648
+ sql_elements = exprs.SqlElementCache(input_col_map)
649
+ assert sql_elements.contains_all(stratify_exprs)
650
+ cell_md_col_refs = [cell_md_ref.col_ref for cell_md_ref in input.cell_md_refs]
651
+ super().__init__(
652
+ input.tbl,
653
+ row_builder,
654
+ select_list,
655
+ columns=[],
656
+ sql_elements=sql_elements,
657
+ cell_md_col_refs=cell_md_col_refs,
658
+ set_pk=True,
659
+ )
660
+ self.stratify_exprs = stratify_exprs
661
+ self.sample_clause = sample_clause
662
+
663
+ @classmethod
664
+ def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
665
+ """Construct expression which is the ordering key for rows to be sampled
666
+ General SQL form is:
667
+ - MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
668
+ """
669
+ sql_expr: sql.ColumnElement = seed.cast(sql.String)
670
+ for e in sql_cols:
671
+ # Quotes are required below to guarantee that the string is properly presented in SQL
672
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + e.cast(sql.String)
673
+ sql_expr = sql.func.md5(sql_expr)
674
+ return sql_expr
675
+
676
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
677
+ """Create an expression for randomly ordering rows with a given seed"""
678
+ rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
679
+ assert len(rowid_cols) > 0
680
+ # If seed is not set in the sample clause, use the random seed given by the execution context
681
+ seed = self.sample_clause.seed if self.sample_clause.seed is not None else self.ctx.random_seed
682
+ return self.key_sql_expr(sql.literal_column(str(seed)), rowid_cols)
683
+
684
+ def _create_stmt(self) -> sql.Select:
685
+ from pixeltable.plan import SampleClause
686
+
687
+ self._init_exec_state()
688
+
689
+ if self.sample_clause.fraction is not None:
690
+ if len(self.stratify_exprs) == 0:
691
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
692
+ s_key = self._create_key_sql(self.input_cte)
693
+
694
+ # Construct a suitable where clause
695
+ fraction_md5 = SampleClause.fraction_to_md5_hex(self.sample_clause.fraction)
696
+ order_by = self._create_key_sql(self.input_cte)
697
+ return sql.select(*self.input_cte.c).where(s_key < fraction_md5).order_by(order_by)
698
+
699
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
700
+ else:
701
+ if len(self.stratify_exprs) == 0:
702
+ # No stratification, just return n samples from the input CTE
703
+ order_by = self._create_key_sql(self.input_cte)
704
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
705
+
706
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
707
+
708
+ def _create_stmt_stratified_n(self, n: int | None, n_per_stratum: int | None) -> sql.Select:
709
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
710
+
711
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
712
+ order_by = self._create_key_sql(self.input_cte)
713
+
714
+ # Create a list of all columns plus the rank
715
+ # Get all columns from the input CTE dynamically
716
+ select_columns = [*self.input_cte.c]
717
+ select_columns.append(
718
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
719
+ )
720
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
721
+
722
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
723
+ if n_per_stratum is not None:
724
+ return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
725
+ else:
726
+ secondary_order = self._create_key_sql(row_rank_cte)
727
+ return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
728
+
729
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
730
+ """Create a Select stmt that returns a fraction of the rows per strata"""
731
+
732
+ # Build the strata count CTE
733
+ # Produces a table of the form:
734
+ # (*stratify_exprs, s_s_size)
735
+ # where s_s_size is the number of samples to take from each stratum
736
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
737
+ per_strata_count_cte = (
738
+ sql.select(
739
+ *sql_strata_exprs,
740
+ sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
741
+ )
742
+ .select_from(self.input_cte)
743
+ .group_by(*sql_strata_exprs)
744
+ .cte('per_strata_count_cte')
745
+ )
746
+
747
+ # Build a CTE that ranks the rows within each stratum
748
+ # Include all columns from the input CTE dynamically
749
+ order_by = self._create_key_sql(self.input_cte)
750
+ select_columns = [*self.input_cte.c]
751
+ select_columns.append(
752
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
753
+ )
754
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
755
+
756
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
757
+ join_c = sql.true()
758
+ for col in per_strata_count_cte.c[:-1]:
759
+ join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
760
+
761
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
762
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
763
+ stmt = (
764
+ sql.select(*final_columns)
765
+ .select_from(row_rank_cte)
766
+ .join(per_strata_count_cte, join_c)
767
+ .where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
768
+ )
769
+
770
+ return stmt