pixeltable 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (78) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +559 -134
  5. pixeltable/catalog/column.py +36 -32
  6. pixeltable/catalog/dir.py +1 -2
  7. pixeltable/catalog/globals.py +12 -0
  8. pixeltable/catalog/insertable_table.py +30 -25
  9. pixeltable/catalog/schema_object.py +9 -6
  10. pixeltable/catalog/table.py +334 -267
  11. pixeltable/catalog/table_version.py +358 -241
  12. pixeltable/catalog/table_version_handle.py +18 -2
  13. pixeltable/catalog/table_version_path.py +86 -16
  14. pixeltable/catalog/view.py +47 -23
  15. pixeltable/dataframe.py +198 -19
  16. pixeltable/env.py +6 -4
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/__init__.py +1 -1
  19. pixeltable/exec/exec_node.py +2 -0
  20. pixeltable/exec/expr_eval/evaluators.py +4 -1
  21. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  22. pixeltable/exec/in_memory_data_node.py +1 -1
  23. pixeltable/exec/sql_node.py +188 -22
  24. pixeltable/exprs/column_property_ref.py +16 -6
  25. pixeltable/exprs/column_ref.py +33 -11
  26. pixeltable/exprs/comparison.py +1 -1
  27. pixeltable/exprs/data_row.py +5 -3
  28. pixeltable/exprs/expr.py +11 -4
  29. pixeltable/exprs/literal.py +2 -0
  30. pixeltable/exprs/row_builder.py +4 -6
  31. pixeltable/exprs/rowid_ref.py +8 -0
  32. pixeltable/exprs/similarity_expr.py +1 -0
  33. pixeltable/func/__init__.py +1 -0
  34. pixeltable/func/mcp.py +74 -0
  35. pixeltable/func/query_template_function.py +5 -3
  36. pixeltable/func/tools.py +12 -2
  37. pixeltable/func/udf.py +2 -2
  38. pixeltable/functions/__init__.py +1 -0
  39. pixeltable/functions/anthropic.py +19 -45
  40. pixeltable/functions/deepseek.py +19 -38
  41. pixeltable/functions/fireworks.py +9 -18
  42. pixeltable/functions/gemini.py +2 -3
  43. pixeltable/functions/groq.py +108 -0
  44. pixeltable/functions/llama_cpp.py +6 -6
  45. pixeltable/functions/mistralai.py +16 -53
  46. pixeltable/functions/ollama.py +1 -1
  47. pixeltable/functions/openai.py +82 -165
  48. pixeltable/functions/string.py +212 -58
  49. pixeltable/functions/together.py +22 -80
  50. pixeltable/globals.py +10 -4
  51. pixeltable/index/base.py +5 -0
  52. pixeltable/index/btree.py +5 -0
  53. pixeltable/index/embedding_index.py +5 -0
  54. pixeltable/io/external_store.py +10 -31
  55. pixeltable/io/label_studio.py +5 -5
  56. pixeltable/io/parquet.py +2 -2
  57. pixeltable/io/table_data_conduit.py +1 -32
  58. pixeltable/metadata/__init__.py +11 -2
  59. pixeltable/metadata/converters/convert_13.py +2 -2
  60. pixeltable/metadata/converters/convert_30.py +6 -11
  61. pixeltable/metadata/converters/convert_35.py +9 -0
  62. pixeltable/metadata/converters/convert_36.py +38 -0
  63. pixeltable/metadata/converters/convert_37.py +15 -0
  64. pixeltable/metadata/converters/util.py +3 -9
  65. pixeltable/metadata/notes.py +3 -0
  66. pixeltable/metadata/schema.py +13 -1
  67. pixeltable/plan.py +135 -12
  68. pixeltable/share/packager.py +138 -14
  69. pixeltable/share/publish.py +2 -2
  70. pixeltable/store.py +19 -13
  71. pixeltable/type_system.py +30 -0
  72. pixeltable/utils/dbms.py +1 -1
  73. pixeltable/utils/formatter.py +64 -42
  74. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/METADATA +2 -1
  75. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/RECORD +78 -73
  76. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/LICENSE +0 -0
  77. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/WHEEL +0 -0
  78. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@ from .exec_node import ExecNode
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  import pixeltable.plan
17
+ from pixeltable.plan import SampleClause
17
18
 
