pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (120) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/column.py +37 -11
  5. pixeltable/catalog/globals.py +21 -0
  6. pixeltable/catalog/insertable_table.py +6 -4
  7. pixeltable/catalog/table.py +227 -148
  8. pixeltable/catalog/table_version.py +66 -28
  9. pixeltable/catalog/table_version_path.py +0 -8
  10. pixeltable/catalog/view.py +18 -19
  11. pixeltable/dataframe.py +16 -32
  12. pixeltable/env.py +6 -1
  13. pixeltable/exec/__init__.py +1 -2
  14. pixeltable/exec/aggregation_node.py +27 -17
  15. pixeltable/exec/cache_prefetch_node.py +1 -1
  16. pixeltable/exec/data_row_batch.py +9 -26
  17. pixeltable/exec/exec_node.py +36 -7
  18. pixeltable/exec/expr_eval_node.py +19 -11
  19. pixeltable/exec/in_memory_data_node.py +14 -11
  20. pixeltable/exec/sql_node.py +266 -138
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/arithmetic_expr.py +3 -1
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +93 -14
  26. pixeltable/exprs/comparison.py +5 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +56 -36
  29. pixeltable/exprs/expr.py +65 -63
  30. pixeltable/exprs/expr_dict.py +55 -0
  31. pixeltable/exprs/expr_set.py +26 -15
  32. pixeltable/exprs/function_call.py +53 -24
  33. pixeltable/exprs/globals.py +4 -1
  34. pixeltable/exprs/in_predicate.py +8 -7
  35. pixeltable/exprs/inline_expr.py +4 -4
  36. pixeltable/exprs/is_null.py +4 -4
  37. pixeltable/exprs/json_mapper.py +11 -12
  38. pixeltable/exprs/json_path.py +5 -10
  39. pixeltable/exprs/literal.py +5 -5
  40. pixeltable/exprs/method_ref.py +5 -4
  41. pixeltable/exprs/object_ref.py +2 -1
  42. pixeltable/exprs/row_builder.py +88 -36
  43. pixeltable/exprs/rowid_ref.py +14 -13
  44. pixeltable/exprs/similarity_expr.py +12 -7
  45. pixeltable/exprs/sql_element_cache.py +12 -6
  46. pixeltable/exprs/type_cast.py +8 -6
  47. pixeltable/exprs/variable.py +5 -4
  48. pixeltable/ext/functions/whisperx.py +7 -2
  49. pixeltable/func/aggregate_function.py +1 -1
  50. pixeltable/func/callable_function.py +2 -2
  51. pixeltable/func/function.py +11 -10
  52. pixeltable/func/function_registry.py +6 -7
  53. pixeltable/func/query_template_function.py +11 -12
  54. pixeltable/func/signature.py +17 -15
  55. pixeltable/func/udf.py +0 -4
  56. pixeltable/functions/__init__.py +2 -2
  57. pixeltable/functions/audio.py +4 -6
  58. pixeltable/functions/globals.py +84 -42
  59. pixeltable/functions/huggingface.py +31 -34
  60. pixeltable/functions/image.py +59 -45
  61. pixeltable/functions/json.py +0 -1
  62. pixeltable/functions/llama_cpp.py +106 -0
  63. pixeltable/functions/mistralai.py +2 -2
  64. pixeltable/functions/ollama.py +147 -0
  65. pixeltable/functions/openai.py +22 -25
  66. pixeltable/functions/replicate.py +72 -0
  67. pixeltable/functions/string.py +59 -50
  68. pixeltable/functions/timestamp.py +20 -20
  69. pixeltable/functions/together.py +2 -2
  70. pixeltable/functions/video.py +11 -20
  71. pixeltable/functions/whisper.py +2 -20
  72. pixeltable/globals.py +65 -74
  73. pixeltable/index/base.py +2 -2
  74. pixeltable/index/btree.py +20 -7
  75. pixeltable/index/embedding_index.py +12 -14
  76. pixeltable/io/__init__.py +1 -2
  77. pixeltable/io/external_store.py +11 -5
  78. pixeltable/io/fiftyone.py +178 -0
  79. pixeltable/io/globals.py +98 -2
  80. pixeltable/io/hf_datasets.py +1 -1
  81. pixeltable/io/label_studio.py +6 -6
  82. pixeltable/io/parquet.py +14 -13
  83. pixeltable/iterators/base.py +3 -2
  84. pixeltable/iterators/document.py +10 -8
  85. pixeltable/iterators/video.py +126 -60
  86. pixeltable/metadata/__init__.py +4 -3
  87. pixeltable/metadata/converters/convert_14.py +4 -2
  88. pixeltable/metadata/converters/convert_15.py +1 -1
  89. pixeltable/metadata/converters/convert_19.py +1 -0
  90. pixeltable/metadata/converters/convert_20.py +1 -1
  91. pixeltable/metadata/converters/convert_21.py +34 -0
  92. pixeltable/metadata/converters/util.py +54 -12
  93. pixeltable/metadata/notes.py +1 -0
  94. pixeltable/metadata/schema.py +40 -21
  95. pixeltable/plan.py +149 -165
  96. pixeltable/py.typed +0 -0
  97. pixeltable/store.py +57 -37
  98. pixeltable/tool/create_test_db_dump.py +6 -6
  99. pixeltable/tool/create_test_video.py +1 -1
  100. pixeltable/tool/doc_plugins/griffe.py +3 -34
  101. pixeltable/tool/embed_udf.py +1 -1
  102. pixeltable/tool/mypy_plugin.py +55 -0
  103. pixeltable/type_system.py +260 -61
  104. pixeltable/utils/arrow.py +10 -9
  105. pixeltable/utils/coco.py +4 -4
  106. pixeltable/utils/documents.py +16 -2
  107. pixeltable/utils/filecache.py +9 -9
  108. pixeltable/utils/formatter.py +10 -11
  109. pixeltable/utils/http_server.py +2 -5
  110. pixeltable/utils/media_store.py +6 -6
  111. pixeltable/utils/pytorch.py +10 -11
  112. pixeltable/utils/sql.py +2 -1
  113. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
  114. pixeltable-0.2.22.dist-info/RECORD +153 -0
  115. pixeltable/exec/media_validation_node.py +0 -43
  116. pixeltable/utils/help.py +0 -11
  117. pixeltable-0.2.20.dist-info/RECORD +0 -147
  118. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
  119. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
  120. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
