pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (202) hide show
  1. pixeltable/__init__.py +23 -5
  2. pixeltable/_version.py +1 -0
  3. pixeltable/catalog/__init__.py +5 -3
  4. pixeltable/catalog/catalog.py +1318 -404
  5. pixeltable/catalog/column.py +186 -115
  6. pixeltable/catalog/dir.py +1 -2
  7. pixeltable/catalog/globals.py +11 -43
  8. pixeltable/catalog/insertable_table.py +167 -79
  9. pixeltable/catalog/path.py +61 -23
  10. pixeltable/catalog/schema_object.py +9 -10
  11. pixeltable/catalog/table.py +626 -308
  12. pixeltable/catalog/table_metadata.py +101 -0
  13. pixeltable/catalog/table_version.py +713 -569
  14. pixeltable/catalog/table_version_handle.py +37 -6
  15. pixeltable/catalog/table_version_path.py +42 -29
  16. pixeltable/catalog/tbl_ops.py +50 -0
  17. pixeltable/catalog/update_status.py +191 -0
  18. pixeltable/catalog/view.py +108 -94
  19. pixeltable/config.py +128 -22
  20. pixeltable/dataframe.py +188 -100
  21. pixeltable/env.py +407 -136
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +3 -0
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +231 -0
  27. pixeltable/exec/cell_reconstruction_node.py +135 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +7 -6
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +190 -30
  37. pixeltable/exec/globals.py +32 -0
  38. pixeltable/exec/in_memory_data_node.py +18 -18
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +206 -101
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +34 -30
  46. pixeltable/exprs/column_ref.py +92 -96
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +152 -55
  50. pixeltable/exprs/expr.py +62 -43
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +75 -37
  54. pixeltable/exprs/globals.py +1 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +10 -27
  57. pixeltable/exprs/is_null.py +1 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +5 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +127 -53
  64. pixeltable/exprs/rowid_ref.py +8 -12
  65. pixeltable/exprs/similarity_expr.py +50 -25
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +10 -10
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +20 -18
  78. pixeltable/func/signature.py +43 -16
  79. pixeltable/func/tools.py +23 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +6 -0
  82. pixeltable/functions/anthropic.py +93 -33
  83. pixeltable/functions/audio.py +114 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +1 -1
  86. pixeltable/functions/deepseek.py +20 -9
  87. pixeltable/functions/fireworks.py +2 -2
  88. pixeltable/functions/gemini.py +28 -11
  89. pixeltable/functions/globals.py +13 -13
  90. pixeltable/functions/groq.py +108 -0
  91. pixeltable/functions/huggingface.py +1046 -23
  92. pixeltable/functions/image.py +9 -18
  93. pixeltable/functions/llama_cpp.py +23 -8
  94. pixeltable/functions/math.py +3 -4
  95. pixeltable/functions/mistralai.py +4 -15
  96. pixeltable/functions/ollama.py +16 -9
  97. pixeltable/functions/openai.py +104 -82
  98. pixeltable/functions/openrouter.py +143 -0
  99. pixeltable/functions/replicate.py +2 -2
  100. pixeltable/functions/reve.py +250 -0
  101. pixeltable/functions/string.py +21 -28
  102. pixeltable/functions/timestamp.py +13 -14
  103. pixeltable/functions/together.py +4 -6
  104. pixeltable/functions/twelvelabs.py +92 -0
  105. pixeltable/functions/util.py +6 -1
  106. pixeltable/functions/video.py +1388 -106
  107. pixeltable/functions/vision.py +7 -7
  108. pixeltable/functions/whisper.py +15 -7
  109. pixeltable/functions/whisperx.py +179 -0
  110. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  111. pixeltable/globals.py +332 -105
  112. pixeltable/index/base.py +13 -22
  113. pixeltable/index/btree.py +23 -22
  114. pixeltable/index/embedding_index.py +32 -44
  115. pixeltable/io/__init__.py +4 -2
  116. pixeltable/io/datarows.py +7 -6
  117. pixeltable/io/external_store.py +49 -77
  118. pixeltable/io/fiftyone.py +11 -11
  119. pixeltable/io/globals.py +29 -28
  120. pixeltable/io/hf_datasets.py +17 -9
  121. pixeltable/io/label_studio.py +70 -66
  122. pixeltable/io/lancedb.py +3 -0
  123. pixeltable/io/pandas.py +12 -11
  124. pixeltable/io/parquet.py +13 -93
  125. pixeltable/io/table_data_conduit.py +71 -47
  126. pixeltable/io/utils.py +3 -3
  127. pixeltable/iterators/__init__.py +2 -1
  128. pixeltable/iterators/audio.py +21 -11
  129. pixeltable/iterators/document.py +116 -55
  130. pixeltable/iterators/image.py +5 -2
  131. pixeltable/iterators/video.py +293 -13
  132. pixeltable/metadata/__init__.py +4 -2
  133. pixeltable/metadata/converters/convert_18.py +2 -2
  134. pixeltable/metadata/converters/convert_19.py +2 -2
  135. pixeltable/metadata/converters/convert_20.py +2 -2
  136. pixeltable/metadata/converters/convert_21.py +2 -2
  137. pixeltable/metadata/converters/convert_22.py +2 -2
  138. pixeltable/metadata/converters/convert_24.py +2 -2
  139. pixeltable/metadata/converters/convert_25.py +2 -2
  140. pixeltable/metadata/converters/convert_26.py +2 -2
  141. pixeltable/metadata/converters/convert_29.py +4 -4
  142. pixeltable/metadata/converters/convert_34.py +2 -2
  143. pixeltable/metadata/converters/convert_36.py +2 -2
  144. pixeltable/metadata/converters/convert_37.py +15 -0
  145. pixeltable/metadata/converters/convert_38.py +39 -0
  146. pixeltable/metadata/converters/convert_39.py +124 -0
  147. pixeltable/metadata/converters/convert_40.py +73 -0
  148. pixeltable/metadata/converters/util.py +13 -12
  149. pixeltable/metadata/notes.py +4 -0
  150. pixeltable/metadata/schema.py +79 -42
  151. pixeltable/metadata/utils.py +74 -0
  152. pixeltable/mypy/__init__.py +3 -0
  153. pixeltable/mypy/mypy_plugin.py +123 -0
  154. pixeltable/plan.py +274 -223
  155. pixeltable/share/__init__.py +1 -1
  156. pixeltable/share/packager.py +259 -129
  157. pixeltable/share/protocol/__init__.py +34 -0
  158. pixeltable/share/protocol/common.py +170 -0
  159. pixeltable/share/protocol/operation_types.py +33 -0
  160. pixeltable/share/protocol/replica.py +109 -0
  161. pixeltable/share/publish.py +213 -57
  162. pixeltable/store.py +238 -175
  163. pixeltable/type_system.py +104 -63
  164. pixeltable/utils/__init__.py +2 -3
  165. pixeltable/utils/arrow.py +108 -13
  166. pixeltable/utils/av.py +298 -0
  167. pixeltable/utils/azure_store.py +305 -0
  168. pixeltable/utils/code.py +3 -3
  169. pixeltable/utils/console_output.py +4 -1
  170. pixeltable/utils/coroutine.py +6 -23
  171. pixeltable/utils/dbms.py +31 -5
  172. pixeltable/utils/description_helper.py +4 -5
  173. pixeltable/utils/documents.py +5 -6
  174. pixeltable/utils/exception_handler.py +7 -30
  175. pixeltable/utils/filecache.py +6 -6
  176. pixeltable/utils/formatter.py +4 -6
  177. pixeltable/utils/gcs_store.py +283 -0
  178. pixeltable/utils/http_server.py +2 -3
  179. pixeltable/utils/iceberg.py +1 -2
  180. pixeltable/utils/image.py +17 -0
  181. pixeltable/utils/lancedb.py +88 -0
  182. pixeltable/utils/local_store.py +316 -0
  183. pixeltable/utils/misc.py +5 -0
  184. pixeltable/utils/object_stores.py +528 -0
  185. pixeltable/utils/pydantic.py +60 -0
  186. pixeltable/utils/pytorch.py +5 -6
  187. pixeltable/utils/s3_store.py +392 -0
  188. pixeltable-0.4.20.dist-info/METADATA +587 -0
  189. pixeltable-0.4.20.dist-info/RECORD +218 -0
  190. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info}/WHEEL +1 -1
  191. pixeltable-0.4.20.dist-info/entry_points.txt +2 -0
  192. pixeltable/__version__.py +0 -3
  193. pixeltable/ext/__init__.py +0 -17
  194. pixeltable/ext/functions/__init__.py +0 -11
  195. pixeltable/ext/functions/whisperx.py +0 -77
  196. pixeltable/utils/media_store.py +0 -77
  197. pixeltable/utils/s3.py +0 -17
  198. pixeltable/utils/sample.py +0 -25
  199. pixeltable-0.4.0rc3.dist-info/METADATA +0 -435
  200. pixeltable-0.4.0rc3.dist-info/RECORD +0 -189
  201. pixeltable-0.4.0rc3.dist-info/entry_points.txt +0 -3
  202. {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info/licenses}/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