18
19
  _logger = logging.getLogger('pixeltable')
19
20
 
@@ -64,8 +65,12 @@ def print_order_by_clause(clause: OrderByClause) -> str:
64
65
 
65
66
  class SqlNode(ExecNode):
66
67
  """
67
- Materializes data from the store via a Select stmt.
68
+ Materializes data from the store via an SQL statement.
68
69
  This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
70
+ The pk columns are not included in the select list.
71
+ If set_pk is True, they are added to the end of the result set when creating the SQL statement
72
+ so they can always be referenced as cols[-num_pk_cols:] in the result set.
73
+ The pk_columns consist of the rowid columns of the target table followed by the version number.
69
74
  """
70
75
 
71
76
  tbl: Optional[catalog.TableVersionPath]
@@ -122,6 +127,7 @@ class SqlNode(ExecNode):
122
127
  # we also need to retrieve the pk columns
123
128
  assert tbl is not None
124
129
  self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
130
+ assert self.num_pk_cols > 1
125
131
 
126
132
  # additional state
127
133
  self.result_cursor = None
@@ -134,14 +140,25 @@ class SqlNode(ExecNode):
134
140
  self.where_clause_element = None
135
141
  self.order_by_clause = []
136
142
 
143
+ if self.tbl is not None:
144
+ tv = self.tbl.tbl_version._tbl_version
145
+ if tv is not None:
146
+ assert tv.is_validated
147
+
148
+ def _create_pk_cols(self) -> list[sql.Column]:
149
+ """Create a list of pk columns"""
150
+ # we need to retrieve the pk columns
151
+ if self.set_pk:
152
+ assert self.tbl is not None
153
+ assert self.tbl.tbl_version.get().is_validated
154
+ return self.tbl.tbl_version.get().store_tbl.pk_columns()
155
+ return []
156
+
137
157
  def _create_stmt(self) -> sql.Select:
138
158
  """Create Select from local state"""
139
159
 
140
160
  assert self.sql_elements.contains_all(self.select_list)
141
- sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
142
- if self.set_pk:
143
- assert self.tbl is not None
144
- sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
161
+ sql_select_list = [self.sql_elements.get(e) for e in self.select_list] + self._create_pk_cols()
145
162
  stmt = sql.select(*sql_select_list)
146
163
 