@@ -1,73 +1,176 @@
1
- from typing import List, Optional, Tuple, Iterable, Set
2
- from uuid import UUID
3
1
  import logging
4
2
  import warnings
3
+ from decimal import Decimal
4
+ from typing import Iterable, Iterator, NamedTuple, Optional
5
+ from uuid import UUID
5
6
 
6
7
  import sqlalchemy as sql
7
8
 
8
- from .data_row_batch import DataRowBatch
9
- from .exec_node import ExecNode
10
- import pixeltable.exprs as exprs
11
9
  import pixeltable.catalog as catalog
10
+ import pixeltable.exprs as exprs
12
11
 
12
+ from .data_row_batch import DataRowBatch
13
+ from .exec_node import ExecNode
13
14
 
14
15
  _logger = logging.getLogger('pixeltable')
15
16
 
17
+
18
+ class OrderByItem(NamedTuple):
19
+ expr: exprs.Expr
20
+ asc: Optional[bool]
21
+
22
+
23
+ OrderByClause = list[OrderByItem]
24
+
25
+
26
+ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
27
+ """Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
28
+ Two clauses are compatible if for each of their respective items c1[i] and c2[i]
29
+ a) the exprs are identical and
30
+ b) the asc values are identical or at least one is None (None serves as a wildcard)
31
+ """
32
+ result: OrderByClause = []
33
+ for clause in clauses:
34
+ combined: OrderByClause = []
35
+ for item1, item2 in zip(result, clause):
36
+ if item1.expr.id != item2.expr.id:
37
+ return None
38
+ if item1.asc is not None and item2.asc is not None and item1.asc != item2.asc:
39
+ return None
40
+ asc = item1.asc if item1.asc is not None else item2.asc
41
+ combined.append(OrderByItem(item1.expr, asc))
42
+
43
+ # add remaining ordering of the longer list
44
+ prefix_len = min(len(result), len(clause))
45
+ if len(result) > prefix_len:
46
+ combined.extend(result[prefix_len:])
47
+ elif len(clause) > prefix_len:
48
+ combined.extend(clause[prefix_len:])
49
+ result = combined
50
+ return result
51
+
52
+
53
+ def print_order_by_clause(clause: OrderByClause) -> str:
54
+ return ', '.join([
55
+ f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
56
+ for item in clause
57
+ ])
58
+
59
+
16
60
  class SqlNode(ExecNode):