+ import datetime
1
2
  import logging
2
3
  import warnings
3
4
  from decimal import Decimal
4
- from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Optional, Sequence
5
+ from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Sequence
5
6
  from uuid import UUID
6
7
 
7
8
  import sqlalchemy as sql
@@ -21,13 +22,13 @@ _logger = logging.getLogger('pixeltable')
21
22
 
22
23
  class OrderByItem(NamedTuple):
23
24
  expr: exprs.Expr
24
- asc: Optional[bool]
25
+ asc: bool | None
25
26
 
26
27
 
27
28
  OrderByClause = list[OrderByItem]
28
29
 
29
30
 
30
- def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
31
+ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> OrderByClause | None:
31
32
  """Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
32
33
  Two clauses are compatible if for each of their respective items c1[i] and c2[i]
33
34
  a) the exprs are identical and
@@ -65,49 +66,66 @@ def print_order_by_clause(clause: OrderByClause) -> str:
65
66
 
66
67
  class SqlNode(ExecNode):
67
68
  """
68
- Materializes data from the store via an SQL statement.
69
+ Materializes data from the store via a SQL statement.
69
70
  This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
70
71
  The pk columns are not included in the select list.
71
72
  If set_pk is True, they are added to the end of the result set when creating the SQL statement
72
73
  so they can always be referenced as cols[-num_pk_cols:] in the result set.