147
164
  where_clause_element = (
@@ -167,9 +184,10 @@ class SqlNode(ExecNode):
167
184
  def _ordering_tbl_ids(self) -> set[UUID]:
168
185
  return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
169
186
 
170
- def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
187
+ def to_cte(self, keep_pk: bool = False) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
171
188
  """
172
- Returns a CTE that materializes the output of this node plus a mapping from select list expr to output column
189
+ Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
190
+ keep_pk: if True, the PK columns are included in the CTE Select statement
173
191
 
174
192
  Returns:
175
193
  (CTE, dict from Expr to output column)
@@ -177,11 +195,13 @@ class SqlNode(ExecNode):
177
195
  if self.py_filter is not None:
178
196
  # the filter needs to run in Python
179
197
  return None
180
- self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
181
198
  if self.cte is None:
199
+ if not keep_pk:
200
+ self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
182
201
  self.cte = self._create_stmt().cte()
183
- assert len(self.cte.c) == len(self.select_list)
184
- return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c))
202
+ pk_count = self.num_pk_cols if self.set_pk else 0
203
+ assert len(self.select_list) + pk_count == len(self.cte.c)
204
+ return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c)) # skip pk cols
185
205
 
186
206
  @classmethod
187
207
  def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
@@ -220,26 +240,29 @@ class SqlNode(ExecNode):
220
240
  joined_tbls.append(t)
221
241
 
222
242
  first = True
223
- prev_tbl: Optional[catalog.TableVersionHandle] = None
243
+ prev_tv: Optional[catalog.TableVersion] = None
224
244
  for t in joined_tbls[::-1]:
245
+ tv = t.get()
246
+ # _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
225
247
  if first:
226
- stmt = stmt.select_from(t.get().store_tbl.sa_tbl)
248
+ stmt = stmt.select_from(tv.store_tbl.sa_tbl)
227
249
  first = False
228
250
  else:
229
- # join tbl to prev_tbl on prev_tbl's rowid cols
230
- prev_tbl_rowid_cols = prev_tbl.get().store_tbl.rowid_columns()
231
- tbl_rowid_cols = t.get().store_tbl.rowid_columns()
251
+ # join tv to prev_tv on prev_tv's rowid cols
252
+ prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
253
+ tbl_rowid_cols = tv.store_tbl.rowid_columns()
232
254
  rowid_clauses = [
233
255
  c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
234
256
  ]
235
- stmt = stmt.join(t.get().store_tbl.sa_tbl, sql.and_(*rowid_clauses))
257
+ stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
258
+
236
259
  if t.id in exact_version_only:
237
- stmt = stmt.where(t.get().store_tbl.v_min_col == t.get().version)
260
+ stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
238
261
  else:
239
- stmt = stmt.where(t.get().store_tbl.v_min_col <= t.get().version).where(
240
- t.get().store_tbl.v_max_col > t.get().version
241
- )
242
- prev_tbl = t
262
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
263
+ stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
264
+ prev_tv = tv
265
+
243
266
  return stmt
244
267
 
245
268
  def set_where(self, where_clause: exprs.Expr) -> None:
@@ -284,7 +307,8 @@ class SqlNode(ExecNode):
284
307
  stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
285
308
  _logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
286
309
  except Exception:
287
- pass
310
+ # log something if we can't log the compiled stmt
311
+ _logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
288
312
  self._log_explain(stmt)
289
313
 
290
314
  conn = Env.get().conn
@@ -501,3 +525,145 @@ class SqlJoinNode(SqlNode):
501
525
  full=join_clause == plan.JoinType.FULL_OUTER,
502
526
  )
503
527
  return stmt
528
+
529
+
530
+ class SqlSampleNode(SqlNode):
531
+ """
532
+ Returns rows sampled from the input node.
533
+ """
534
+
535
+ input_cte: Optional[sql.CTE]
536
+ pk_count: int
537
+ stratify_exprs: Optional[list[exprs.Expr]]
538
+ sample_clause: 'SampleClause'
539
+
540
+ def __init__(
541
+ self,
542
+ row_builder: exprs.RowBuilder,
543
+ input: SqlNode,
544
+ select_list: Iterable[exprs.Expr],
545
+ sample_clause: 'SampleClause',
546
+ stratify_exprs: list[exprs.Expr],
547
+ ):
548
+ """
549
+ Args:
550
+ input: SqlNode to sample from
551
+ select_list: can contain calls to AggregateFunctions
552
+ sample_clause: specifies the sampling method
553
+ stratify_exprs: Analyzer processed list of expressions to stratify by.
554
+ """
555
+ assert isinstance(input, SqlNode)
556
+ self.input_cte, input_col_map = input.to_cte(keep_pk=True)
557
+ self.pk_count = input.num_pk_cols
558
+ assert self.pk_count > 1
559
+ sql_elements = exprs.SqlElementCache(input_col_map)
560
+ assert sql_elements.contains_all(stratify_exprs)
561
+ super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
562
+ self.stratify_exprs = stratify_exprs
563
+ self.sample_clause = sample_clause
564
+ assert isinstance(self.sample_clause.seed, int)
565
+
566
+ @classmethod
567
+ def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
568
+ """Construct expression which is the ordering key for rows to be sampled
569
+ General SQL form is:
570
+ - MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
571
+ """
572
+ sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
573
+ for e in sql_cols:
574
+ # Quotes are required below to guarantee that the string is properly presented in SQL
575
+ sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + sql.cast(e, sql.Text)
576
+ sql_expr = sql.func.md5(sql_expr)
577
+ return sql_expr
578
+
579
+ def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
580
+ """Create an expression for randomly ordering rows with a given seed"""
581
+ rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
582
+ assert len(rowid_cols) > 0
583
+ return self.key_sql_expr(sql.literal_column(str(self.sample_clause.seed)), rowid_cols)
584
+
585
+ def _create_stmt(self) -> sql.Select:
586
+ from pixeltable.plan import SampleClause
587
+
588
+ if self.sample_clause.fraction is not None:
589
+ if len(self.stratify_exprs) == 0:
590
+ # If non-stratified sampling, construct a where clause, order_by, and limit clauses
591
+ s_key = self._create_key_sql(self.input_cte)
592
+
593
+ # Construct a suitable where clause
594
+ fraction_sql = sql.cast(SampleClause.fraction_to_md5_hex(float(self.sample_clause.fraction)), sql.Text)
595
+ order_by = self._create_key_sql(self.input_cte)
596
+ return sql.select(*self.input_cte.c).where(s_key < fraction_sql).order_by(order_by)
597
+
598
+ return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
599
+ else:
600
+ if len(self.stratify_exprs) == 0:
601
+ # No stratification, just return n samples from the input CTE
602
+ order_by = self._create_key_sql(self.input_cte)
603
+ return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
604
+
605
+ return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
606
+
607
+ def _create_stmt_stratified_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
608
+ """Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
609
+
610
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
611
+ order_by = self._create_key_sql(self.input_cte)
612
+
613
+ # Create a list of all columns plus the rank
614
+ # Get all columns from the input CTE dynamically
615
+ select_columns = [*self.input_cte.c]
616
+ select_columns.append(
617
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
618
+ )
619
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
620
+
621
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
622
+ if n_per_stratum is not None:
623
+ return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
624
+ else:
625
+ secondary_order = self._create_key_sql(row_rank_cte)
626
+ return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
627
+
628
+ def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
629
+ """Create a Select stmt that returns a fraction of the rows per strata"""
630
+
631
+ # Build the strata count CTE
632
+ # Produces a table of the form:
633
+ # (*stratify_exprs, s_s_size)
634
+ # where s_s_size is the number of samples to take from each stratum
635
+ sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
636
+ per_strata_count_cte = (
637
+ sql.select(
638
+ *sql_strata_exprs,
639
+ sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
640
+ )
641
+ .select_from(self.input_cte)
642
+ .group_by(*sql_strata_exprs)
643
+ .cte('per_strata_count_cte')
644
+ )
645
+
646
+ # Build a CTE that ranks the rows within each stratum
647
+ # Include all columns from the input CTE dynamically
648
+ order_by = self._create_key_sql(self.input_cte)
649
+ select_columns = [*self.input_cte.c]
650
+ select_columns.append(
651
+ sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
652
+ )
653
+ row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
654
+
655
+ # Build the join criterion dynamically to accommodate any number of stratify_by expressions
656
+ join_c = sql.true()
657
+ for col in per_strata_count_cte.c[:-1]:
658
+ join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
659
+
660
+ # Join with per_strata_count_cte to limit returns to the requested fraction of rows
661
+ final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
662
+ stmt = (
663
+ sql.select(*final_columns)
664
+ .select_from(row_rank_cte)
665
+ .join(per_strata_count_cte, join_c)
666
+ .where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
667
+ )
668
+
669
+ return stmt
@@ -58,20 +58,30 @@ class ColumnPropertyRef(Expr):
58
58
  if not self._col_ref.col.is_stored:
59
59
  return None
60
60
 
61
+ # we need to reestablish that we have the correct Column instance, there could have been a metadata
62
+ # reload since init()
63
+ # TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
64
+ # perform runtime checks and update state
65
+ tv = self._col_ref.tbl_version.get()
66
+ assert tv.is_validated
67
+ # we can assume at this point during query execution that the column exists
68
+ assert self._col_ref.col_id in tv.cols_by_id
69
+ col = tv.cols_by_id[self._col_ref.col_id]
70
+
61
71
  # the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
62
72
  if (
63
- self._col_ref.col.col_type.is_media_type()
64
- and self._col_ref.col.media_validation == catalog.MediaValidation.ON_READ
73
+ col.col_type.is_media_type()
74
+ and col.media_validation == catalog.MediaValidation.ON_READ
65
75
  and self.is_error_prop()
66
76
  ):
67
77
  return None
68
78
 
69
79
  if self.prop == self.Property.ERRORTYPE:
70
- assert self._col_ref.col.sa_errortype_col is not None
71
- return self._col_ref.col.sa_errortype_col
80
+ assert col.sa_errortype_col is not None
81
+ return col.sa_errortype_col
72
82
  if self.prop == self.Property.ERRORMSG:
73
- assert self._col_ref.col.sa_errormsg_col is not None
74
- return self._col_ref.col.sa_errormsg_col
83
+ assert col.sa_errormsg_col is not None
84
+ return col.sa_errormsg_col
75
85
  if self.prop == self.Property.FILEURL:
76
86
  # the file url is stored as the column value
77
87
  return sql_elements.get(self._col_ref)
@@ -52,6 +52,10 @@ class ColumnRef(Expr):
52
52
  id: int
53
53
  perform_validation: bool # if True, performs media validation
54
54
 
55
+ # needed by sql_expr() to re-resolve Column instance after a metadata reload
56
+ tbl_version: catalog.TableVersionHandle
57
+ col_id: int
58
+
55
59
  def __init__(
56
60
  self,
57
61
  col: catalog.Column,
@@ -62,16 +66,17 @@ class ColumnRef(Expr):
62
66
  assert col.tbl is not None
63
67
  self.col = col
64
68
  self.reference_tbl = reference_tbl
65
- self.is_unstored_iter_col = (
66
- col.tbl.get().is_component_view and col.tbl.get().is_iterator_column(col) and not col.is_stored
67
- )
69
+ self.tbl_version = catalog.TableVersionHandle(col.tbl.id, col.tbl.effective_version)
70
+ self.col_id = col.id
71
+
72
+ self.is_unstored_iter_col = col.tbl.is_component_view and col.tbl.is_iterator_column(col) and not col.is_stored
68
73
  self.iter_arg_ctx = None
69
74
  # number of rowid columns in the base table
70
- self.base_rowid_len = col.tbl.get().base.get().num_rowid_columns() if self.is_unstored_iter_col else 0
75
+ self.base_rowid_len = col.tbl.base.get().num_rowid_columns() if self.is_unstored_iter_col else 0
71
76
  self.base_rowid = [None] * self.base_rowid_len
72
77
  self.iterator = None
73
78
  # index of the position column in the view's primary key; don't try to reference tbl.store_tbl here
74
- self.pos_idx = col.tbl.get().num_rowid_columns() - 1 if self.is_unstored_iter_col else None
79
+ self.pos_idx = col.tbl.num_rowid_columns() - 1 if self.is_unstored_iter_col else None
75
80
 
76
81
  self.perform_validation = False
77
82
  if col.col_type.is_media_type():
@@ -175,7 +180,7 @@ class ColumnRef(Expr):
175
180
  assert len(idx_info) == 1
176
181
  col = copy.copy(next(iter(idx_info.values())).val_col)
177
182
  col.name = f'{self.col.name}_embedding_{idx if idx is not None else ""}'
178
- col.create_sa_cols()
183
+ # col.create_sa_cols()
179
184
  return ColumnRef(col)
180
185
 
181
186
  def default_column_name(self) -> Optional[str]:
@@ -226,7 +231,7 @@ class ColumnRef(Expr):
226
231
  def _descriptors(self) -> DescriptionHelper:
227
232
  tbl = catalog.Catalog.get().get_table_by_id(self.col.tbl.id)
228
233
  helper = DescriptionHelper()
229
- helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
234
+ helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path()!r})')
230
235
  helper.append(tbl._col_descriptor([self.col.name]))