17
- """Materializes data from the store via a Select stmt."""
61
+ """
62
+ Materializes data from the store via a Select stmt.
63
+ This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
64
+ """
65
+
66
+ tbl: Optional[catalog.TableVersionPath]
67
+ select_list: exprs.ExprSet
68
+ set_pk: bool
69
+ num_pk_cols: int
70
+ filter: Optional[exprs.Expr]
71
+ filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
72
+ cte: Optional[sql.CTE]
73
+ sql_elements: exprs.SqlElementCache
74
+ limit: Optional[int]
75
+ order_by_clause: OrderByClause
18
76
 
19
77
  def __init__(
20
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
78
+ self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
21
79
  select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
22
80
  ):
23
81
  """
24
- Initialize self.stmt with expressions derived from select_list.
25
-
26
- This only provides the select list. The subclass is responsible for the From clause and any additional clauses.
82
+ If row_builder contains references to unstored iter columns, expands the select list to include their
83
+ SQL-materializable subexpressions.
27
84
 
28
85
  Args:
29
86
  select_list: output of the query
30
87
  set_pk: if True, sets the primary for each DataRow
31
88
  """
32
89
  # create Select stmt
90
+ self.sql_elements = sql_elements
33
91
  self.tbl = tbl
34
- target = tbl.tbl_version # the stored table we're scanning
35
- self.sql_exprs = exprs.ExprSet(select_list)
92
+ self.select_list = exprs.ExprSet(select_list)
36
93
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
37
94
  for iter_arg in row_builder.unstored_iter_args.values():
38
- sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
95
+ sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
39
96
  for e in sql_subexprs:
40
- self.sql_exprs.add(e)
41
- super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
97
+ self.select_list.add(e)
98
+ super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
42
99
 
43
- # change rowid refs against a base table to rowid refs against the target table, so that we minimize
44
- # the number of tables that need to be joined to the target table
45
- for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
46
- rowid_ref.set_tbl(tbl)
100
+ if tbl is not None:
101
+ # minimize the number of tables that need to be joined to the target table
102
+ self.retarget_rowid_refs(tbl, self.select_list)
47
103
 
48
- sql_select_list = [sql_elements.get(e) for e in self.sql_exprs]
49
- assert len(sql_select_list) == len(self.sql_exprs)
50
- assert all(e is not None for e in sql_select_list)
104
+ assert self.sql_elements.contains_all(self.select_list)
51
105
  self.set_pk = set_pk
52
106
  self.num_pk_cols = 0
53
107
  if set_pk:
54
108
  # we also need to retrieve the pk columns
55
- pk_columns = target.store_tbl.pk_columns()
56
- self.num_pk_cols = len(pk_columns)
57
- sql_select_list += pk_columns
58
-
59
- self.stmt = sql.select(*sql_select_list)
109
+ assert tbl is not None
110
+ self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
60
111
 
61
112
  # additional state
62
- self.result_cursor: Optional[sql.engine.CursorResult] = None
113
+ self.result_cursor = None
63
114
  # the filter is provided by the subclass
64
- self.filter: Optional[exprs.Expr] = None
65
- self.filter_eval_ctx: Optional[exprs.EvalContext] = None
115
+ self.filter = None
116
+ self.filter_eval_ctx = None
117
+ self.cte = None
118
+ self.limit = None
119
+ self.order_by_clause = []
120
+
121
+ def _create_stmt(self) -> sql.Select:
122
+ """Create Select from local state"""
123
+
124
+ assert self.sql_elements.contains_all(self.select_list)
125
+ sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
126
+ if self.set_pk:
127
+ sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
128
+ stmt = sql.select(*sql_select_list)
129
+
130
+ order_by_clause: list[sql.ColumnElement] = []
131
+ for e, asc in self.order_by_clause:
132
+ if isinstance(e, exprs.SimilarityExpr):
133
+ order_by_clause.append(e.as_order_by_clause(asc))
134
+ else:
135
+ order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
136
+ stmt = stmt.order_by(*order_by_clause)
137
+
138
+ if self.filter is None and self.limit is not None:
139
+ # if we don't have a Python filter, we can apply the limit to stmt
140
+ stmt = stmt.limit(self.limit)
141
+
142
+ return stmt
143
+
144
+ def _ordering_tbl_ids(self) -> set[UUID]:
145
+ return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
146
+
147
+ def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
148
+ """
149
+ Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
150
+
151
+ Returns:
152
+ (CTE, dict from Expr to output column)
153
+ """
154
+ if self.filter is not None:
155
+ # the filter needs to run in Python
156
+ return None
157
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
158
+ if self.cte is None:
159
+ self.cte = self._create_stmt().cte()
160
+ assert len(self.cte.c) == len(self.select_list)
161
+ return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
162
+
163
+ @classmethod
164
+ def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
165
+ """Change rowid refs to point to target"""
166
+ for e in expr_seq:
167
+ if isinstance(e, exprs.RowidRef):
168
+ e.set_tbl(target)
66
169
 