73
74
  The pk_columns consist of the rowid columns of the target table followed by the version number.
75
+
76
+ If row_builder contains references to unstored iter columns, expands the select list to include their
77
+ SQL-materializable subexpressions.
78
+
79
+ Args:
80
+ select_list: output of the query
81
+ set_pk: if True, sets the primary for each DataRow
74
82
  """
75
83
 
76
- tbl: Optional[catalog.TableVersionPath]
84
+ tbl: catalog.TableVersionPath | None
77
85
  select_list: exprs.ExprSet
86
+ columns: list[catalog.Column] # for which columns to populate DataRow.cell_vals/cell_md
87
+ cell_md_refs: list[exprs.ColumnPropertyRef] # of ColumnRefs which also need DataRow.slot_cellmd for evaluation
78
88
  set_pk: bool
79
89
  num_pk_cols: int
80
- py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
81
- py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
82
- cte: Optional[sql.CTE]
90
+ py_filter: exprs.Expr | None # a predicate that can only be run in Python
91
+ py_filter_eval_ctx: exprs.RowBuilder.EvalCtx | None
92
+ cte: sql.CTE | None
83
93
  sql_elements: exprs.SqlElementCache
84
94
 
95
+ # execution state
96
+ cellmd_item_idxs: exprs.ExprDict[int] # cellmd expr -> idx in sql select list
97
+ column_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
98
+ column_cellmd_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
99
+ result_cursor: sql.engine.CursorResult | None
100
+
85
101
  # where_clause/-_element: allow subclass to set one or the other (but not both)
86
- where_clause: Optional[exprs.Expr]
87
- where_clause_element: Optional[sql.ColumnElement]
102
+ where_clause: exprs.Expr | None
103
+ where_clause_element: sql.ColumnElement | None
88
104
 
89
105
  order_by_clause: OrderByClause
90
- limit: Optional[int]
106
+ limit: int | None
91
107
 
92
108
  def __init__(
93
109
  self,
94
- tbl: Optional[catalog.TableVersionPath],
110
+ tbl: catalog.TableVersionPath | None,
95
111
  row_builder: exprs.RowBuilder,
96
112
  select_list: Iterable[exprs.Expr],
113
+ columns: list[catalog.Column],
97
114
  sql_elements: exprs.SqlElementCache,
115
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
98
116
  set_pk: bool = False,
99
117
  ):
100
- """
101
- If row_builder contains references to unstored iter columns, expands the select list to include their
102
- SQL-materializable subexpressions.
103
-
104
- Args:
105
- select_list: output of the query
106
- set_pk: if True, sets the primary for each DataRow
107
- """
108
118
  # create Select stmt
109
119
  self.sql_elements = sql_elements
110
120
  self.tbl = tbl
121
+ self.columns = columns
122
+ if cell_md_col_refs is not None:
123
+ assert all(ref.col.stores_cellmd for ref in cell_md_col_refs)
124
+ self.cell_md_refs = [
125
+ exprs.ColumnPropertyRef(ref, exprs.ColumnPropertyRef.Property.CELLMD) for ref in cell_md_col_refs
126
+ ]
127
+ else:
128
+ self.cell_md_refs = []
111
129
  self.select_list = exprs.ExprSet(select_list)
112
130
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
113
131
  for iter_arg in row_builder.unstored_iter_args.values():
@@ -130,6 +148,9 @@ class SqlNode(ExecNode):
130
148
  assert self.num_pk_cols > 1
131
149
 
132
150
  # additional state
151
+ self.cellmd_item_idxs = exprs.ExprDict()
152
+ self.column_item_idxs = {}
153
+ self.column_cellmd_item_idxs = {}
133
154
  self.result_cursor = None
134
155
  # the filter is provided by the subclass
135
156
  self.py_filter = None
@@ -145,10 +166,9 @@ class SqlNode(ExecNode):
145
166
  if tv is not None:
146
167
  assert tv.is_validated
147
168
 
148
- def _create_pk_cols(self) -> list[sql.Column]:
149
- """Create a list of pk columns"""
150
- # we need to retrieve the pk columns
169
+ def _pk_col_items(self) -> list[sql.Column]:
151
170
  if self.set_pk:
171
+ # we need to retrieve the pk columns
152
172
  assert self.tbl is not None
153
173
  assert self.tbl.tbl_version.get().is_validated
154
174
  return self.tbl.tbl_version.get().store_tbl.pk_columns()
@@ -158,7 +178,19 @@ class SqlNode(ExecNode):
158
178
  """Create Select from local state"""
159
179
 
160
180
  assert self.sql_elements.contains_all(self.select_list)
161
- sql_select_list = [self.sql_elements.get(e) for e in self.select_list] + self._create_pk_cols()
181
+ sql_select_list_exprs = exprs.ExprSet(self.select_list)
182
+ self.cellmd_item_idxs = exprs.ExprDict((ref, sql_select_list_exprs.add(ref)) for ref in self.cell_md_refs)
183
+ column_refs = [exprs.ColumnRef(col) for col in self.columns]
184
+ self.column_item_idxs = {col_ref.col: sql_select_list_exprs.add(col_ref) for col_ref in column_refs}
185
+ column_cellmd_refs = [
186
+ exprs.ColumnPropertyRef(col_ref, exprs.ColumnPropertyRef.Property.CELLMD)
187
+ for col_ref in column_refs
188
+ if col_ref.col.stores_cellmd
189
+ ]
190
+ self.column_cellmd_item_idxs = {
191
+ cellmd_ref.col_ref.col: sql_select_list_exprs.add(cellmd_ref) for cellmd_ref in column_cellmd_refs
192
+ }
193
+ sql_select_list = [self.sql_elements.get(e) for e in sql_select_list_exprs] + self._pk_col_items()
162
194
  stmt = sql.select(*sql_select_list)
163
195
 
164
196
  where_clause_element = (
@@ -184,7 +216,7 @@ class SqlNode(ExecNode):
184
216
  def _ordering_tbl_ids(self) -> set[UUID]:
185
217
  return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
186
218
 
187
- def to_cte(self, keep_pk: bool = False) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
219
+ def to_cte(self, keep_pk: bool = False) -> tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]] | None:
188
220
  """
