pixeltable 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (88) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/globals.py +3 -0
  5. pixeltable/catalog/insertable_table.py +9 -7
  6. pixeltable/catalog/table.py +220 -143
  7. pixeltable/catalog/table_version.py +36 -18
  8. pixeltable/catalog/table_version_path.py +0 -8
  9. pixeltable/catalog/view.py +3 -3
  10. pixeltable/dataframe.py +9 -24
  11. pixeltable/env.py +107 -36
  12. pixeltable/exceptions.py +7 -4
  13. pixeltable/exec/__init__.py +1 -1
  14. pixeltable/exec/aggregation_node.py +22 -15
  15. pixeltable/exec/component_iteration_node.py +62 -41
  16. pixeltable/exec/data_row_batch.py +7 -7
  17. pixeltable/exec/exec_node.py +35 -7
  18. pixeltable/exec/expr_eval_node.py +2 -1
  19. pixeltable/exec/in_memory_data_node.py +9 -9
  20. pixeltable/exec/sql_node.py +265 -136
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/data_row.py +30 -19
  23. pixeltable/exprs/expr.py +15 -14
  24. pixeltable/exprs/expr_dict.py +55 -0
  25. pixeltable/exprs/expr_set.py +21 -15
  26. pixeltable/exprs/function_call.py +21 -8
  27. pixeltable/exprs/json_path.py +3 -6
  28. pixeltable/exprs/rowid_ref.py +2 -2
  29. pixeltable/exprs/sql_element_cache.py +5 -1
  30. pixeltable/ext/functions/whisperx.py +7 -2
  31. pixeltable/func/callable_function.py +2 -2
  32. pixeltable/func/function_registry.py +6 -7
  33. pixeltable/func/query_template_function.py +11 -12
  34. pixeltable/func/signature.py +17 -15
  35. pixeltable/func/udf.py +0 -4
  36. pixeltable/functions/__init__.py +1 -1
  37. pixeltable/functions/audio.py +4 -6
  38. pixeltable/functions/globals.py +86 -42
  39. pixeltable/functions/huggingface.py +12 -14
  40. pixeltable/functions/image.py +59 -45
  41. pixeltable/functions/json.py +0 -1
  42. pixeltable/functions/mistralai.py +2 -2
  43. pixeltable/functions/openai.py +22 -25
  44. pixeltable/functions/string.py +50 -50
  45. pixeltable/functions/timestamp.py +20 -20
  46. pixeltable/functions/together.py +26 -12
  47. pixeltable/functions/video.py +11 -20
  48. pixeltable/functions/whisper.py +2 -20
  49. pixeltable/globals.py +57 -56
  50. pixeltable/index/base.py +2 -2
  51. pixeltable/index/btree.py +7 -7
  52. pixeltable/index/embedding_index.py +8 -10
  53. pixeltable/io/external_store.py +11 -5
  54. pixeltable/io/globals.py +3 -1
  55. pixeltable/io/hf_datasets.py +4 -4
  56. pixeltable/io/label_studio.py +6 -6
  57. pixeltable/io/parquet.py +14 -13
  58. pixeltable/iterators/document.py +10 -8
  59. pixeltable/iterators/video.py +10 -1
  60. pixeltable/metadata/__init__.py +3 -2
  61. pixeltable/metadata/converters/convert_14.py +4 -2
  62. pixeltable/metadata/converters/convert_15.py +1 -1
  63. pixeltable/metadata/converters/convert_19.py +1 -0
  64. pixeltable/metadata/converters/convert_20.py +1 -1
  65. pixeltable/metadata/converters/util.py +9 -8
  66. pixeltable/metadata/schema.py +32 -21
  67. pixeltable/plan.py +136 -154
  68. pixeltable/store.py +51 -36
  69. pixeltable/tool/create_test_db_dump.py +7 -7
  70. pixeltable/tool/doc_plugins/griffe.py +3 -34
  71. pixeltable/tool/mypy_plugin.py +32 -0
  72. pixeltable/type_system.py +243 -60
  73. pixeltable/utils/arrow.py +10 -9
  74. pixeltable/utils/coco.py +4 -4
  75. pixeltable/utils/documents.py +1 -1
  76. pixeltable/utils/filecache.py +131 -84
  77. pixeltable/utils/formatter.py +1 -1
  78. pixeltable/utils/http_server.py +2 -5
  79. pixeltable/utils/media_store.py +6 -6
  80. pixeltable/utils/pytorch.py +10 -11
  81. pixeltable/utils/sql.py +2 -1
  82. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/METADATA +16 -7
  83. pixeltable-0.2.21.dist-info/RECORD +148 -0
  84. pixeltable/utils/help.py +0 -11
  85. pixeltable-0.2.19.dist-info/RECORD +0 -147
  86. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
  87. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
  88. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