67
170
  @classmethod
68
171
  def create_from_clause(
69
- cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[Set[UUID]] = None,
70
- exact_version_only: Optional[Set[UUID]] = None
172
+ cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[set[UUID]] = None,
173
+ exact_version_only: Optional[set[UUID]] = None
71
174
  ) -> sql.Select:
72
175
  """Add From clause to stmt for tables/views referenced by materialized_exprs
73
176
  Args:
@@ -80,17 +183,18 @@ class SqlNode(ExecNode):
80
183
  """
81
184
  # we need to include at least the root
82
185
  if refd_tbl_ids is None:
83
- refd_tbl_ids = {}
186
+ refd_tbl_ids = set()
84
187
  if exact_version_only is None:
85
- exact_version_only = {}
188
+ exact_version_only = set()
86
189
  candidates = tbl.get_tbl_versions()
87
190
  assert len(candidates) > 0
88
- joined_tbls: List[catalog.TableVersion] = [candidates[0]]
191
+ joined_tbls: list[catalog.TableVersion] = [candidates[0]]
89
192
  for tbl in candidates[1:]:
90
193
  if tbl.id in refd_tbl_ids:
91
194
  joined_tbls.append(tbl)
92
195
 
93
196
  first = True
197
+ prev_tbl: catalog.TableVersion
94
198
  for tbl in joined_tbls[::-1]:
95
199
  if first:
96
200
  stmt = stmt.select_from(tbl.store_tbl.sa_tbl)
@@ -111,72 +215,94 @@ class SqlNode(ExecNode):
111
215
  prev_tbl = tbl
112
216
  return stmt
113
217
 
114
- def _log_explain(self, conn: sql.engine.Connection) -> None:
218
+ def add_order_by(self, ordering: OrderByClause) -> None:
219
+ """Add Order By clause to stmt"""
220
+ if self.tbl is not None:
221
+ # change rowid refs against a base table to rowid refs against the target table, so that we minimize
222
+ # the number of tables that need to be joined to the target table
223
+ self.retarget_rowid_refs(self.tbl, [e for e, _ in ordering])
224
+ combined = combine_order_by_clauses([self.order_by_clause, ordering])
225
+ assert combined is not None
226
+ self.order_by_clause = combined
227
+
228
+ def set_limit(self, limit: int) -> None:
229
+ self.limit = limit
230
+
231
+ def _log_explain(self, stmt: sql.Select) -> None:
115
232
  try:
116
233
  # don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
117
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
234
+ stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
118
235
  explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
119
236
  explain_str = '\n'.join([str(row) for row in explain_result])
120
237
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
121
238
  except Exception as e:
122
239
  _logger.warning(f'EXPLAIN failed')
123
240
 
124
- def __next__(self) -> DataRowBatch:
125
- if self.result_cursor is None:
126
- # run the query; do this here rather than in _open(), exceptions are only expected during iteration
127
- assert self.ctx.conn is not None
241
+ def __iter__(self) -> Iterator[DataRowBatch]:
242
+ # run the query; do this here rather than in _open(), exceptions are only expected during iteration
243
+ assert self.ctx.conn is not None
244
+ with warnings.catch_warnings(record=True) as w:
245
+ stmt = self._create_stmt()
128
246
  try:
129
- self._log_explain(self.ctx.conn)
130
- with warnings.catch_warnings(record=True) as w:
131
- self.result_cursor = self.ctx.conn.execute(self.stmt)
132
- for warning in w:
133
- pass
134
- self.has_more_rows = True
135
- except Exception as e:
136
- self.has_more_rows = False
137
- raise e
138
-
139
- if not self.has_more_rows:
140
- raise StopIteration
141
-
142
- output_batch = DataRowBatch(self.tbl.tbl_version, self.row_builder)
143
- needs_row = True
144
- while self.ctx.batch_size == 0 or len(output_batch) < self.ctx.batch_size:
145
- try:
146
- sql_row = next(self.result_cursor)
147
- except StopIteration:
148
- self.has_more_rows = False
149
- break
150
-
151
- if needs_row:
152
- output_row = output_batch.add_row()
247
+ # log stmt, if possible
248
+ stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
249
+ _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
250
+ except Exception:
251
+ pass
252
+ self._log_explain(stmt)
253
+
254
+ result_cursor = self.ctx.conn.execute(stmt)
255
+ for warning in w:
256
+ pass
257
+
258
+ tbl_version = self.tbl.tbl_version if self.tbl is not None else None
259
+ output_batch = DataRowBatch(tbl_version, self.row_builder)
260
+ output_row: Optional[exprs.DataRow] = None
261
+ num_rows_returned = 0
262
+
263
+ for sql_row in result_cursor:
264
+ output_row = output_batch.add_row(output_row)
265
+
266
+ # populate output_row
153
267
  if self.num_pk_cols > 0:
154
268
  output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
155
269
  # copy the output of the SQL query into the output row
156
- for i, e in enumerate(self.sql_exprs):
270
+ for i, e in enumerate(self.select_list):
157
271
  slot_idx = e.slot_idx
158
- output_row[slot_idx] = sql_row[i]
272
+ # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
273
+ if isinstance(sql_row[i], Decimal):
274
+ if e.col_type.is_int_type():
275
+ output_row[slot_idx] = int(sql_row[i])
276
+ elif e.col_type.is_float_type():
277
+ output_row[slot_idx] = float(sql_row[i])
278
+ else:
279
+ raise RuntimeError(f'Unexpected Decimal value for {e}')
280
+ else:
281
+ output_row[slot_idx] = sql_row[i]
282
+
159
283
  if self.filter is not None:
284
+ # evaluate filter
160
285
  self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
161
- if output_row[self.filter.slot_idx]:
162
- needs_row = True
163
- if self.limit > 0 and len(output_batch) >= self.limit:
164
- self.has_more_rows = False
165
- break
166
- else:
167
- # we re-use this row for the next sql row if it didn't pass the filter
168
- needs_row = False
169
- output_row.clear()
286
+ if self.filter is not None and not output_row[self.filter.slot_idx]:
287
+ # we re-use this row for the next sql row since it didn't pass the filter
288
+ output_row = output_batch.pop_row()
289
+ output_row.clear()
290
+ else:
291
+ # reset output_row in order to add new one
292
+ output_row = None
293
+ num_rows_returned += 1
170
294
 
171
- if not needs_row:
172
- # the last row didn't pass the filter
173
- assert self.filter is not None
174
- output_batch.pop_row()
295
+ if self.limit is not None and num_rows_returned == self.limit:
296
+ break
297
+
298
+ if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
299
+ _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
300
+ yield output_batch
301
+ output_batch = DataRowBatch(tbl_version, self.row_builder)
175
302
 
176
- _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
177
- if len(output_batch) == 0:
178
- raise StopIteration
179
- return output_batch
303
+ if len(output_batch) > 0:
304
+ _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
305
+ yield output_batch
180
306
 
181
307
  def _close(self) -> None:
182
308
  if self.result_cursor is not None:
@@ -189,12 +315,14 @@ class SqlScanNode(SqlNode):
189
315
 
190
316
  Supports filtering and ordering.
191
317
  """
318
+ where_clause: Optional[exprs.Expr]
319
+ exact_version_only: list[catalog.TableVersion]
320
+
192
321
  def __init__(
193
322
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
194
323
  select_list: Iterable[exprs.Expr],
195
324
  where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
196
- order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
197
- limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
325
+ set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
198
326
  ):
199
327
  """
200
328
  Args:
@@ -208,52 +336,29 @@ class SqlScanNode(SqlNode):
208
336
  sql_elements = exprs.SqlElementCache()
209
337
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
210
338
  # create Select stmt
211
- if order_by_items is None:
212
- order_by_items = []
213
339
  if exact_version_only is None:
214
340
  exact_version_only = []