231
236
  idxs = tbl._index_descriptor([self.col.name])
232
237
  if len(idxs) > 0:
@@ -234,7 +239,19 @@ class ColumnRef(Expr):
234
239
  return helper
235
240
 
236
241
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
237
- return None if self.perform_validation else self.col.sa_col
242
+ if self.perform_validation:
243
+ return None
244
+ # we need to reestablish that we have the correct Column instance, there could have been a metadata
245
+ # reload since init()
246
+ # TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
247
+ # perform runtime checks and update state
248
+ tv = self.tbl_version.get()
249
+ assert tv.is_validated
250
+ # we can assume at this point during query execution that the column exists
251
+ assert self.col_id in tv.cols_by_id
252
+ self.col = tv.cols_by_id[self.col_id]
253
+ assert self.col.tbl is tv
254
+ return self.col.sa_col
238
255
 
239
256
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
240
257
  if self.perform_validation:
@@ -275,7 +292,7 @@ class ColumnRef(Expr):
275
292
  if self.base_rowid != data_row.pk[: self.base_rowid_len]:
276
293
  row_builder.eval(data_row, self.iter_arg_ctx)
277
294
  iterator_args = data_row[self.iter_arg_ctx.target_slot_idxs[0]]
278
- self.iterator = self.col.tbl.get().iterator_cls(**iterator_args)
295
+ self.iterator = self.col.tbl.iterator_cls(**iterator_args)
279
296
  self.base_rowid = data_row.pk[: self.base_rowid_len]
280
297
  self.iterator.set_pos(data_row.pk[self.pos_idx])
281
298
  res = next(self.iterator)
@@ -283,17 +300,22 @@ class ColumnRef(Expr):
283
300
 
284
301
  def _as_dict(self) -> dict:
285
302
  tbl = self.col.tbl
286
- tbl_version = tbl.get().version if tbl.get().is_snapshot else None
303
+ version = tbl.version if tbl.is_snapshot else None
287
304
  # we omit self.components, even if this is a validating ColumnRef, because init() will recreate the
288
305
  # non-validating component ColumnRef
289
306
  return {
290
307
  'tbl_id': str(tbl.id),
291
- 'tbl_version': tbl_version,
308
+ 'tbl_version': version,
292
309
  'col_id': self.col.id,
293
310
  'reference_tbl': self.reference_tbl.as_dict() if self.reference_tbl is not None else None,
294
311
  'perform_validation': self.perform_validation,
295
312
  }
296
313
 
314
+ @classmethod
315
+ def get_column_id(cls, d: dict) -> catalog.QColumnId:
316
+ tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
317
+ return catalog.QColumnId(tbl_id, col_id)
318
+
297
319
  @classmethod
298
320
  def get_column(cls, d: dict) -> catalog.Column:
299
321
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
@@ -81,7 +81,7 @@ class Comparison(Expr):
81
81
  if self.is_search_arg_comparison:
82
82
  # reference the index value column if there is an index and this is not a snapshot
83
83
  # (indices don't apply to snapshots)
84
- tbl = self._op1.col.tbl.get()
84
+ tbl = self._op1.col.tbl
85
85
  idx_info = [
86
86
  info for info in self._op1.col.get_idx_info().values() if isinstance(info.idx, index.BtreeIndex)
87
87
  ]
@@ -214,6 +214,7 @@ class DataRow:
214
214
  """Assign in-memory cell value
215
215
  This allows overwriting
216
216
  """
217
+ assert isinstance(idx, int)
217
218
  assert self.excs[idx] is None
218
219
 
219
220
  if (idx in self.img_slot_idxs or idx in self.media_slot_idxs) and isinstance(val, str):
@@ -253,14 +254,15 @@ class DataRow:
253
254
  assert self.excs[index] is None
254
255
  if self.file_paths[index] is None:
255
256
  if filepath is not None:
256
- # we want to save this to a file
257
- self.file_paths[index] = filepath
258
- self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
259
257
  image = self.vals[index]
260
258
  assert isinstance(image, PIL.Image.Image)
261
259
  # Default to JPEG unless the image has a transparency layer (which isn't supported by JPEG).
262
260
  # In that case, use WebP instead.
263
261
  format = 'webp' if image.has_transparency_data else 'jpeg'
262
+ if not filepath.endswith(f'.{format}'):
263
+ filepath += f'.{format}'
264
+ self.file_paths[index] = filepath
265
+ self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
264
266
  image.save(filepath, format=format)
265
267
  else:
266
268
  # we discard the content of this cell
pixeltable/exprs/expr.py CHANGED
@@ -276,6 +276,13 @@ class Expr(abc.ABC):
276
276
  tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
277
277
  return self._retarget(tbl_versions)
278
278
 
279
+ @classmethod
280
+ def retarget_list(cls, expr_list: list[Expr], tbl: catalog.TableVersionPath) -> None:
281
+ """Retarget ColumnRefs in expr_list to the specific TableVersions in tbl."""
282
+ tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
283
+ for i in range(len(expr_list)):
284
+ expr_list[i] = expr_list[i]._retarget(tbl_versions)
285
+
279
286
  def _retarget(self, tbl_versions: dict[UUID, catalog.TableVersion]) -> Self:
280
287
  for i in range(len(self.components)):
281
288
  self.components[i] = self.components[i]._retarget(tbl_versions)
@@ -387,17 +394,17 @@ class Expr(abc.ABC):
387
394
  return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
388
395
 
389
396
  @classmethod
390
- def get_refd_columns(cls, expr_dict: dict[str, Any]) -> list[catalog.Column]:
397
+ def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
391
398
  """Return Columns referenced by expr_dict."""