@@ -1,73 +1,175 @@
1
- from typing import List, Optional, Tuple, Iterable, Set
2
- from uuid import UUID
3
1
  import logging
4
2
  import warnings
3
+ from decimal import Decimal
4
+ from typing import Optional, Iterable, Iterator, NamedTuple
5
+ from uuid import UUID
5
6
 
6
7
  import sqlalchemy as sql
7
8
 
9
+ import pixeltable.catalog as catalog
10
+ import pixeltable.exprs as exprs
8
11
  from .data_row_batch import DataRowBatch
9
12
  from .exec_node import ExecNode
10
- import pixeltable.exprs as exprs
11
- import pixeltable.catalog as catalog
12
-
13
13
 
14
14
  _logger = logging.getLogger('pixeltable')
15
15
 
16
+
17
+ class OrderByItem(NamedTuple):
18
+ expr: exprs.Expr
19
+ asc: Optional[bool]
20
+
21
+
22
+ OrderByClause = list[OrderByItem]
23
+
24
+
25
+ def combine_order_by_clauses(clauses: Iterable[OrderByClause]) -> Optional[OrderByClause]:
26
+ """Returns a clause that's compatible with 'clauses', or None if that doesn't exist.
27
+ Two clauses are compatible if for each of their respective items c1[i] and c2[i]
28
+ a) the exprs are identical and
29
+ b) the asc values are identical or at least one is None (None serves as a wildcard)
30
+ """
31
+ result: OrderByClause = []
32
+ for clause in clauses:
33
+ combined: OrderByClause = []
34
+ for item1, item2 in zip(result, clause):
35
+ if item1.expr.id != item2.expr.id:
36
+ return None
37
+ if item1.asc is not None and item2.asc is not None and item1.asc != item2.asc:
38
+ return None
39
+ asc = item1.asc if item1.asc is not None else item2.asc
40
+ combined.append(OrderByItem(item1.expr, asc))
41
+
42
+ # add remaining ordering of the longer list
43
+ prefix_len = min(len(result), len(clause))
44
+ if len(result) > prefix_len:
45
+ combined.extend(result[prefix_len:])
46
+ elif len(clause) > prefix_len:
47
+ combined.extend(clause[prefix_len:])
48
+ result = combined
49
+ return result
50
+
51
+
52
+ def print_order_by_clause(clause: OrderByClause) -> str:
53
+ return ', '.join([
54
+ f'({item.expr}{", asc=True" if item.asc is True else ""}{", asc=False" if item.asc is False else ""})'
55
+ for item in clause
56
+ ])
57
+
58
+
16
59
  class SqlNode(ExecNode):
17
- """Materializes data from the store via a Select stmt."""
60
+ """
61
+ Materializes data from the store via a Select stmt.
62
+ This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
63
+ """
64
+
65
+ tbl: Optional[catalog.TableVersionPath]
66
+ select_list: exprs.ExprSet
67
+ set_pk: bool
68
+ num_pk_cols: int
69
+ filter: Optional[exprs.Expr]
70
+ filter_eval_ctx: Optional[exprs.RowBuilder.EvalCtx]
71
+ cte: Optional[sql.CTE]
72
+ sql_elements: exprs.SqlElementCache
73
+ limit: Optional[int]
74
+ order_by_clause: OrderByClause
18
75
 
19
76
  def __init__(
20
- self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
77
+ self, tbl: Optional[catalog.TableVersionPath], row_builder: exprs.RowBuilder,
21
78
  select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
22
79
  ):
23
80
  """
24
- Initialize self.stmt with expressions derived from select_list.
25
-
26
- This only provides the select list. The subclass is responsible for the From clause and any additional clauses.
81
+ If row_builder contains references to unstored iter columns, expands the select list to include their
82
+ SQL-materializable subexpressions.
27
83
 
28
84
  Args:
29
85
  select_list: output of the query
30
86
  set_pk: if True, sets the primary for each DataRow
31
87
  """
32
88
  # create Select stmt
89
+ self.sql_elements = sql_elements
33
90
  self.tbl = tbl
34
- target = tbl.tbl_version # the stored table we're scanning
35
- self.sql_exprs = exprs.ExprSet(select_list)
91
+ self.select_list = exprs.ExprSet(select_list)
36
92
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
37
93
  for iter_arg in row_builder.unstored_iter_args.values():
38
- sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
94
+ sql_subexprs = iter_arg.subexprs(filter=self.sql_elements.contains, traverse_matches=False)
39
95
  for e in sql_subexprs:
40
- self.sql_exprs.add(e)
41
- super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
96
+ self.select_list.add(e)
97
+ super().__init__(row_builder, self.select_list, [], None) # we materialize self.select_list
42
98
 
43
- # change rowid refs against a base table to rowid refs against the target table, so that we minimize
44
- # the number of tables that need to be joined to the target table
45
- for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
46
- rowid_ref.set_tbl(tbl)
99
+ if tbl is not None:
100
+ # minimize the number of tables that need to be joined to the target table
101
+ self.retarget_rowid_refs(tbl, self.select_list)
47
102
 
48
- sql_select_list = [sql_elements.get(e) for e in self.sql_exprs]
49
- assert len(sql_select_list) == len(self.sql_exprs)
50
- assert all(e is not None for e in sql_select_list)
103
+ assert self.sql_elements.contains(self.select_list)
51
104
  self.set_pk = set_pk
52
105
  self.num_pk_cols = 0
53
106
  if set_pk:
54
107
  # we also need to retrieve the pk columns
55
- pk_columns = target.store_tbl.pk_columns()
56
- self.num_pk_cols = len(pk_columns)
57
- sql_select_list += pk_columns
58
-
59
- self.stmt = sql.select(*sql_select_list)
108
+ assert tbl is not None
109
+ self.num_pk_cols = len(tbl.tbl_version.store_tbl.pk_columns())
60
110
 
61
111
  # additional state
62
- self.result_cursor: Optional[sql.engine.CursorResult] = None
112
+ self.result_cursor = None
63
113
  # the filter is provided by the subclass
64
- self.filter: Optional[exprs.Expr] = None
65
- self.filter_eval_ctx: Optional[exprs.EvalContext] = None
114
+ self.filter = None
115
+ self.filter_eval_ctx = None
116
+ self.cte = None
117
+ self.limit = None
118
+ self.order_by_clause = []
119
+
120
+ def _create_stmt(self) -> sql.Select:
121
+ """Create Select from local state"""
122
+
123
+ assert self.sql_elements.contains(self.select_list)
124
+ sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
125
+ if self.set_pk:
126
+ sql_select_list += self.tbl.tbl_version.store_tbl.pk_columns()
127
+ stmt = sql.select(*sql_select_list)
128
+
129
+ order_by_clause: list[sql.ClauseElement] = []
130
+ for e, asc in self.order_by_clause:
131
+ if isinstance(e, exprs.SimilarityExpr):
132
+ order_by_clause.append(e.as_order_by_clause(asc))
133
+ else:
134
+ order_by_clause.append(self.sql_elements.get(e).desc() if asc is False else self.sql_elements.get(e))
135
+ stmt = stmt.order_by(*order_by_clause)
136
+
137
+ if self.filter is None and self.limit is not None:
138
+ # if we don't have a Python filter, we can apply the limit to stmt
139
+ stmt = stmt.limit(self.limit)
140
+
141
+ return stmt
142
+
143
+ def _ordering_tbl_ids(self) -> set[UUID]:
144
+ return exprs.Expr.list_tbl_ids(e for e, _ in self.order_by_clause)
145
+
146
+ def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
147
+ """
148
+ Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
149
+
150
+ Returns:
151
+ (CTE, dict from Expr to output column)
152
+ """
153
+ if self.filter is not None:
154
+ # the filter needs to run in Python
155
+ return None
156
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
157
+ if self.cte is None:
158
+ self.cte = self._create_stmt().cte()
159
+ assert len(self.cte.c) == len(self.select_list)
160
+ return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
161
+
162
+ @classmethod
163
+ def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
164
+ """Change rowid refs to point to target"""
165
+ for e in expr_seq:
166
+ if isinstance(e, exprs.RowidRef):
167
+ e.set_tbl(target)
66
168
 
67
169
  @classmethod
68
170
  def create_from_clause(
69
- cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[Set[UUID]] = None,
70
- exact_version_only: Optional[Set[UUID]] = None
171
+ cls, tbl: catalog.TableVersionPath, stmt: sql.Select, refd_tbl_ids: Optional[set[UUID]] = None,
172
+ exact_version_only: Optional[set[UUID]] = None
71
173
  ) -> sql.Select:
72
174
  """Add From clause to stmt for tables/views referenced by materialized_exprs
73
175
  Args:
@@ -85,7 +187,7 @@ class SqlNode(ExecNode):
85
187
  exact_version_only = {}
86
188
  candidates = tbl.get_tbl_versions()
87
189
  assert len(candidates) > 0
88
- joined_tbls: List[catalog.TableVersion] = [candidates[0]]
190
+ joined_tbls: list[catalog.TableVersion] = [candidates[0]]
89
191
  for tbl in candidates[1:]:
90
192
  if tbl.id in refd_tbl_ids:
91
193
  joined_tbls.append(tbl)
@@ -111,72 +213,97 @@ class SqlNode(ExecNode):
111
213
  prev_tbl = tbl
112
214
  return stmt
113
215
 
114
- def _log_explain(self, conn: sql.engine.Connection) -> None:
216
+ def add_order_by(self, ordering: OrderByClause) -> None:
217
+ """Add Order By clause to stmt"""
218
+ if self.tbl is not None:
219
+ # change rowid refs against a base table to rowid refs against the target table, so that we minimize
220
+ # the number of tables that need to be joined to the target table
221
+ self.retarget_rowid_refs(self.tbl, [e for e, _ in ordering])
222
+ combined = combine_order_by_clauses([self.order_by_clause, ordering])
223
+ assert combined is not None
224
+ self.order_by_clause = combined
225
+
226
+ def set_limit(self, limit: int) -> None:
227
+ self.limit = limit
228
+
229
+ def _log_explain(self, stmt: sql.Select) -> None:
115
230
  try:
116
231
  # don't set dialect=Env.get().engine.dialect: x % y turns into x %% y, which results in a syntax error
117
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
232
+ stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
118
233
  explain_result = self.ctx.conn.execute(sql.text(f'EXPLAIN {stmt_str}'))
119
234
  explain_str = '\n'.join([str(row) for row in explain_result])
120
235
  _logger.debug(f'SqlScanNode explain:\n{explain_str}')
121
236
  except Exception as e:
122
237
  _logger.warning(f'EXPLAIN failed')
123
238
 
124
- def __next__(self) -> DataRowBatch:
125
- if self.result_cursor is None:
126
- # run the query; do this here rather than in _open(), exceptions are only expected during iteration
127
- assert self.ctx.conn is not None
128
- try:
129
- self._log_explain(self.ctx.conn)
130
- with warnings.catch_warnings(record=True) as w:
131
- self.result_cursor = self.ctx.conn.execute(self.stmt)
132
- for warning in w:
133
- pass
134
- self.has_more_rows = True
135
- except Exception as e:
136
- self.has_more_rows = False
137
- raise e
138
-
139
- if not self.has_more_rows:
140
- raise StopIteration
141
-
142
- output_batch = DataRowBatch(self.tbl.tbl_version, self.row_builder)
143
- needs_row = True
144
- while self.ctx.batch_size == 0 or len(output_batch) < self.ctx.batch_size:
145
- try:
146
- sql_row = next(self.result_cursor)
147
- except StopIteration:
148
- self.has_more_rows = False
149
- break
239
+ def __iter__(self) -> Iterator[DataRowBatch]:
240
+ # run the query; do this here rather than in _open(), exceptions are only expected during iteration
241
+ assert self.ctx.conn is not None
242
+ try:
243
+ with warnings.catch_warnings(record=True) as w:
244
+ stmt = self._create_stmt()
245
+ try:
246
+ # log stmt, if possible
247
+ stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
248
+ _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
249
+ except Exception as e:
250
+ pass
251
+ self._log_explain(stmt)
252
+
253
+ result_cursor = self.ctx.conn.execute(stmt)
254
+ for warning in w:
255
+ pass
256
+ except Exception as e:
257
+ raise e
258
+
259
+ tbl_version = self.tbl.tbl_version if self.tbl is not None else None
260
+ output_batch = DataRowBatch(tbl_version, self.row_builder)
261
+ output_row: Optional[exprs.DataRow] = None
262
+ num_rows_returned = 0
150
263
 
151
- if needs_row:
152
- output_row = output_batch.add_row()
264
+ for sql_row in result_cursor:
265
+ output_row = output_batch.add_row(output_row)
266
+
267
+ # populate output_row
153
268
  if self.num_pk_cols > 0:
154
269
  output_row.set_pk(tuple(sql_row[-self.num_pk_cols:]))
155
270
  # copy the output of the SQL query into the output row
156
- for i, e in enumerate(self.sql_exprs):
271
+ for i, e in enumerate(self.select_list):
157
272
  slot_idx = e.slot_idx
158
- output_row[slot_idx] = sql_row[i]
273
+ # certain numerical operations can produce Decimals (eg, SUM(<int column>)); we need to convert them
274
+ if isinstance(sql_row[i], Decimal):
275
+ if e.col_type.is_int_type():
276
+ output_row[slot_idx] = int(sql_row[i])
277
+ elif e.col_type.is_float_type():
278
+ output_row[slot_idx] = float(sql_row[i])
279
+ else:
280
+ raise RuntimeError(f'Unexpected Decimal value for {e}')
281
+ else:
282
+ output_row[slot_idx] = sql_row[i]
283
+
159
284
  if self.filter is not None:
285
+ # evaluate filter
160
286
  self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
161
- if output_row[self.filter.slot_idx]:
162
- needs_row = True
163
- if self.limit > 0 and len(output_batch) >= self.limit:
164
- self.has_more_rows = False
165
- break
166
- else:
167
- # we re-use this row for the next sql row if it didn't pass the filter
168
- needs_row = False
169
- output_row.clear()
287
+ if self.filter is not None and not output_row[self.filter.slot_idx]:
288
+ # we re-use this row for the next sql row since it didn't pass the filter
289
+ output_row = output_batch.pop_row()
290
+ output_row.clear()
291
+ else:
292
+ # reset output_row in order to add new one
293
+ output_row = None
294
+ num_rows_returned += 1
170
295
 
171
- if not needs_row:
172
- # the last row didn't pass the filter
173
- assert self.filter is not None
174
- output_batch.pop_row()
296
+ if self.limit is not None and num_rows_returned == self.limit:
297
+ break
298
+
299
+ if self.ctx.batch_size > 0 and len(output_batch) == self.ctx.batch_size:
300
+ _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
301
+ yield output_batch
302
+ output_batch = DataRowBatch(tbl_version, self.row_builder)
175
303
 
176
- _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
177
- if len(output_batch) == 0:
178
- raise StopIteration
179
- return output_batch
304
+ if len(output_batch) > 0:
305
+ _logger.debug(f'SqlScanNode: returning {len(output_batch)} rows')
306
+ yield output_batch
180
307
 
181
308
  def _close(self) -> None:
182
309
  if self.result_cursor is not None:
@@ -189,12 +316,14 @@ class SqlScanNode(SqlNode):
189
316
 
190
317
  Supports filtering and ordering.
191
318
  """