215
341
  target = tbl.tbl_version # the stored table we're scanning
216
342
  self.filter = filter
217
343
  self.filter_eval_ctx = \
218
344
  row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
219
- self.limit = limit
220
345
 
221
- where_clause_tbl_ids = where_clause.tbl_ids() if where_clause is not None else set()
222
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs) | where_clause_tbl_ids
223
- self.stmt = self.create_from_clause(
224
- tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
225
-
226
- # change rowid refs against a base table to rowid refs against the target table, so that we minimize
227
- # the number of tables that need to be joined to the target table
228
- for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
229
- rowid_ref.set_tbl(tbl)
230
- order_by_clause: List[sql.ClauseElement] = []
231
- for e, asc in order_by_items:
232
- if isinstance(e, exprs.SimilarityExpr):
233
- order_by_clause.append(e.as_order_by_clause(asc))
234
- else:
235
- order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
346
+ self.where_clause = where_clause
347
+ self.exact_version_only = exact_version_only
236
348
 
237
- if where_clause is not None:
238
- sql_where_clause = sql_elements.get(where_clause)
349
+ def _create_stmt(self) -> sql.Select:
350
+ stmt = super()._create_stmt()
351
+ where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
352
+ refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
353
+ stmt = self.create_from_clause(
354
+ self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
355
+
356
+ if self.where_clause is not None:
357
+ sql_where_clause = self.sql_elements.get(self.where_clause)
239
358
  assert sql_where_clause is not None
240
- self.stmt = self.stmt.where(sql_where_clause)
241
- if len(order_by_clause) > 0:
242
- self.stmt = self.stmt.order_by(*order_by_clause)
243
- elif target.id in row_builder.unstored_iter_args:
244
- # we are referencing unstored iter columns from this view and try to order by our primary key,
245
- # which ensures that iterators will see monotonically increasing pos values
246
- self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
247
- if limit != 0 and self.filter is None:
248
- # if we need to do post-SQL filtering, we can't use LIMIT
249
- self.stmt = self.stmt.limit(limit)
359
+ stmt = stmt.where(sql_where_clause)
250
360
 
251
- try:
252
- # log stmt, if possible
253
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
254
- _logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
255
- except Exception as e:
256
- pass
361
+ return stmt
257
362
 
258
363
 
259
364
  class SqlLookupNode(SqlNode):
@@ -261,8 +366,7 @@ class SqlLookupNode(SqlNode):
261
366
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
262
367
  """
263
368
 
264
- stmt: sql.Select
265
- where_clause: sql.ColumnElement[bool]
369
+ where_clause: sql.ColumnElement
266
370
 
267
371
  def __init__(
268
372
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
@@ -276,21 +380,45 @@ class SqlLookupNode(SqlNode):
276
380
  """
277
381
  sql_elements = exprs.SqlElementCache()
278
382
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
279
- target = tbl.tbl_version # the stored table we're scanning
280
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
281
- self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
282
383
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
283
384
  self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
284
- self.stmt = self.stmt.where(self.where_clause)
285
385
 
286
- if target.id in row_builder.unstored_iter_args:
287
- # we are referencing unstored iter columns from this view and try to order by our primary key,
288
- # which ensures that iterators will see monotonically increasing pos values
289
- self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
386
+ def _create_stmt(self) -> sql.Select:
387
+ stmt = super()._create_stmt()
388
+ refd_tbl_ids = exprs.Expr.all_tbl_ids(self.select_list) | self._ordering_tbl_ids()
389
+ stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
390
+ stmt = stmt.where(self.where_clause)
391
+ return stmt
290
392
 
291
- try:
292
- # log stmt, if possible
293
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
294
- _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
295
- except Exception as e:
296
- pass
393
+ class SqlAggregationNode(SqlNode):
394
+ """
395
+ Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
396
+ """
397
+
398
+ group_by_items: Optional[list[exprs.Expr]]
399
+
400
+ def __init__(
401
+ self, row_builder: exprs.RowBuilder,
402
+ input: SqlNode,
403
+ select_list: Iterable[exprs.Expr],
404
+ group_by_items: Optional[list[exprs.Expr]] = None,
405
+ limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
406
+ ):
407
+ """
408
+ Args:
409
+ select_list: can contain calls to AggregateFunctions
410
+ group_by_items: list of expressions to group by
411
+ limit: max number of rows to return: None = no limit
412
+ """
413
+ _, input_col_map = input.to_cte()
414
+ sql_elements = exprs.SqlElementCache(input_col_map)
415
+ super().__init__(None, row_builder, select_list, sql_elements)
416
+ self.group_by_items = group_by_items
417
+
418
+ def _create_stmt(self) -> sql.Select:
419
+ stmt = super()._create_stmt()
420
+ if self.group_by_items is not None:
421
+ sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
422
+ assert all(e is not None for e in sql_group_by_items)
423
+ stmt = stmt.group_by(*sql_group_by_items)
424
+ return stmt
@@ -6,6 +6,7 @@ from .comparison import Comparison
6
6
  from .compound_predicate import CompoundPredicate
7
7
  from .data_row import DataRow
8
8
  from .expr import Expr
9
+ from .expr_dict import ExprDict
9
10
  from .expr_set import ExprSet
10
11
  from .function_call import FunctionCall
11
12
  from .in_predicate import InPredicate
@@ -6,6 +6,7 @@ import sqlalchemy as sql
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
+
9
10
  from .data_row import DataRow
10
11
  from .expr import Expr
11
12
  from .globals import ArithmeticOperator
@@ -86,6 +87,7 @@ class ArithmeticExpr(Expr):
86
87
  return sql.sql.expression.cast(sql.func.floor(left / right), sql.Integer)
87
88
  if self.col_type.is_float_type():
88
89
  return sql.sql.expression.cast(sql.func.floor(left / right), sql.Float)
90
+ assert False
89
91
 
90
92
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
91
93
  op1_val = data_row[self._op1.slot_idx]
@@ -121,7 +123,7 @@ class ArithmeticExpr(Expr):
121
123
  return {'operator': self.operator.value, **super()._as_dict()}
122
124
 
123
125
  @classmethod
124
- def _from_dict(cls, d: dict, components: list[Expr]) -> Expr:
126
+ def _from_dict(cls, d: dict, components: list[Expr]) -> ArithmeticExpr:
125
127
  assert 'operator' in d
126
128
  assert len(components) == 2
127
129
  return cls(ArithmeticOperator(d['operator']), components[0], components[1])
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Optional, Tuple
3
+ from typing import Any, Optional, Union
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
@@ -15,7 +15,7 @@ class ArraySlice(Expr):
15
15
  """
16
16
  Slice operation on an array, eg, t.array_col[:, 1:2].
17
17
  """
18
- def __init__(self, arr: Expr, index: Tuple):
18
+ def __init__(self, arr: Expr, index: tuple[Union[int, slice], ...]):
19
19
  assert arr.col_type.is_array_type()
20
20
  # determine result type
21
21
  super().__init__(arr.col_type)
@@ -24,7 +24,7 @@ class ArraySlice(Expr):
24
24
  self.id = self._create_id()
25
25
 
26
26
  def __str__(self) -> str:
27
- index_strs: List[str] = []
27
+ index_strs: list[str] = []
28
28
  for el in self.index:
29
29
  if isinstance(el, int):
30
30
  index_strs.append(str(el))
@@ -39,7 +39,7 @@ class ArraySlice(Expr):
39
39
  def _equals(self, other: ArraySlice) -> bool:
40
40
  return self.index == other.index
41
41
 
42
- def _id_attrs(self) -> List[Tuple[str, Any]]:
42
+ def _id_attrs(self) -> list[tuple[str, Any]]:
43
43
  return super()._id_attrs() + [('index', self.index)]
44
44
 
45
45
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
@@ -49,8 +49,8 @@ class ArraySlice(Expr):
49
49
  val = data_row[self._array.slot_idx]
50
50
  data_row[self.slot_idx] = val[self.index]
51
51
 
52
- def _as_dict(self) -> Dict:
53
- index = []
52
+ def _as_dict(self) -> dict:
53
+ index: list[Any] = []
54
54
  for el in self.index:
55
55
  if isinstance(el, slice):
56
56
  index.append([el.start, el.stop, el.step])
@@ -59,7 +59,7 @@ class ArraySlice(Expr):
59
59
  return {'index': index, **super()._as_dict()}
60
60
 
61
61
  @classmethod
62
- def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
62
+ def _from_dict(cls, d: dict, components: list[Expr]) -> ArraySlice:
63
63
  assert 'index' in d
64
64
  index = []
65
65
  for el in d['index']: