pixeltable 0.3.14__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. pixeltable/__init__.py +42 -8
  2. pixeltable/{dataframe.py → _query.py} +470 -206
  3. pixeltable/_version.py +1 -0
  4. pixeltable/catalog/__init__.py +5 -4
  5. pixeltable/catalog/catalog.py +1785 -432
  6. pixeltable/catalog/column.py +190 -113
  7. pixeltable/catalog/dir.py +2 -4
  8. pixeltable/catalog/globals.py +19 -46
  9. pixeltable/catalog/insertable_table.py +191 -98
  10. pixeltable/catalog/path.py +63 -23
  11. pixeltable/catalog/schema_object.py +11 -15
  12. pixeltable/catalog/table.py +843 -436
  13. pixeltable/catalog/table_metadata.py +103 -0
  14. pixeltable/catalog/table_version.py +978 -657
  15. pixeltable/catalog/table_version_handle.py +72 -16
  16. pixeltable/catalog/table_version_path.py +112 -43
  17. pixeltable/catalog/tbl_ops.py +53 -0
  18. pixeltable/catalog/update_status.py +191 -0
  19. pixeltable/catalog/view.py +134 -90
  20. pixeltable/config.py +134 -22
  21. pixeltable/env.py +471 -157
  22. pixeltable/exceptions.py +6 -0
  23. pixeltable/exec/__init__.py +4 -1
  24. pixeltable/exec/aggregation_node.py +7 -8
  25. pixeltable/exec/cache_prefetch_node.py +83 -110
  26. pixeltable/exec/cell_materialization_node.py +268 -0
  27. pixeltable/exec/cell_reconstruction_node.py +168 -0
  28. pixeltable/exec/component_iteration_node.py +4 -3
  29. pixeltable/exec/data_row_batch.py +8 -65
  30. pixeltable/exec/exec_context.py +16 -4
  31. pixeltable/exec/exec_node.py +13 -36
  32. pixeltable/exec/expr_eval/evaluators.py +11 -7
  33. pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
  34. pixeltable/exec/expr_eval/globals.py +8 -5
  35. pixeltable/exec/expr_eval/row_buffer.py +1 -2
  36. pixeltable/exec/expr_eval/schedulers.py +106 -56
  37. pixeltable/exec/globals.py +35 -0
  38. pixeltable/exec/in_memory_data_node.py +19 -19
  39. pixeltable/exec/object_store_save_node.py +293 -0
  40. pixeltable/exec/row_update_node.py +16 -9
  41. pixeltable/exec/sql_node.py +351 -84
  42. pixeltable/exprs/__init__.py +1 -1
  43. pixeltable/exprs/arithmetic_expr.py +27 -22
  44. pixeltable/exprs/array_slice.py +3 -3
  45. pixeltable/exprs/column_property_ref.py +36 -23
  46. pixeltable/exprs/column_ref.py +213 -89
  47. pixeltable/exprs/comparison.py +5 -5
  48. pixeltable/exprs/compound_predicate.py +5 -4
  49. pixeltable/exprs/data_row.py +164 -54
  50. pixeltable/exprs/expr.py +70 -44
  51. pixeltable/exprs/expr_dict.py +3 -3
  52. pixeltable/exprs/expr_set.py +17 -10
  53. pixeltable/exprs/function_call.py +100 -40
  54. pixeltable/exprs/globals.py +2 -2
  55. pixeltable/exprs/in_predicate.py +4 -4
  56. pixeltable/exprs/inline_expr.py +18 -32
  57. pixeltable/exprs/is_null.py +7 -3
  58. pixeltable/exprs/json_mapper.py +8 -8
  59. pixeltable/exprs/json_path.py +56 -22
  60. pixeltable/exprs/literal.py +27 -5
  61. pixeltable/exprs/method_ref.py +2 -2
  62. pixeltable/exprs/object_ref.py +2 -2
  63. pixeltable/exprs/row_builder.py +167 -67
  64. pixeltable/exprs/rowid_ref.py +25 -10
  65. pixeltable/exprs/similarity_expr.py +58 -40
  66. pixeltable/exprs/sql_element_cache.py +4 -4
  67. pixeltable/exprs/string_op.py +5 -5
  68. pixeltable/exprs/type_cast.py +3 -5
  69. pixeltable/func/__init__.py +1 -0
  70. pixeltable/func/aggregate_function.py +8 -8
  71. pixeltable/func/callable_function.py +9 -9
  72. pixeltable/func/expr_template_function.py +17 -11
  73. pixeltable/func/function.py +18 -20
  74. pixeltable/func/function_registry.py +6 -7
  75. pixeltable/func/globals.py +2 -3
  76. pixeltable/func/mcp.py +74 -0
  77. pixeltable/func/query_template_function.py +29 -27
  78. pixeltable/func/signature.py +46 -19
  79. pixeltable/func/tools.py +31 -13
  80. pixeltable/func/udf.py +18 -20
  81. pixeltable/functions/__init__.py +16 -0
  82. pixeltable/functions/anthropic.py +123 -77
  83. pixeltable/functions/audio.py +147 -10
  84. pixeltable/functions/bedrock.py +13 -6
  85. pixeltable/functions/date.py +7 -4
  86. pixeltable/functions/deepseek.py +35 -43
  87. pixeltable/functions/document.py +81 -0
  88. pixeltable/functions/fal.py +76 -0
  89. pixeltable/functions/fireworks.py +11 -20
  90. pixeltable/functions/gemini.py +195 -39
  91. pixeltable/functions/globals.py +142 -14
  92. pixeltable/functions/groq.py +108 -0
  93. pixeltable/functions/huggingface.py +1056 -24
  94. pixeltable/functions/image.py +115 -57
  95. pixeltable/functions/json.py +1 -1
  96. pixeltable/functions/llama_cpp.py +28 -13
  97. pixeltable/functions/math.py +67 -5
  98. pixeltable/functions/mistralai.py +18 -55
  99. pixeltable/functions/net.py +70 -0
  100. pixeltable/functions/ollama.py +20 -13
  101. pixeltable/functions/openai.py +240 -226
  102. pixeltable/functions/openrouter.py +143 -0
  103. pixeltable/functions/replicate.py +4 -4
  104. pixeltable/functions/reve.py +250 -0
  105. pixeltable/functions/string.py +239 -69
  106. pixeltable/functions/timestamp.py +16 -16
  107. pixeltable/functions/together.py +24 -84
  108. pixeltable/functions/twelvelabs.py +188 -0
  109. pixeltable/functions/util.py +6 -1
  110. pixeltable/functions/uuid.py +30 -0
  111. pixeltable/functions/video.py +1515 -107
  112. pixeltable/functions/vision.py +8 -8
  113. pixeltable/functions/voyageai.py +289 -0
  114. pixeltable/functions/whisper.py +16 -8
  115. pixeltable/functions/whisperx.py +179 -0
  116. pixeltable/{ext/functions → functions}/yolox.py +2 -4
  117. pixeltable/globals.py +362 -115
  118. pixeltable/index/base.py +17 -21
  119. pixeltable/index/btree.py +28 -22
  120. pixeltable/index/embedding_index.py +100 -118
  121. pixeltable/io/__init__.py +4 -2
  122. pixeltable/io/datarows.py +8 -7
  123. pixeltable/io/external_store.py +56 -105
  124. pixeltable/io/fiftyone.py +13 -13
  125. pixeltable/io/globals.py +31 -30
  126. pixeltable/io/hf_datasets.py +61 -16
  127. pixeltable/io/label_studio.py +74 -70
  128. pixeltable/io/lancedb.py +3 -0
  129. pixeltable/io/pandas.py +21 -12
  130. pixeltable/io/parquet.py +25 -105
  131. pixeltable/io/table_data_conduit.py +250 -123
  132. pixeltable/io/utils.py +4 -4
  133. pixeltable/iterators/__init__.py +2 -1
  134. pixeltable/iterators/audio.py +26 -25
  135. pixeltable/iterators/base.py +9 -3
  136. pixeltable/iterators/document.py +112 -78
  137. pixeltable/iterators/image.py +12 -15
  138. pixeltable/iterators/string.py +11 -4
  139. pixeltable/iterators/video.py +523 -120
  140. pixeltable/metadata/__init__.py +14 -3
  141. pixeltable/metadata/converters/convert_13.py +2 -2
  142. pixeltable/metadata/converters/convert_18.py +2 -2
  143. pixeltable/metadata/converters/convert_19.py +2 -2
  144. pixeltable/metadata/converters/convert_20.py +2 -2
  145. pixeltable/metadata/converters/convert_21.py +2 -2
  146. pixeltable/metadata/converters/convert_22.py +2 -2
  147. pixeltable/metadata/converters/convert_24.py +2 -2
  148. pixeltable/metadata/converters/convert_25.py +2 -2
  149. pixeltable/metadata/converters/convert_26.py +2 -2
  150. pixeltable/metadata/converters/convert_29.py +4 -4
  151. pixeltable/metadata/converters/convert_30.py +34 -21
  152. pixeltable/metadata/converters/convert_34.py +2 -2
  153. pixeltable/metadata/converters/convert_35.py +9 -0
  154. pixeltable/metadata/converters/convert_36.py +38 -0
  155. pixeltable/metadata/converters/convert_37.py +15 -0
  156. pixeltable/metadata/converters/convert_38.py +39 -0
  157. pixeltable/metadata/converters/convert_39.py +124 -0
  158. pixeltable/metadata/converters/convert_40.py +73 -0
  159. pixeltable/metadata/converters/convert_41.py +12 -0
  160. pixeltable/metadata/converters/convert_42.py +9 -0
  161. pixeltable/metadata/converters/convert_43.py +44 -0
  162. pixeltable/metadata/converters/util.py +20 -31
  163. pixeltable/metadata/notes.py +9 -0
  164. pixeltable/metadata/schema.py +140 -53
  165. pixeltable/metadata/utils.py +74 -0
  166. pixeltable/mypy/__init__.py +3 -0
  167. pixeltable/mypy/mypy_plugin.py +123 -0
  168. pixeltable/plan.py +382 -115
  169. pixeltable/share/__init__.py +1 -1
  170. pixeltable/share/packager.py +547 -83
  171. pixeltable/share/protocol/__init__.py +33 -0
  172. pixeltable/share/protocol/common.py +165 -0
  173. pixeltable/share/protocol/operation_types.py +33 -0
  174. pixeltable/share/protocol/replica.py +119 -0
  175. pixeltable/share/publish.py +257 -59
  176. pixeltable/store.py +311 -194
  177. pixeltable/type_system.py +373 -211
  178. pixeltable/utils/__init__.py +2 -3
  179. pixeltable/utils/arrow.py +131 -17
  180. pixeltable/utils/av.py +298 -0
  181. pixeltable/utils/azure_store.py +346 -0
  182. pixeltable/utils/coco.py +6 -6
  183. pixeltable/utils/code.py +3 -3
  184. pixeltable/utils/console_output.py +4 -1
  185. pixeltable/utils/coroutine.py +6 -23
  186. pixeltable/utils/dbms.py +32 -6
  187. pixeltable/utils/description_helper.py +4 -5
  188. pixeltable/utils/documents.py +7 -18
  189. pixeltable/utils/exception_handler.py +7 -30
  190. pixeltable/utils/filecache.py +6 -6
  191. pixeltable/utils/formatter.py +86 -48
  192. pixeltable/utils/gcs_store.py +295 -0
  193. pixeltable/utils/http.py +133 -0
  194. pixeltable/utils/http_server.py +2 -3
  195. pixeltable/utils/iceberg.py +1 -2
  196. pixeltable/utils/image.py +17 -0
  197. pixeltable/utils/lancedb.py +90 -0
  198. pixeltable/utils/local_store.py +322 -0
  199. pixeltable/utils/misc.py +5 -0
  200. pixeltable/utils/object_stores.py +573 -0
  201. pixeltable/utils/pydantic.py +60 -0
  202. pixeltable/utils/pytorch.py +5 -6
  203. pixeltable/utils/s3_store.py +527 -0
  204. pixeltable/utils/sql.py +26 -0
  205. pixeltable/utils/system.py +30 -0
  206. pixeltable-0.5.7.dist-info/METADATA +579 -0
  207. pixeltable-0.5.7.dist-info/RECORD +227 -0
  208. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
  209. pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
  210. pixeltable/__version__.py +0 -3
  211. pixeltable/catalog/named_function.py +0 -40
  212. pixeltable/ext/__init__.py +0 -17
  213. pixeltable/ext/functions/__init__.py +0 -11
  214. pixeltable/ext/functions/whisperx.py +0 -77
  215. pixeltable/utils/media_store.py +0 -77
  216. pixeltable/utils/s3.py +0 -17
  217. pixeltable-0.3.14.dist-info/METADATA +0 -434
  218. pixeltable-0.3.14.dist-info/RECORD +0 -186
  219. pixeltable-0.3.14.dist-info/entry_points.txt +0 -3
  220. {pixeltable-0.3.14.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
+ import datetime
1
2
  import logging
2
3
  import warnings
3
4
  from decimal import Decimal
4
- from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Optional, Sequence
5
+ from typing import TYPE_CHECKING, AsyncIterator, Iterable, NamedTuple, Sequence
5
6
  from uuid import UUID
6
7
 
7
8
  import sqlalchemy as sql
@@ -14,19 +15,20 @@ from .exec_node import ExecNode
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  import pixeltable.plan
18
+ from pixeltable.plan import SampleClause
17
19
 
18
20
  _logger = logging.getLogger('pixeltable')
19
21
 
20
22
 
21
23
  class OrderByItem(NamedTuple):
22
24
  expr: exprs.Expr
23
- asc: Optional[bool]
25
+ asc: bool | None
24
26
 
25
27
 
26
28
  OrderByClause = list[OrderByItem]
27
29
 
28
30
 
29
- def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
31
+ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> OrderByClause | None:
30
32
  """Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
31
33
  Two clauses are compatible if for each of their respective items c1[i] and c2[i]
32
34
  a) the exprs are identical and
@@ -55,60 +57,90 @@ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[Order
55
57
 
56
58
  def print_order_by_clause(clause: OrderByClause) -> str:
57
59
  return ', '.join(
58
- [
59
- f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
60
- for item in clause
61
- ]
60
+ f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
61
+ for item in clause
62
62
  )
63
63
 
64
64
 
65
65
  class SqlNode(ExecNode):
66
66
  """
67
- Materializes data from the store via a Select stmt.
67
+ Materializes data from the store via a SQL statement.
68
68
  This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
69
+ The pk columns are not included in the select list.
70
+ If set_pk is True, they are added to the end of the result set when creating the SQL statement
71
+ so they can always be referenced as cols[-num_pk_cols:] in the result set.
72
+ The pk_columns consist of the rowid columns of the target table followed by the version number.
73
+
74
+ If row_builder contains references to unstored iter columns, expands the select list to include their
75
+ SQL-materializable subexpressions.
76
+
77
+ Args:
78
+ select_list: output of the query
79
+ set_pk: if True, sets the primary for each DataRow
69
80
  """
70
81
 
71
- tbl: Optional[catalog.TableVersionPath]
82
+ tbl: catalog.TableVersionPath | None
72
83
  select_list: exprs.ExprSet
84
+ columns: list[catalog.Column] # for which columns to populate DataRow.cell_vals/cell_md
85
+ cell_md_refs: list[exprs.ColumnPropertyRef] # of ColumnRefs which also need DataRow.slot_cellmd for evaluation
73
86
  set_pk: bool
74
87
  num_pk_cols: int
75
- py_filter: Optional[exprs.Expr] # a predicate that can only be run in Python
76
- py_filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
77
- cte: Optional[sql.CTE]
88
+ py_filter: exprs.Expr | None # a predicate that can only be run in Python
89
+ py_filter_eval_ctx: exprs.RowBuilder.EvalCtx | None
90
+ cte: sql.CTE | None
78
91
  sql_elements: exprs.SqlElementCache
79
92
 
93
+ # execution state
94
+ sql_select_list_exprs: exprs.ExprSet
95
+ cellmd_item_idxs: exprs.ExprDict[int] # cellmd expr -> idx in sql select list
96
+ column_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
97
+ column_cellmd_item_idxs: dict[catalog.Column, int] # column -> idx in sql select list
98
+ result_cursor: sql.engine.CursorResult | None
99
+
80
100
  # where_clause/-_element: allow subclass to set one or the other (but not both)
81
- where_clause: Optional[exprs.Expr]
82
- where_clause_element: Optional[sql.ColumnElement]
101
+ where_clause: exprs.Expr | None
102
+ where_clause_element: sql.ColumnElement | None
83
103
 
84
104
  order_by_clause: OrderByClause
85
- limit: Optional[int]
105
+ limit: int | None
86
106
 
87
107
  def __init__(
88
108
  self,
89
- tbl: Optional[catalog.TableVersionPath],
109
+ tbl: catalog.TableVersionPath | None,
90
110
  row_builder: exprs.RowBuilder,
91
111
  select_list: Iterable[exprs.Expr],
112
+ columns: list[catalog.Column],
92
113
  sql_elements: exprs.SqlElementCache,
114
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
93
115
  set_pk: bool = False,
94
116
  ):
95
- """
96
- If row_builder contains references to unstored iter columns, expands the select list to include their
97
- SQL-materializable subexpressions.
98
-
99
- Args:
100
- select_list: output of the query
101
- set_pk: if True, sets the primary for each DataRow
102
- """
103
117
  # create Select stmt
104
118
  self.sql_elements = sql_elements
105
119
  self.tbl = tbl
120
+ self.columns = columns
121
+ if cell_md_col_refs is not None:
122
+ assert all(ref.col.stores_cellmd for ref in cell_md_col_refs)
123
+ self.cell_md_refs = [
124
+ exprs.ColumnPropertyRef(ref, exprs.ColumnPropertyRef.Property.CELLMD) for ref in cell_md_col_refs
125
+ ]
126
+ else:
127
+ self.cell_md_refs = []
106
128
  self.select_list = exprs.ExprSet(select_list)
107
- # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
129
+ # unstored iter columns: we also need to retrieve whatever is needed to materialize the
130
+ # iter args and stored outputs
108
131
  for iter_arg in row_builder.unstored_iter_args.values():
109
132
  sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
110
- for e in sql_subexprs:
111
- self.select_list.add(e)
133
+ self.select_list.update(sql_subexprs)
134
+ # We query for unstored outputs only if we're not loading a view; when we're loading a view, we are populating
135
+ # those columns, so we need to keep them out of the select list. This isn't a problem, because view loads never
136
+ # need to call set_pos().
137
+ # TODO: This is necessary because create_view_load_plan passes stored output columns to `RowBuilder` via the
138
+ # `columns` parameter (even though they don't appear in `output_exprs`). This causes them to be recorded as
139
+ # expressions in `RowBuilder`, which creates a conflict if we add them here. If `RowBuilder` is restructured
140
+ # to keep them out of `unique_exprs`, then this conditional can be removed.
141
+ if not row_builder.for_view_load:
142
+ for outputs in row_builder.unstored_iter_outputs.values():
143
+ self.select_list.update(outputs)
112
144
  super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
113
145
 
114
146
  if tbl is not None:
@@ -122,8 +154,12 @@ class SqlNode(ExecNode):
122
154
  # we also need to retrieve the pk columns
123
155
  assert tbl is not None
124
156
  self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
157
+ assert self.num_pk_cols > 1
125
158
 
126
159
  # additional state
160
+ self.cellmd_item_idxs = exprs.ExprDict()
161
+ self.column_item_idxs = {}
162
+ self.column_cellmd_item_idxs = {}
127
163
  self.result_cursor = None
128
164
  # the filter is provided by the subclass
129
165
  self.py_filter = None
@@ -134,14 +170,38 @@ class SqlNode(ExecNode):
134
170
  self.where_clause_element = None
135
171
  self.order_by_clause = []
136
172
 
137
- def _create_stmt(self) -> sql.Select:
138
- """Create Select from local state"""
173
+ if self.tbl is not None:
174
+ tv = self.tbl.tbl_version._tbl_version
175
+ if tv is not None:
176
+ assert tv.is_validated
139
177
 
140
- assert self.sql_elements.contains_all(self.select_list)
141
- sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
178
+ def _pk_col_items(self) -> list[sql.Column]:
142
179
  if self.set_pk:
180
+ # we need to retrieve the pk columns
143
181
  assert self.tbl is not None
144
- sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
182
+ assert self.tbl.tbl_version.get().is_validated
183
+ return self.tbl.tbl_version.get().store_tbl.pk_columns()
184
+ return []
185
+
186
+ def _init_exec_state(self) -> None:
187
+ assert self.sql_elements.contains_all(self.select_list)
188
+ self.sql_select_list_exprs = exprs.ExprSet(self.select_list)
189
+ self.cellmd_item_idxs = exprs.ExprDict((ref, self.sql_select_list_exprs.add(ref)) for ref in self.cell_md_refs)
190
+ column_refs = [exprs.ColumnRef(col) for col in self.columns]
191
+ self.column_item_idxs = {col_ref.col: self.sql_select_list_exprs.add(col_ref) for col_ref in column_refs}
192
+ column_cellmd_refs = [
193
+ exprs.ColumnPropertyRef(col_ref, exprs.ColumnPropertyRef.Property.CELLMD)
194
+ for col_ref in column_refs
195
+ if col_ref.col.stores_cellmd
196
+ ]
197
+ self.column_cellmd_item_idxs = {
198
+ cellmd_ref.col_ref.col: self.sql_select_list_exprs.add(cellmd_ref) for cellmd_ref in column_cellmd_refs
199
+ }
200
+
201
+ def _create_stmt(self) -> sql.Select:
202
+ """Create Select from local state"""
203
+ self._init_exec_state()
204
+ sql_select_list = [self.sql_elements.get(e) for e in self.sql_select_list_exprs] + self._pk_col_items()
145
205
  stmt = sql.select(*sql_select_list)
146
206
 
147
207
  where_clause_element = (
@@ -167,9 +227,10 @@ class SqlNode(ExecNode):
167
227
  def _ordering_tbl_ids(self) -> set[UUID]:
168
228
  return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
169
229
 
170
- def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
230
+ def to_cte(self, keep_pk: bool = False) -> tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]] | None:
171
231
  """
172
- Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
232
+ Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
233
+ keep_pk: if True, the PK columns are included in the CTE Select statement
173
234
 
174
235
  Returns:
175
236
  (CTE, dict from Expr to output column)
@@ -177,11 +238,11 @@ class SqlNode(ExecNode):
177
238
  if self.py_filter is not None:
178
239
  # the filter needs to run in Python
179
240
  return None
180
- self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
181
241
  if self.cte is None:
242
+ if not keep_pk:
243
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
182
244
  self.cte = self._create_stmt().cte()
183
- assert len(self.cte.c) == len(self.select_list)
184
- return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
245
+ return self.cte, exprs.ExprDict(zip(list(self.select_list) + self.cell_md_refs, self.cte.c)) # skip pk cols
185
246
 
186
247
  @classmethod
187
248
  def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
@@ -195,8 +256,8 @@ class SqlNode(ExecNode):
195
256
  cls,
196
257
  tbl: catalog.TableVersionPath,
197
258
  stmt: sql.Select,
198
- refd_tbl_ids: Optional[set[UUID]] = None,
199
- exact_version_only: Optional[set[UUID]] = None,
259
+ refd_tbl_ids: set[UUID] | None = None,
260
+ exact_version_only: set[UUID] | None = None,
200
261
  ) -> sql.Select:
201
262
  """Add From clause to stmt for tables/views referenced by materialized_exprs
202
263
  Args:
@@ -220,26 +281,29 @@ class SqlNode(ExecNode):
220
281
  joined_tbls.append(t)
221
282
 
222
283
  first = True
223
- prev_tbl: Optional[catalog.TableVersionHandle] = None
284
+ prev_tv: catalog.TableVersion | None = None
224
285
  for t in joined_tbls[::-1]:
286
+ tv = t.get()
287
+ # _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
225
288
  if first:
226
- stmt = stmt.select_from(t.get().store_tbl.sa_tbl)
289
+ stmt = stmt.select_from(tv.store_tbl.sa_tbl)
227
290
  first = False
228
291
  else:
229
- # join tbl to prev_tbl on prev_tbl's rowid cols
230
- prev_tbl_rowid_cols = prev_tbl.get().store_tbl.rowid_columns()
231
- tbl_rowid_cols = t.get().store_tbl.rowid_columns()
292
+ # join tv to prev_tv on prev_tv's rowid cols
293
+ prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
294
+ tbl_rowid_cols = tv.store_tbl.rowid_columns()
232
295
  rowid_clauses = [
233
296
  c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
234
297
  ]
235
- stmt = stmt.join(t.get().store_tbl.sa_tbl, sql.and_(*rowid_clauses))
298
+ stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
299
+
236
300
  if t.id in exact_version_only:
237
- stmt = stmt.where(t.get().store_tbl.v_min_col == t.get().version)
301
+ stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
238
302
  else:
239
- stmt = stmt.where(t.get().store_tbl.v_min_col <= t.get().version).where(
240
- t.get().store_tbl.v_max_col > t.get().version
241
- )
242
- prev_tbl = t
303
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
304
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
305
+ prev_tv = tv
306
+
243
307
  return stmt
244
308
 
245
309
  def set_where(self, where_clause: exprs.Expr) -> None:
@@ -284,7 +348,8 @@ class SqlNode(ExecNode):
284
348
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
285
349
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
286
350
  except Exception:
287
- pass
351
+ # log something if we can't log the compiled stmt
352
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
288
353
  self._log_explain(stmt)
289
354
 
290
355
  conn = Env.get().conn
@@ -292,28 +357,56 @@ class SqlNode(ExecNode):
292
357
  for _ in w:
293
358
  pass
294
359
 
295
- tbl_version = self.tbl.tbl_version if self.tbl is not None else None
296
- output_batch = DataRowBatch(tbl_version, self.row_builder)
297
- output_row: Optional[exprs.DataRow] = None
360
+ output_batch = DataRowBatch(self.row_builder)
361
+ output_row: exprs.DataRow | None = None
298
362
  num_rows_returned = 0
363
+ is_using_cockroachdb = Env.get().is_using_cockroachdb
364
+ tzinfo = Env.get().default_time_zone
299
365
 
300
366
  for sql_row in result_cursor:
301
367
  output_row = output_batch.add_row(output_row)
302
368
 
303
369
  # populate output_row
370
+
304
371
  if self.num_pk_cols > 0:
305
372
  output_row.set_pk(tuple(sql_row[-self.num_pk_cols :]))
373
+
374
+ # column copies
375
+ for col, item_idx in self.column_item_idxs.items():
376
+ output_row.cell_vals[col.id] = sql_row[item_idx]
377
+ for col, item_idx in self.column_cellmd_item_idxs.items():
378
+ cell_md_dict = sql_row[item_idx]
379
+ output_row.cell_md[col.id] = exprs.CellMd(**cell_md_dict) if cell_md_dict is not None else None
380
+
381
+ # populate DataRow.slot_cellmd, where requested
382
+ for cellmd_ref, item_idx in self.cellmd_item_idxs.items():
383
+ cell_md_dict = sql_row[item_idx]
384
+ output_row.slot_md[cellmd_ref.col_ref.slot_idx] = (
385
+ exprs.CellMd.from_dict(cell_md_dict) if cell_md_dict is not None else None
386
+ )
387
+
306
388
  # copy the output of the SQL query into the output row
307
389
  for i, e in enumerate(self.select_list):
308
390
  slot_idx = e.slot_idx
309
- # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
310
391
  if isinstance(sql_row[i], Decimal):
392
+ # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
311
393
  if e.col_type.is_int_type():
312
394
  output_row[slot_idx] = int(sql_row[i])
313
395
  elif e.col_type.is_float_type():
314
396
  output_row[slot_idx] = float(sql_row[i])
315
397
  else:
316
398
  raise RuntimeError(f'Unexpected Decimal value for {e}')
399
+ elif is_using_cockroachdb and isinstance(sql_row[i], datetime.datetime):
400
+ # Ensure that the datetime is timezone-aware and in the session time zone
401
+ # cockroachDB returns timestamps in the session time zone, with numeric offset,
402
+ # convert to the session time zone with the requested tzinfo for DST handling
403
+ if e.col_type.is_timestamp_type():
404
+ if isinstance(sql_row[i].tzinfo, datetime.timezone):
405
+ output_row[slot_idx] = sql_row[i].astimezone(tz=tzinfo)
406
+ else:
407
+ output_row[slot_idx] = sql_row[i]
408
+ else:
409
+ raise RuntimeError(f'Unexpected datetime value for {e}')
317
410
  else:
318
411
  output_row[slot_idx] = sql_row[i]
319
412
 
@@ -335,7 +428,7 @@ class SqlNode(ExecNode):
335
428
  if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
336
429
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
337
430
  yield output_batch
338
- output_batch = DataRowBatch(tbl_version, self.row_builder)
431
+ output_batch = DataRowBatch(self.row_builder)
339
432
 
340
433
  if len(output_batch) > 0:
341
434
  _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
@@ -351,6 +444,11 @@ class SqlScanNode(SqlNode):
351
444
  Materializes data from the store via a Select stmt.
352
445
 
353
446
  Supports filtering and ordering.
447
+
448
+ Args:
449
+ select_list: output of the query
450
+ set_pk: if True, sets the primary for each DataRow
451
+ exact_version_only: tables for which we only want to see rows created at the current version
354
452
  """