319
+ where_clause: Optional[exprs.Expr]
320
+ exact_version_only: list[catalog.TableVersion]
321
+
192
322
  def __init__(
193
323
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
194
324
  select_list: Iterable[exprs.Expr],
195
325
  where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Expr] = None,
196
- order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
197
- limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
326
+ set_pk: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
198
327
  ):
199
328
  """
200
329
  Args:
@@ -208,52 +337,29 @@ class SqlScanNode(SqlNode):
208
337
  sql_elements = exprs.SqlElementCache()
209
338
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
210
339
  # create Select stmt
211
- if order_by_items is None:
212
- order_by_items = []
213
340
  if exact_version_only is None:
214
341
  exact_version_only = []
215
342
  target = tbl.tbl_version # the stored table we're scanning
216
343
  self.filter = filter
217
344
  self.filter_eval_ctx = \
218
345
  row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
219
- self.limit = limit
220
346
 
221
- where_clause_tbl_ids = where_clause.tbl_ids() if where_clause is not None else set()
222
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs) | where_clause_tbl_ids
223
- self.stmt = self.create_from_clause(
224
- tbl, self.stmt, refd_tbl_ids, exact_version_only={t.id for t in exact_version_only})
225
-
226
- # change rowid refs against a base table to rowid refs against the target table, so that we minimize
227
- # the number of tables that need to be joined to the target table
228
- for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
229
- rowid_ref.set_tbl(tbl)
230
- order_by_clause: List[sql.ClauseElement] = []
231
- for e, asc in order_by_items:
232
- if isinstance(e, exprs.SimilarityExpr):
233
- order_by_clause.append(e.as_order_by_clause(asc))
234
- else:
235
- order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
347
+ self.where_clause = where_clause
348
+ self.exact_version_only = exact_version_only
236
349
 
237
- if where_clause is not None:
238
- sql_where_clause = sql_elements.get(where_clause)
350
+ def _create_stmt(self) -> sql.Select:
351
+ stmt = super()._create_stmt()
352
+ where_clause_tbl_ids = self.where_clause.tbl_ids() if self.where_clause is not None else set()
353
+ refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | where_clause_tbl_ids | self._ordering_tbl_ids()
354
+ stmt = self.create_from_clause(
355
+ self.tbl, stmt, refd_tbl_ids, exact_version_only={t.id for t in self.exact_version_only})
356
+
357
+ if self.where_clause is not None:
358
+ sql_where_clause = self.sql_elements.get(self.where_clause)
239
359
  assert sql_where_clause is not None
240
- self.stmt = self.stmt.where(sql_where_clause)
241
- if len(order_by_clause) > 0:
242
- self.stmt = self.stmt.order_by(*order_by_clause)
243
- elif target.id in row_builder.unstored_iter_args:
244
- # we are referencing unstored iter columns from this view and try to order by our primary key,
245
- # which ensures that iterators will see monotonically increasing pos values
246
- self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
247
- if limit != 0 and self.filter is None:
248
- # if we need to do post-SQL filtering, we can't use LIMIT
249
- self.stmt = self.stmt.limit(limit)
360
+ stmt = stmt.where(sql_where_clause)
250
361
 
251
- try:
252
- # log stmt, if possible
253
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
254
- _logger.debug(f'SqlScanNode stmt:\n{stmt_str}')
255
- except Exception as e:
256
- pass
362
+ return stmt
257
363
 
258
364
 
259
365
  class SqlLookupNode(SqlNode):
@@ -261,8 +367,7 @@ class SqlLookupNode(SqlNode):
261
367
  Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
262
368
  """
263
369
 
264
- stmt: sql.Select
265
- where_clause: sql.ColumnElement[bool]
370
+ where_clause: sql.ColumnElement
266
371
 
267
372
  def __init__(
268
373
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
@@ -276,21 +381,45 @@ class SqlLookupNode(SqlNode):
276
381
  """
277
382
  sql_elements = exprs.SqlElementCache()
278
383
  super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
279
- target = tbl.tbl_version # the stored table we're scanning
280
- refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
281
- self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
282
384
  # Where clause: (key-col-1, key-col-2, ...) IN ((val-1, val-2, ...), ...)
283
385
  self.where_clause = sql.tuple_(*sa_key_cols).in_(key_vals)
284
- self.stmt = self.stmt.where(self.where_clause)
285
386
 
286
- if target.id in row_builder.unstored_iter_args:
287
- # we are referencing unstored iter columns from this view and try to order by our primary key,
288
- # which ensures that iterators will see monotonically increasing pos values
289
- self.stmt = self.stmt.order_by(*self.tbl.store_tbl.rowid_columns())
387
+ def _create_stmt(self) -> sql.Select:
388
+ stmt = super()._create_stmt()
389
+ refd_tbl_ids = exprs.Expr.list_tbl_ids(self.select_list) | self._ordering_tbl_ids()
390
+ stmt = self.create_from_clause(self.tbl, stmt, refd_tbl_ids)
391
+ stmt = stmt.where(self.where_clause)
392
+ return stmt
290
393
 
291
- try:
292
- # log stmt, if possible
293
- stmt_str = str(self.stmt.compile(compile_kwargs={'literal_binds': True}))
294
- _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
295
- except Exception as e:
296
- pass
394
+ class SqlAggregationNode(SqlNode):
395
+ """
396
+ Materializes data from the store via a Select stmt with a WHERE clause that matches a list of key values
397
+ """
398
+
399
+ group_by_items: Optional[list[exprs.Expr]]
400
+
401
+ def __init__(
402
+ self, row_builder: exprs.RowBuilder,
403
+ input: SqlNode,
404
+ select_list: Iterable[exprs.Expr],
405
+ group_by_items: Optional[list[exprs.Expr]] = None,
406
+ limit: Optional[int] = None, exact_version_only: Optional[list[catalog.TableVersion]] = None
407
+ ):
408
+ """
409
+ Args:
410
+ select_list: can contain calls to AggregateFunctions
411
+ group_by_items: list of expressions to group by
412
+ limit: max number of rows to return: None = no limit
413
+ """
414
+ _, input_col_map = input.to_cte()
415
+ sql_elements = exprs.SqlElementCache(input_col_map)
416
+ super().__init__(None, row_builder, select_list, sql_elements)
417
+ self.group_by_items = group_by_items
418
+
419
+ def _create_stmt(self) -> sql.Select:
420
+ stmt = super()._create_stmt()
421
+ if self.group_by_items is not None:
422
+ sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
423
+ assert all(e is not None for e in sql_group_by_items)
424
+ stmt = stmt.group_by(*sql_group_by_items)
425
+ return stmt
@@ -6,6 +6,7 @@ from .comparison import Comparison
6
6
  from .compound_predicate import CompoundPredicate
7
7
  from .data_row import DataRow
8
8
  from .expr import Expr
9
+ from .expr_dict import ExprDict
9
10
  from .expr_set import ExprSet
10
11
  from .function_call import FunctionCall
11
12
  from .in_predicate import InPredicate
@@ -33,29 +33,40 @@ class DataRow:
33
33
  - ImageType: PIL.Image.Image
34
34
  - VideoType: local path if available, otherwise url
35
35
  """
36
- def __init__(self, size: int, img_slot_idxs: List[int], media_slot_idxs: List[int], array_slot_idxs: List[int]):
37
- self.vals: List[Any] = [None] * size # either cell values or exceptions
38
- self.has_val = [False] * size
39
- self.excs: List[Optional[Exception]] = [None] * size
40
36
 
41
- # control structures that are shared across all DataRows in a batch
42
- self.img_slot_idxs = img_slot_idxs
43
- self.media_slot_idxs = media_slot_idxs # all media types aside from image
44
- self.array_slot_idxs = array_slot_idxs
37
+ vals: list[Any]
38
+ has_val: list[bool]
39
+ excs: list[Optional[Exception]]
45
40
 
46
- # the primary key of a store row is a sequence of ints (the number is different for table vs view)
47
- self.pk: Optional[Tuple[int, ...]] = None
41
+ # control structures that are shared across all DataRows in a batch
42
+ img_slot_idxs: list[int]
43
+ media_slot_idxs: list[int]
44
+ array_slot_idxs: list[int]
48
45
 
49
- # file_urls:
50
- # - stored url of file for media in vals[i]
51
- # - None if vals[i] is not media type
52
- # - not None if file_paths[i] is not None
53
- self.file_urls: List[Optional[str]] = [None] * size
46
+ # the primary key of a store row is a sequence of ints (the number is different for table vs view)
47
+ pk: Optional[tuple[int, ...]]
54
48
 
55
- # file_paths:
56
- # - local path of media file in vals[i]; points to the file cache if file_urls[i] is remote
57
- # - None if vals[i] is not a media type or if there is no local file yet for file_urls[i]
58
- self.file_paths: List[Optional[str]] = [None] * size
49
+ # file_urls:
50
+ # - stored url of file for media in vals[i]
51
+ # - None if vals[i] is not media type
52
+ # - not None if file_paths[i] is not None
53
+ file_urls: list[Optional[str]]
54
+
55
+ # file_paths:
56
+ # - local path of media file in vals[i]; points to the file cache if file_urls[i] is remote
57
+ # - None if vals[i] is not a media type or if there is no local file yet for file_urls[i]
58
+ file_paths: list[Optional[str]]
59
+
60
+ def __init__(self, size: int, img_slot_idxs: List[int], media_slot_idxs: List[int], array_slot_idxs: List[int]):
61
+ self.vals = [None] * size
62
+ self.has_val = [False] * size
63
+ self.excs = [None] * size
64
+ self.img_slot_idxs = img_slot_idxs
65
+ self.media_slot_idxs = media_slot_idxs
66
+ self.array_slot_idxs = array_slot_idxs
67
+ self.pk = None
68
+ self.file_urls = [None] * size
69
+ self.file_paths = [None] * size
59
70
 
60
71
  def clear(self) -> None:
61
72
  size = len(self.vals)
pixeltable/exprs/expr.py CHANGED
@@ -7,11 +7,11 @@ import inspect
7
7
  import json
8
8
  import sys
9
9
  import typing
10
- from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union, overload
10
+ from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, TypeVar, Union, overload, Iterable
11
11
  from uuid import UUID
12
12
 
13
13
  import sqlalchemy as sql
14
- from typing_extensions import Self
14
+ from typing_extensions import _AnnotatedAlias, Self
15
15
 
16
16
  import pixeltable
17
17
  import pixeltable.catalog as catalog
@@ -281,9 +281,10 @@ class Expr(abc.ABC):
281
281
  """
282
282
  Iterate over all subexprs, including self.
283
283
  """
284
- is_match = filter is None or filter(self)
285
- if expr_class is not None:
286
- is_match = is_match and isinstance(self, expr_class)
284
+ is_match = isinstance(self, expr_class) if expr_class is not None else True
285
+ # apply filter after checking for expr_class
286
+ if filter is not None and is_match:
287
+ is_match = filter(self)
287
288
  if not is_match or traverse_matches:
288
289
  for c in self.components:
289
290
  yield from c.subexprs(expr_class=expr_class, filter=filter, traverse_matches=traverse_matches)
@@ -292,7 +293,7 @@ class Expr(abc.ABC):
292
293
 
293
294
  @overload
294
295
  def list_subexprs(
295
- expr_list: list[Expr], *, filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
296
+ expr_list: Iterable[Expr], *, filter: Optional[Callable[[Expr], bool]] = None, traverse_matches: bool = True
296
297
  ) -> Iterator[Expr]: ...
297
298
 
298
299
  @overload
@@ -312,13 +313,11 @@ class Expr(abc.ABC):
312
313
 
313
314
  def _contains(self, cls: Optional[type[Expr]] = None, filter: Optional[Callable[[Expr], bool]] = None) -> bool:
314
315
  """
315
- Returns True if any subexpr is an instance of cls.
316
+ Returns True if any subexpr is an instance of cls and/or matches filter.
316
317
  """
317
- assert (cls is not None) != (filter is not None) # need one of them
318
- if cls is not None:
319
- filter = lambda e: isinstance(e, cls)
318
+ assert cls is not None or filter is not None
320
319
  try:
321
- _ = next(self.subexprs(filter=filter, traverse_matches=False))
320
+ _ = next(self.subexprs(expr_class=cls, filter=filter, traverse_matches=False))
322
321
  return True
323
322
  except StopIteration:
324
323
  return False
@@ -457,11 +456,13 @@ class Expr(abc.ABC):
457
456
  else:
458
457
  return InPredicate(self, value_set_literal=value_set)
459
458
 
460
- def astype(self, new_type: ts.ColumnType) -> 'pixeltable.exprs.TypeCast':
459
+ def astype(self, new_type: Union[ts.ColumnType, type, _AnnotatedAlias]) -> 'pixeltable.exprs.TypeCast':
461
460
  from pixeltable.exprs import TypeCast
462
- return TypeCast(self, new_type)
461
+ return TypeCast(self, ts.ColumnType.normalize_type(new_type))
463
462
 
464
- def apply(self, fn: Callable, *, col_type: Optional[ts.ColumnType] = None) -> 'pixeltable.exprs.FunctionCall':
463
+ def apply(self, fn: Callable, *, col_type: Union[ts.ColumnType, type, _AnnotatedAlias, None] = None) -> 'pixeltable.exprs.FunctionCall':
464
+ if col_type is not None:
465
+ col_type = ts.ColumnType.normalize_type(col_type)
465
466
  function = self._make_applicator_function(fn, col_type)
466
467
  # Return a `FunctionCall` obtained by passing this `Expr` to the new `function`.
467
468
  return function(self)