392
- result: list[catalog.Column] = []
399
+ result: set[catalog.QColumnId] = set()
393
400
  assert '_classname' in expr_dict
394
401
  from .column_ref import ColumnRef
395
402
 
396
403
  if expr_dict['_classname'] == 'ColumnRef':
397
- result.append(ColumnRef.get_column(expr_dict))
404
+ result.add(ColumnRef.get_column_id(expr_dict))
398
405
  if 'components' in expr_dict:
399
406
  for component_dict in expr_dict['components']:
400
- result.extend(cls.get_refd_columns(component_dict))
407
+ result.update(cls.get_refd_column_ids(component_dict))
401
408
  return result
402
409
 
403
410
  def as_literal(self) -> Optional[Expr]:
@@ -16,6 +16,8 @@ from .sql_element_cache import SqlElementCache
16
16
 
17
17
 
18
18
  class Literal(Expr):
19
+ val: Any
20
+
19
21
  def __init__(self, val: Any, col_type: Optional[ts.ColumnType] = None):
20
22
  if col_type is not None:
21
23
  val = col_type.create_literal(val)
@@ -172,13 +172,11 @@ class RowBuilder:
172
172
 
173
173
  def refs_unstored_iter_col(col_ref: ColumnRef) -> bool:
174
174
  tbl = col_ref.col.tbl
175
- return (
176
- tbl.get().is_component_view and tbl.get().is_iterator_column(col_ref.col) and not col_ref.col.is_stored
177
- )
175
+ return tbl.is_component_view and tbl.is_iterator_column(col_ref.col) and not col_ref.col.is_stored
178
176
 
179
177
  unstored_iter_col_refs = [col_ref for col_ref in col_refs if refs_unstored_iter_col(col_ref)]
180
178
  component_views = [col_ref.col.tbl for col_ref in unstored_iter_col_refs]
181
- unstored_iter_args = {view.id: view.get().iterator_args.copy() for view in component_views}
179
+ unstored_iter_args = {view.id: view.iterator_args.copy() for view in component_views}
182
180
  self.unstored_iter_args = {
183
181
  id: self._record_unique_expr(arg, recursive=True) for id, arg in unstored_iter_args.items()
184
182
  }
@@ -450,9 +448,9 @@ class RowBuilder:
450
448
  else:
451
449
  if col.col_type.is_image_type() and data_row.file_urls[slot_idx] is None:
452
450
  # we have yet to store this image
453
- filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.get().version))
451
+ filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
454
452
  data_row.flush_img(slot_idx, filepath)
455
- val = data_row.get_stored_val(slot_idx, col.sa_col.type)
453
+ val = data_row.get_stored_val(slot_idx, col.get_sa_col_type())
456
454
  table_row[col.store_name()] = val
457
455
  # we unfortunately need to set these, even if there are no errors
458
456
  table_row[col.errortype_store_name()] = None
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  from typing import Any, Optional, cast
4
5
  from uuid import UUID
5
6
 
@@ -12,6 +13,8 @@ from .expr import Expr
12
13
  from .row_builder import RowBuilder
13
14
  from .sql_element_cache import SqlElementCache
14
15
 
16
+ _logger = logging.getLogger('pixeltable')
17
+
15
18
 
16
19
  class RowidRef(Expr):
17
20
  """A reference to a part of a table rowid
@@ -97,10 +100,15 @@ class RowidRef(Expr):
97
100
 
98
101
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
99
102
  tbl = self.tbl.get() if self.tbl is not None else catalog.Catalog.get().get_tbl_version(self.tbl_id, None)
103
+ assert tbl.is_validated
100
104
  rowid_cols = tbl.store_tbl.rowid_columns()
101
105
  assert self.rowid_component_idx <= len(rowid_cols), (
102
106
  f'{self.rowid_component_idx} not consistent with {rowid_cols}'
103
107
  )
108
+ # _logger.debug(
109
+ # f'RowidRef.sql_expr: tbl={tbl.id}{tbl.effective_version} sa_tbl={id(tbl.store_tbl.sa_tbl):x} '
110
+ # f'tv={id(tbl):x}'
111
+ # )
104
112
  return rowid_cols[self.rowid_component_idx]
105
113
 
106
114
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -54,6 +54,7 @@ class SimilarityExpr(Expr):
54
54
  return 'similarity'
55
55
 
56
56
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
57
+ # TODO: validate that the index still exists
57
58
  if not isinstance(self.components[1], Literal):
58
59
  raise excs.Error('similarity(): requires a string or a PIL.Image.Image object, not an expression')
59
60
  item = self.components[1].val
@@ -5,6 +5,7 @@ from .callable_function import CallableFunction
5
5
  from .expr_template_function import ExprTemplateFunction
6
6
  from .function import Function, InvalidFunction
7
7
  from .function_registry import FunctionRegistry
8
+ from .mcp import mcp_udfs
8
9
  from .query_template_function import QueryTemplateFunction, query, retrieval_udf
9
10
  from .signature import Batch, Parameter, Signature
10
11
  from .tools import Tool, ToolChoice, Tools
pixeltable/func/mcp.py ADDED
@@ -0,0 +1,74 @@
1
+ import asyncio
2
+ import inspect
3
+ from typing import TYPE_CHECKING, Any, Optional
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable import exceptions as excs, type_system as ts
7
+ from pixeltable.func.signature import Parameter
8
+
9
+ if TYPE_CHECKING:
10
+ import mcp
11
+
12
+
13
+ def mcp_udfs(url: str) -> list['pxt.func.Function']:
14
+ return asyncio.run(mcp_udfs_async(url))
15
+
16
+
17
+ async def mcp_udfs_async(url: str) -> list['pxt.func.Function']:
18
+ import mcp
19
+ from mcp.client.streamable_http import streamablehttp_client
20
+
21
+ list_tools_result: Optional[mcp.types.ListToolsResult] = None
22
+ async with (
23
+ streamablehttp_client(url) as (read_stream, write_stream, _),
24
+ mcp.ClientSession(read_stream, write_stream) as session,
25
+ ):
26
+ await session.initialize()
27
+ list_tools_result = await session.list_tools()
28
+ assert list_tools_result is not None
29
+
30
+ return [mcp_tool_to_udf(url, tool) for tool in list_tools_result.tools]
31
+
32
+
33
+ def mcp_tool_to_udf(url: str, mcp_tool: 'mcp.types.Tool') -> 'pxt.func.Function':
34
+ import mcp
35
+ from mcp.client.streamable_http import streamablehttp_client
36
+
37
+ async def invoke(**kwargs: Any) -> str:
38
+ # TODO: Cache session objects rather than creating a new one each time?
39
+ async with (
40
+ streamablehttp_client(url) as (read_stream, write_stream, _),
41
+ mcp.ClientSession(read_stream, write_stream) as session,
42
+ ):
43
+ await session.initialize()
44
+ res = await session.call_tool(name=mcp_tool.name, arguments=kwargs)
45
+ # TODO Handle image/audio responses?
46
+ return res.content[0].text # type: ignore[union-attr]
47
+
48
+ if mcp_tool.description is not None:
49
+ invoke.__doc__ = mcp_tool.description
50
+
51
+ input_schema = mcp_tool.inputSchema
52
+ params = {
53
+ name: __mcp_param_to_pxt_type(mcp_tool.name, name, param) for name, param in input_schema['properties'].items()
54
+ }
55
+ required = input_schema.get('required', [])
56
+
57
+ # Ensure that any params not appearing in `required` are nullable.
58
+ # (A required param might or might not be nullable, since its type might be an 'anyOf' containing a null.)
59
+ for name in params.keys() - required:
60
+ params[name] = params[name].copy(nullable=True)
61
+
62
+ signature = pxt.func.Signature(
63
+ return_type=ts.StringType(), # Return type is always string
64
+ parameters=[Parameter(name, col_type, inspect.Parameter.KEYWORD_ONLY) for name, col_type in params.items()],
65
+ )
66
+
67
+ return pxt.func.CallableFunction(signatures=[signature], py_fns=[invoke], self_name=mcp_tool.name)
68
+
69
+
70
+ def __mcp_param_to_pxt_type(tool_name: str, name: str, param: dict[str, Any]) -> ts.ColumnType:
71
+ pxt_type = ts.ColumnType.from_json_schema(param)
72
+ if pxt_type is None:
73
+ raise excs.Error(f'Unknown type schema for MCP parameter {name!r} of tool {tool_name!r}: {param}')
74
+ return pxt_type