355
453
 
356
454
  exact_version_only: list[catalog.TableVersionHandle]
@@ -360,17 +458,21 @@ class SqlScanNode(SqlNode):
360
458
  tbl: catalog.TableVersionPath,
361
459
  row_builder: exprs.RowBuilder,
362
460
  select_list: Iterable[exprs.Expr],
461
+ columns: list[catalog.Column],
462
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
363
463
  set_pk: bool = False,
364
- exact_version_only: Optional[list[catalog.TableVersionHandle]] = None,
464
+ exact_version_only: list[catalog.TableVersionHandle] | None = None,
365
465
  ):
366
- """
367
- Args:
368
- select_list: output of the query
369
- set_pk: if True, sets the primary for each DataRow
370
- exact_version_only: tables for which we only want to see rows created at the current version
371
- """
372
466
  sql_elements = exprs.SqlElementCache()
373
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
467
+ super().__init__(
468
+ tbl,
469
+ row_builder,
470
+ select_list,
471
+ columns=columns,
472
+ sql_elements=sql_elements,
473
+ set_pk=set_pk,
474
+ cell_md_col_refs=cell_md_col_refs,
475
+ )
374
476
  # create Select stmt
375
477
  if exact_version_only is None:
376
478
  exact_version_only = []
@@ -390,6 +492,11 @@ class SqlScanNode(SqlNode):
390
492
  class SqlLookupNode(SqlNode):
391
493
  """
392
494
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
495
+
496
+ Args:
497
+ select_list: output of the query
498
+ sa_key_cols: list of key columns in the store table
499
+ key_vals: list of key values to look up
393
500
  """
394
501
 
395
502
  def __init__(
@@ -397,17 +504,21 @@ class SqlLookupNode(SqlNode):
397
504
  tbl: catalog.TableVersionPath,
398
505
  row_builder: exprs.RowBuilder,
399
506
  select_list: Iterable[exprs.Expr],
507
+ columns: list[catalog.Column],
400
508
  sa_key_cols: list[sql.Column],
401
509
  key_vals: list[tuple],
510
+ cell_md_col_refs: list[exprs.ColumnRef] | None = None,
402
511
  ):
403
- """
404
- Args:
405
- select_list: output of the query
406
- sa_key_cols: list of key columns in the store table
407
- key_vals: list of key values to look up
408
- """
409
512
  sql_elements = exprs.SqlElementCache()
410
- super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
513
+ super().__init__(
514
+ tbl,
515
+ row_builder,
516
+ select_list,
517
+ columns=columns,
518
+ sql_elements=sql_elements,
519
+ set_pk=True,
520
+ cell_md_col_refs=cell_md_col_refs,
521
+ )
411
522
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
412
523
  self.where_clause_element = sql.tuple_(*sa_key_cols).in_(key_vals)
413
524
 
@@ -421,29 +532,29 @@ class SqlLookupNode(SqlNode):
421
532
  class SqlAggregationNode(SqlNode):
422
533
  """
423
534
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
535
+
536
+ Args:
537
+ select_list: can contain calls to AggregateFunctions
538
+ group_by_items: list of expressions to group by
539
+ limit: max number of rows to return: None = no limit
424
540
  """
425
541
 
426
- group_by_items: Optional[list[exprs.Expr]]
427
- input_cte: Optional[sql.CTE]
542
+ group_by_items: list[exprs.Expr] | None
543
+ input_cte: sql.CTE | None
428
544
 
429
545
  def __init__(
430
546
  self,
431
547
  row_builder: exprs.RowBuilder,
432
548
  input: SqlNode,
433
549
  select_list: Iterable[exprs.Expr],
434
- group_by_items: Optional[list[exprs.Expr]] = None,
435
- limit: Optional[int] = None,
436
- exact_version_only: Optional[list[catalog.TableVersion]] = None,
550
+ group_by_items: list[exprs.Expr] | None = None,
551
+ limit: int | None = None,
552
+ exact_version_only: list[catalog.TableVersion] | None = None,
437
553
  ):
438
- """
439
- Args:
440
- select_list: can contain calls to AggregateFunctions
441
- group_by_items: list of expressions to group by
442
- limit: max number of rows to return: None = no limit
443
- """
554
+ assert len(input.cell_md_refs) == 0 # there's no aggregation over json or arrays in SQL
444
555
  self.input_cte, input_col_map = input.to_cte()
445
556
  sql_elements = exprs.SqlElementCache(input_col_map)
446
- super().__init__(None, row_builder, select_list, sql_elements)
557
+ super().__init__(None, row_builder, select_list, columns=[], sql_elements=sql_elements)
447
558
  self.group_by_items = group_by_items
448
559
 
449
560
  def _create_stmt(self) -> sql.Select:
@@ -479,7 +590,10 @@ class SqlJoinNode(SqlNode):
479
590
  input_cte, input_col_map = input_node.to_cte()
480
591
  self.input_ctes.append(input_cte)
481
592
  sql_elements.extend(input_col_map)
482
- super().__init__(None, row_builder, select_list, sql_elements)
593
+ cell_md_col_refs = [cell_md_ref.col_ref for input in inputs for cell_md_ref in input.cell_md_refs]
594
+ super().__init__(
595
+ None, row_builder, select_list, columns=[], sql_elements=sql_elements, cell_md_col_refs=cell_md_col_refs
596
+ )
483
597
 
484
598
  def _create_stmt(self) -> sql.Select:
485
599
  from pixeltable import plan
@@ -501,3 +615,156 @@ class SqlJoinNode(SqlNode):
501
615
  full=join_clause == plan.JoinType.FULL_OUTER,
502
616
  )
503
617
  return stmt
618
+
619
+
620
+ class SqlSampleNode(SqlNode):
621
+ """
622
+ Returns rows sampled from the input node.
623
+
624
+ Args:
625
+ input: SqlNode to sample from
626
+ select_list: can contain calls to AggregateFunctions
627
+ sample_clause: specifies the sampling method
628
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
629
+ """
630
+
631
+ input_cte: sql.CTE | None
632
+ pk_count: int
633
+ stratify_exprs: list[exprs.Expr] | None
634
+ sample_clause: 'SampleClause'
635
+
636
+ def __init__(
637
+ self,
638
+ row_builder: exprs.RowBuilder,
639
+ input: SqlNode,
640
+ select_list: Iterable[exprs.Expr],
641
+ sample_clause: 'SampleClause',
642
+ stratify_exprs: list[exprs.Expr],
643
+ ):
644
+ assert isinstance(input, SqlNode)
645
+ self.input_cte, input_col_map = input.to_cte(keep_pk=True)
646
+ self.pk_count = input.num_pk_cols
647
+ assert self.pk_count > 1
648
+ sql_elements = exprs.SqlElementCache(input_col_map)
649
+ assert sql_elements.contains_all(stratify_exprs)
650
+ cell_md_col_refs = [cell_md_ref.col_ref for cell_md_ref in input.cell_md_refs]
651
+ super().__init__(
652
+ input.tbl,
653
+ row_builder,
654
+ select_list,
655
+ columns=[],
656
+ sql_elements=sql_elements,
657
+ cell_md_col_refs=cell_md_col_refs,
658
+ set_pk=True,
659
+ )
660
+ self.stratify_exprs = stratify_exprs
661
+ self.sample_clause = sample_clause
662
+
663
+ @classmethod
664
+ def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
665
+ """Construct expression which is the ordering key for rows to be sampled
666
+ General SQL form is:
667
+ - MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
668
+ """
669
+ sql_expr: sql.ColumnElement = seed.cast(sql.String)
670
+ for e in sql_cols:
671
+ # Quotes are required below to guarantee that the string is properly presented in SQL
672
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + e.cast(sql.String)
673
+ sql_expr = sql.func.md5(sql_expr)
674
+ return sql_expr
675
+
676
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
677
+ """Create an expression for randomly ordering rows with a given seed"""
678
+ rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
679
+ assert len(rowid_cols) > 0
680
+ # If seed is not set in the sample clause, use the random seed given by the execution context
681
+ seed = self.sample_clause.seed if self.sample_clause.seed is not None else self.ctx.random_seed
682
+ return self.key_sql_expr(sql.literal_column(str(seed)), rowid_cols)
683
+
684
+ def _create_stmt(self) -> sql.Select:
685
+ from pixeltable.plan import SampleClause
686
+
687
+ self._init_exec_state()
688
+
689
+ if self.sample_clause.fraction is not None:
690
+ if len(self.stratify_exprs) == 0:
691
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
692
+ s_key = self._create_key_sql(self.input_cte)
693
+
694
+ # Construct a suitable where clause
695
+ fraction_md5 = SampleClause.fraction_to_md5_hex(self.sample_clause.fraction)
696
+ order_by = self._create_key_sql(self.input_cte)
697
+ return sql.select(*self.input_cte.c).where(s_key < fraction_md5).order_by(order_by)
698
+
699
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
700
+ else:
701
+ if len(self.stratify_exprs) == 0:
702
+ # No stratification, just return n samples from the input CTE
703
+ order_by = self._create_key_sql(self.input_cte)
704
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
705
+
706
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
707
+
708
+ def _create_stmt_stratified_n(self, n: int | None, n_per_stratum: int | None) -> sql.Select:
709
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
710
+
711
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
712
+ order_by = self._create_key_sql(self.input_cte)
713
+
714
+ # Create a list of all columns plus the rank
715
+ # Get all columns from the input CTE dynamically
716
+ select_columns = [*self.input_cte.c]
717
+ select_columns.append(
718
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
719
+ )
720
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
721
+
722
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
723
+ if n_per_stratum is not None:
724
+ return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
725
+ else:
726
+ secondary_order = self._create_key_sql(row_rank_cte)
727
+ return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
728
+
729
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
730
+ """Create a Select stmt that returns a fraction of the rows per strata"""
731
+
732
+ # Build the strata count CTE
733
+ # Produces a table of the form:
734
+ # (*stratify_exprs, s_s_size)
735
+ # where s_s_size is the number of samples to take from each stratum
736
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
737
+ per_strata_count_cte = (
738
+ sql.select(
739
+ *sql_strata_exprs,
740
+ sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
741
+ )
742
+ .select_from(self.input_cte)
743
+ .group_by(*sql_strata_exprs)
744
+ .cte('per_strata_count_cte')
745
+ )
746
+
747
+ # Build a CTE that ranks the rows within each stratum
748
+ # Include all columns from the input CTE dynamically
749
+ order_by = self._create_key_sql(self.input_cte)
750
+ select_columns = [*self.input_cte.c]
751
+ select_columns.append(
752
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
753
+ )
754
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
755
+
756
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
757
+ join_c = sql.true()
758
+ for col in per_strata_count_cte.c[:-1]:
759
+ join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
760
+
761
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
762
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
763
+ stmt = (
764
+ sql.select(*final_columns)
765
+ .select_from(row_rank_cte)
766
+ .join(per_strata_count_cte, join_c)
767
+ .where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
768
+ )
769
+
770
+ return stmt
@@ -6,7 +6,7 @@ from .column_property_ref import ColumnPropertyRef
6
6
  from .column_ref import ColumnRef
7
7
  from .comparison import Comparison
8
8
  from .compound_predicate import CompoundPredicate
9
- from .data_row import DataRow
9
+ from .data_row import ArrayMd, BinaryMd, CellMd, DataRow
10
10
  from .expr import Expr
11
11
  from .expr_dict import ExprDict
12
12
  from .expr_set import ExprSet