189
221
  Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
190
222
  keep_pk: if True, the PK columns are included in the CTE Select statement
@@ -199,9 +231,7 @@ class SqlNode(ExecNode):
199
231
  if not keep_pk:
200
232
  self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
201
233
  self.cte = self._create_stmt().cte()
202
- pk_count = self.num_pk_cols if self.set_pk else 0
203
- assert len(self.select_list) + pk_count == len(self.cte.c)
204
- return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c)) # skip pk cols
234
+ return self.cte, exprs.ExprDict(zip(list(self.select_list) + self.cell_md_refs, self.cte.c)) # skip pk cols
205
235
 
206
236
  @classmethod
207
237
  def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
@@ -215,8 +245,8 @@ class SqlNode(ExecNode):
215
245
  cls,
216
246
  tbl: catalog.TableVersionPath,
217
247
  stmt: sql.Select,
218
- refd_tbl_ids: Optional[set[UUID]] = None,
219
- exact_version_only: Optional[set[UUID]] = None,
248
+ refd_tbl_ids: set[UUID] | None = None,
249
+ exact_version_only: set[UUID] | None = None,
220
250
  ) -> sql.Select:
221
251
  """Add From clause to stmt for tables/views referenced by materialized_exprs
222
252
  Args:
@@ -240,7 +270,7 @@ class SqlNode(ExecNode):
240
270
  joined_tbls.append(t)
241
271
 
242
272
  first = True
243
- prev_tv: Optional[catalog.TableVersion] = None
273
+ prev_tv: catalog.TableVersion | None = None
244
274
  for t in joined_tbls[::-1]:
245
275
  tv = t.get()
246
276
  # _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
@@ -308,8 +338,7 @@ class SqlNode(ExecNode):
308
338
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
309
339
  except Exception:
310
340
  # log something if we can't log the compiled stmt
311
- stmt_str = repr(stmt)
312
- _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt_str}')
341
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
313
342
  self._log_explain(stmt)
314
343
 
315
344
  conn = Env.get().conn
@@ -317,28 +346,56 @@ class SqlNode(ExecNode):
317
346
  for _ in w:
318
347
  pass
319
348
 
320
- tbl_version = self.tbl.tbl_version if self.tbl is not None else None
321
- output_batch = DataRowBatch(tbl_version, self.row_builder)
322
- output_row: Optional[exprs.DataRow] = None
349
+ output_batch = DataRowBatch(self.row_builder)
350
+ output_row: exprs.DataRow | None = None
323
351
  num_rows_returned = 0
352
+ is_using_cockroachdb = Env.get().is_using_cockroachdb
353
+ tzinfo = Env.get().default_time_zone
324
354
 
325
355
  for sql_row in result_cursor:
326
356
  output_row = output_batch.add_row(output_row)
327
357
 
328
358
  # populate output_row
359
+
329
360
  if self.num_pk_cols > 0:
330
361
  output_row.set_pk(tuple(sql_row[-self.num_pk_cols :]))
362
+
363
+ # column copies
364
+ for col, item_idx in self.column_item_idxs.items():
365
+ output_row.cell_vals[col.id] = sql_row[item_idx]
366
+ for col, item_idx in self.column_cellmd_item_idxs.items():
367
+ cell_md_dict = sql_row[item_idx]
368
+ output_row.cell_md[col.id] = exprs.CellMd(**cell_md_dict) if cell_md_dict is not None else None
369
+
370
+ # populate DataRow.slot_cellmd, where requested
371
+ for cellmd_ref, item_idx in self.cellmd_item_idxs.items():
372
+ cell_md_dict = sql_row[item_idx]
373
+ output_row.slot_md[cellmd_ref.col_ref.slot_idx] = (
374
+ exprs.CellMd.from_dict(cell_md_dict) if cell_md_dict is not None else None
375
+ )
376
+
331
377
  # copy the output of the SQL query into the output row
332
378
  for i, e in enumerate(self.select_list):
333
379
  slot_idx = e.slot_idx
334
- # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
335
380
  if isinstance(sql_row[i], Decimal):
381
+ # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
336
382
  if e.col_type.is_int_type():
337
383
  output_row[slot_idx] = int(sql_row[i])
338
384
  elif e.col_type.is_float_type():
339
385
  output_row[slot_idx] = float(sql_row[i])
340
386
  else:
341
387
  raise RuntimeError(f'Unexpected Decimal value for {e}')
388
+ elif is_using_cockroachdb and isinstance(sql_row[i], datetime.datetime):
389
+ # Ensure that the datetime is timezone-aware and in the session time zone
390
+ # cockroachDB returns timestamps in the session time zone, with numeric offset,
391
+ # convert to the session time zone with the requested tzinfo for DST handling
392
+ if e.col_type.is_timestamp_type():
393
+ if isinstance(sql_row[i].tzinfo, datetime.timezone):
394
+ output_row[slot_idx] = sql_row[i].astimezone(tz=tzinfo)
395
+ else:
396
+ output_row[slot_idx] = sql_row[i]
397
+ else:
398
+ raise RuntimeError(f'Unexpected datetime value for {e}')
342
399
  else:
343
400
  output_row[slot_idx] = sql_row[i]
344
401
 
@@ -360,7 +417,7 @@ class SqlNode(ExecNode):
360
417
  if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
361
418
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
362
419
  yield output_batch
363
- output_batch = DataRowBatch(tbl_version, self.row_builder)
420
+ output_batch = DataRowBatch(self.row_builder)
364
421
 
365
422
  if len(output_batch) > 0:
366
423
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
@@ -376,6 +433,11 @@ class SqlScanNode(SqlNode):
376
433
  Materializes data from the store via a Select stmt.
377
434
 
378
435
  Supports filtering and ordering.
436
+
437
+ Args:
438
+ select_list: output of the query
439
+ set_pk: if True, sets the primary for each DataRow
440
+ exact_version_only: tables for which we only want to see rows created at the current version
379
441
  """
380
442
 
381
443
  exact_version_only: list[catalog.TableVersionHandle]
@@ -385,17 +447,21 @@ class SqlScanNode(SqlNode):
385
447
  tbl: catalog.TableVersionPath,
386
448
  row_builder: exprs.RowBuilder,
387
449
  select_list: Iterable[exprs.Expr],
450
+ columns: list[catalog.Column],
451
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
388
452
  set_pk: bool = False,
389
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
453
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
390
454
  ):
391
- """
392
- Args:
393
- select_list: output of the query
394
- set_pk: if True, sets the primary for each DataRow
395
- exact_version_only: tables for which we only want to see rows created at the current version
396
- """
397
455
  sql_elements = exprs.SqlElementCache()
398
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
456
+ super().__init__(
457
+ tbl,
458
+ row_builder,
459
+ select_list,
460
+ columns=columns,
461
+ sql_elements=sql_elements,
462
+ set_pk=set_pk,
463
+ cell_md_col_refs=cell_md_col_refs,
464
+ )
399
465
  # create Select stmt
400
466
  if exact_version_only is None:
401
467
  exact_version_only = []
@@ -415,6 +481,11 @@ class SqlScanNode(SqlNode):
415
481
  class SqlLookupNode(SqlNode):
416
482
  """
417
483
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
484
+
485
+ Args:
486
+ select_list: output of the query
487
+ sa_key_cols: list of key columns in the store table
488
+ key_vals: list of key values to look up
418
489
  """
419
490
 
420
491
  def __init__(
@@ -422,17 +493,21 @@ class SqlLookupNode(SqlNode):
422
493
  tbl: catalog.TableVersionPath,
423
494
  row_builder: exprs.RowBuilder,
424
495
  select_list: Iterable[exprs.Expr],
496
+ columns: list[catalog.Column],
425
497
  sa_key_cols: list[sql.Column],
426
498
  key_vals: list[tuple],
499
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
427
500
  ):
428
- """
429
- Args:
430
- select_list: output of the query
431
- sa_key_cols: list of key columns in the store table
432
- key_vals: list of key values to look up
433
- """
434
501
  sql_elements = exprs.SqlElementCache()
435
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
502
+ super().__init__(
503
+ tbl,
504
+ row_builder,
505
+ select_list,
506
+ columns=columns,
507
+ sql_elements=sql_elements,
508
+ set_pk=True,
509
+ cell_md_col_refs=cell_md_col_refs,
510
+ )
436
511
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
437
512
  self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
438
513
 
@@ -446,29 +521,29 @@ class SqlLookupNode(SqlNode):
446
521
  class SqlAggregationNode(SqlNode):
447
522
  """
448
523
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
524
+
525
+ Args:
526
+ select_list: can contain calls to AggregateFunctions
527
+ group_by_items: list of expressions to group by
528
+ limit: max number of rows to return: None = no limit
449
529
  """
450
530
 
451
- group_by_items: Optional[list[exprs.Expr]]
452
- input_cte: Optional[sql.CTE]
531
+ group_by_items: list[exprs.Expr] | None
532
+ input_cte: sql.CTE | None
453
533
 
454
534
  def __init__(
455
535
  self,
456
536
  row_builder: exprs.RowBuilder,
457
537
  input: SqlNode,
458
538
  select_list: Iterable[exprs.Expr],
459
- group_by_items: Optional[list[exprs.Expr]] = None,
460
- limit: Optional[int] = None,
461
- exact_version_only: Optional[list[catalog.TableVersion]] = None,
539
+ group_by_items: list[exprs.Expr] | None = None,
540
+ limit: int | None = None,
541
+ exact_version_only: list[catalog.TableVersion] | None = None,
462
542
  ):
463
- """
464
- Args:
465
- select_list: can contain calls to AggregateFunctions
466
- group_by_items: list of expressions to group by
467
- limit: max number of rows to return: None = no limit
468
- """
543
+ assert len(input.cell_md_refs) == 0 # there's no aggregation over json or arrays in SQL
469
544
  self.input_cte, input_col_map = input.to_cte()
470
545
  sql_elements = exprs.SqlElementCache(input_col_map)
471
- super().__init__(None, row_builder, select_list, sql_elements)
546
+ super().__init__(None, row_builder, select_list, columns=[], sql_elements=sql_elements)
472
547
  self.group_by_items = group_by_items
473
548
 
474
549
  def _create_stmt(self) -> sql.Select:
@@ -504,7 +579,10 @@ class SqlJoinNode(SqlNode):
504
579
  input_cte, input_col_map = input_node.to_cte()
505
580
  self.input_ctes.append(input_cte)
506
581
  sql_elements.extend(input_col_map)
507
- super().__init__(None, row_builder, select_list, sql_elements)
582
+ cell_md_col_refs = [cell_md_ref.col_ref for input in inputs for cell_md_ref in input.cell_md_refs]
583
+ super().__init__(
584
+ None, row_builder, select_list, columns=[], sql_elements=sql_elements, cell_md_col_refs=cell_md_col_refs
585
+ )
508
586
 
509
587
  def _create_stmt(self) -> sql.Select:
510
588
  from pixeltable import plan
@@ -530,40 +608,46 @@ class SqlJoinNode(SqlNode):
530
608
 
531
609
  class SqlSampleNode(SqlNode):
532
610
  """
533
- Returns rows from a stratified sample with N samples per strata.
611
+ Returns rows sampled from the input node.
612
+
613
+ Args:
614
+ input: SqlNode to sample from
615
+ select_list: can contain calls to AggregateFunctions
616
+ sample_clause: specifies the sampling method
617
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
534
618
  """
535
619
 
536
- stratify_exprs: Optional[list[exprs.Expr]]
537
- n_samples: Optional[int]
538
- fraction_samples: Optional[float]
539
- seed: int
540
- input_cte: Optional[sql.CTE]
620
+ input_cte: sql.CTE | None
541
621
  pk_count: int
622
+ stratify_exprs: list[exprs.Expr] | None
623
+ sample_clause: 'SampleClause'
542
624
 
543
625
  def __init__(
544
626
  self,
545
627
  row_builder: exprs.RowBuilder,
546
628
  input: SqlNode,
547
629
  select_list: Iterable[exprs.Expr],
548
- stratify_exprs: Optional[list[exprs.Expr]] = None,
549
- sample_clause: Optional['SampleClause'] = None,
630
+ sample_clause: 'SampleClause',
631
+ stratify_exprs: list[exprs.Expr],
550
632
  ):
551
- """
552
- Args:
553
- select_list: can contain calls to AggregateFunctions
554
- stratify_exprs: list of expressions to group by
555
- n: number of samples per strata
556
- """
633
+ assert isinstance(input, SqlNode)
557
634
  self.input_cte, input_col_map = input.to_cte(keep_pk=True)
558
635
  self.pk_count = input.num_pk_cols
559
636
  assert self.pk_count > 1
560
637
  sql_elements = exprs.SqlElementCache(input_col_map)
561
- super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
638
+ assert sql_elements.contains_all(stratify_exprs)
639
+ cell_md_col_refs = [cell_md_ref.col_ref for cell_md_ref in input.cell_md_refs]
640
+ super().__init__(
641
+ input.tbl,
642
+ row_builder,
643
+ select_list,
644
+ columns=[],
645
+ sql_elements=sql_elements,
646
+ cell_md_col_refs=cell_md_col_refs,
647
+ set_pk=True,
648
+ )
562
649
  self.stratify_exprs = stratify_exprs
563
- self.n_samples = sample_clause.n
564
- self.n_per_stratum = sample_clause.n_per_stratum
565
- self.fraction_samples = sample_clause.fraction
566
- self.seed = sample_clause.seed if sample_clause.seed is not None else 0
650
+ self.sample_clause = sample_clause
567
651
 
568
652
  @classmethod
569
653
  def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
@@ -571,27 +655,48 @@ class SqlSampleNode(SqlNode):
571
655
  General SQL form is:
572
656
  - MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
573
657
  """
574
- sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
658
+ sql_expr: sql.ColumnElement = seed.cast(sql.String)
575
659
  for e in sql_cols:
576
- sql_expr = sql_expr + sql.literal_column("'___'") + sql.cast(e, sql.Text)
660
+ # Quotes are required below to guarantee that the string is properly presented in SQL
661
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + e.cast(sql.String)
577
662
  sql_expr = sql.func.md5(sql_expr)
578
663
  return sql_expr
579
664
 
580
- def _create_order_by(self, cte: sql.CTE) -> sql.ColumnElement:
665
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
581
666
  """Create an expression for randomly ordering rows with a given seed"""
582
667
  rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
583
668
  assert len(rowid_cols) > 0
584
- return self.key_sql_expr(sql.literal_column(str(self.seed)), rowid_cols)
669
+ # If seed is not set in the sample clause, use the random seed given by the execution context
670
+ seed = self.sample_clause.seed if self.sample_clause.seed is not None else self.ctx.random_seed
671
+ return self.key_sql_expr(sql.literal_column(str(seed)), rowid_cols)
585
672
 
586
673
  def _create_stmt(self) -> sql.Select:
587
- if self.fraction_samples is not None:
588
- return self._create_stmt_fraction(self.fraction_samples)
589
- return self._create_stmt_n(self.n_samples, self.n_per_stratum)
674
+ from pixeltable.plan import SampleClause
675
+
676
+ if self.sample_clause.fraction is not None:
677
+ if len(self.stratify_exprs) == 0:
678
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
679
+ s_key = self._create_key_sql(self.input_cte)
680
+
681
+ # Construct a suitable where clause
682
+ fraction_md5 = SampleClause.fraction_to_md5_hex(self.sample_clause.fraction)
683
+ order_by = self._create_key_sql(self.input_cte)
684
+ return sql.select(*self.input_cte.c).where(s_key < fraction_md5).order_by(order_by)
685
+
686
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
687
+ else:
688
+ if len(self.stratify_exprs) == 0:
689
+ # No stratification, just return n samples from the input CTE
690
+ order_by = self._create_key_sql(self.input_cte)
691
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
692
+
693
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
694
+
695
+ def _create_stmt_stratified_n(self, n: int | None, n_per_stratum: int | None) -> sql.Select:
696
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
590
697
 
591
- def _create_stmt_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
592
- """Create a Select stmt that returns n samples across all strata"""
593
698
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
594
- order_by = self._create_order_by(self.input_cte)
699
+ order_by = self._create_key_sql(self.input_cte)
595
700
 
596
701
  # Create a list of all columns plus the rank
597
702
  # Get all columns from the input CTE dynamically
@@ -605,15 +710,15 @@ class SqlSampleNode(SqlNode):
605
710
  if n_per_stratum is not None:
606
711
  return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
607
712
  else:
608
- secondary_order = self._create_order_by(row_rank_cte)
713
+ secondary_order = self._create_key_sql(row_rank_cte)
609
714
  return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
610
715
 
611
- def _create_stmt_fraction(self, fraction_samples: float) -> sql.Select:
716
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
612
717
  """Create a Select stmt that returns a fraction of the rows per strata"""
613
718
 
614
719
  # Build the strata count CTE
615
720
  # Produces a table of the form:
616
- # ([stratify_exprs], s_s_size)
721
+ # (*stratify_exprs, s_s_size)
617
722
  # where s_s_size is the number of samples to take from each stratum
618
723
  sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
619
724
  per_strata_count_cte = (
@@ -628,19 +733,19 @@ class SqlSampleNode(SqlNode):
628
733
 
629
734
  # Build a CTE that ranks the rows within each stratum
630
735
  # Include all columns from the input CTE dynamically
631
- order_by = self._create_order_by(self.input_cte)
736
+ order_by = self._create_key_sql(self.input_cte)
632
737
  select_columns = [*self.input_cte.c]
633
738
  select_columns.append(
634
739
  sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
635
740
  )
636
741
  row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
637
742
 
638
- # Build the join criterion dynamically to accommodate any number of group by columns
743
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
639
744
  join_c = sql.true()
640
745
  for col in per_strata_count_cte.c[:-1]:
641
746
  join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
642
747
 
643
- # Join srcp with per_strata_count_cte to limit returns to the requested fraction of rows
748
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
644
749
  final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
645
750
  stmt = (
646
751
  sql.select(*final_columns)
@@ -6,7 +6,7 @@ from .column_property_ref import ColumnPropertyRef
6
6
  from .column_ref import ColumnRef
7
7
  from .comparison import Comparison
8
8
  from .compound_predicate import CompoundPredicate
9
- from .data_row import DataRow
9
+ from .data_row import ArrayMd, CellMd, DataRow
10
10
  from .expr import Expr
11
11
  from .expr_dict import ExprDict
12
12
  from .expr_set import ExprSet