pixeltable 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +9 -1
- pixeltable/catalog/catalog.py +559 -134
- pixeltable/catalog/column.py +36 -32
- pixeltable/catalog/dir.py +1 -2
- pixeltable/catalog/globals.py +12 -0
- pixeltable/catalog/insertable_table.py +30 -25
- pixeltable/catalog/schema_object.py +9 -6
- pixeltable/catalog/table.py +334 -267
- pixeltable/catalog/table_version.py +360 -241
- pixeltable/catalog/table_version_handle.py +18 -2
- pixeltable/catalog/table_version_path.py +86 -23
- pixeltable/catalog/view.py +47 -23
- pixeltable/dataframe.py +198 -19
- pixeltable/env.py +6 -4
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/exec_node.py +2 -0
- pixeltable/exec/expr_eval/evaluators.py +4 -1
- pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
- pixeltable/exec/in_memory_data_node.py +1 -1
- pixeltable/exec/sql_node.py +188 -22
- pixeltable/exprs/column_property_ref.py +16 -6
- pixeltable/exprs/column_ref.py +33 -11
- pixeltable/exprs/comparison.py +1 -1
- pixeltable/exprs/data_row.py +5 -3
- pixeltable/exprs/expr.py +11 -4
- pixeltable/exprs/literal.py +2 -0
- pixeltable/exprs/row_builder.py +4 -6
- pixeltable/exprs/rowid_ref.py +8 -0
- pixeltable/exprs/similarity_expr.py +1 -0
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +5 -3
- pixeltable/func/tools.py +12 -2
- pixeltable/func/udf.py +2 -2
- pixeltable/functions/__init__.py +1 -0
- pixeltable/functions/anthropic.py +19 -45
- pixeltable/functions/deepseek.py +19 -38
- pixeltable/functions/fireworks.py +9 -18
- pixeltable/functions/gemini.py +165 -33
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/llama_cpp.py +6 -6
- pixeltable/functions/math.py +63 -0
- pixeltable/functions/mistralai.py +16 -53
- pixeltable/functions/ollama.py +1 -1
- pixeltable/functions/openai.py +82 -165
- pixeltable/functions/string.py +212 -58
- pixeltable/functions/together.py +22 -80
- pixeltable/globals.py +10 -4
- pixeltable/index/base.py +5 -0
- pixeltable/index/btree.py +5 -0
- pixeltable/index/embedding_index.py +5 -0
- pixeltable/io/external_store.py +10 -31
- pixeltable/io/label_studio.py +5 -5
- pixeltable/io/parquet.py +4 -4
- pixeltable/io/table_data_conduit.py +1 -32
- pixeltable/metadata/__init__.py +11 -2
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_30.py +6 -11
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/util.py +3 -9
- pixeltable/metadata/notes.py +3 -0
- pixeltable/metadata/schema.py +13 -1
- pixeltable/plan.py +135 -12
- pixeltable/share/packager.py +321 -20
- pixeltable/share/publish.py +2 -2
- pixeltable/store.py +31 -13
- pixeltable/type_system.py +30 -0
- pixeltable/utils/dbms.py +1 -1
- pixeltable/utils/formatter.py +64 -42
- {pixeltable-0.3.14.dist-info → pixeltable-0.4.0.dist-info}/METADATA +2 -1
- {pixeltable-0.3.14.dist-info → pixeltable-0.4.0.dist-info}/RECORD +79 -74
- {pixeltable-0.3.14.dist-info → pixeltable-0.4.0.dist-info}/LICENSE +0 -0
- {pixeltable-0.3.14.dist-info → pixeltable-0.4.0.dist-info}/WHEEL +0 -0
- {pixeltable-0.3.14.dist-info → pixeltable-0.4.0.dist-info}/entry_points.txt +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -14,6 +14,7 @@ from .exec_node import ExecNode
|
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
import pixeltable.plan
|
|
17
|
+
from pixeltable.plan import SampleClause
|
|
17
18
|
|
|
18
19
|
_logger = logging.getLogger('pixeltable')
|
|
19
20
|
|
|
@@ -64,8 +65,12 @@ def print_order_by_clause(clause: OrderByClause) -> str:
|
|
|
64
65
|
|
|
65
66
|
class SqlNode(ExecNode):
|
|
66
67
|
"""
|
|
67
|
-
Materializes data from the store via
|
|
68
|
+
Materializes data from the store via an SQL statement.
|
|
68
69
|
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
70
|
+
The pk columns are not included in the select list.
|
|
71
|
+
If set_pk is True, they are added to the end of the result set when creating the SQL statement
|
|
72
|
+
so they can always be referenced as cols[-num_pk_cols:] in the result set.
|
|
73
|
+
The pk_columns consist of the rowid columns of the target table followed by the version number.
|
|
69
74
|
"""
|
|
70
75
|
|
|
71
76
|
tbl: Optional[catalog.TableVersionPath]
|
|
@@ -122,6 +127,7 @@ class SqlNode(ExecNode):
|
|
|
122
127
|
# we also need to retrieve the pk columns
|
|
123
128
|
assert tbl is not None
|
|
124
129
|
self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
|
|
130
|
+
assert self.num_pk_cols > 1
|
|
125
131
|
|
|
126
132
|
# additional state
|
|
127
133
|
self.result_cursor = None
|
|
@@ -134,14 +140,25 @@ class SqlNode(ExecNode):
|
|
|
134
140
|
self.where_clause_element = None
|
|
135
141
|
self.order_by_clause = []
|
|
136
142
|
|
|
143
|
+
if self.tbl is not None:
|
|
144
|
+
tv = self.tbl.tbl_version._tbl_version
|
|
145
|
+
if tv is not None:
|
|
146
|
+
assert tv.is_validated
|
|
147
|
+
|
|
148
|
+
def _create_pk_cols(self) -> list[sql.Column]:
|
|
149
|
+
"""Create a list of pk columns"""
|
|
150
|
+
# we need to retrieve the pk columns
|
|
151
|
+
if self.set_pk:
|
|
152
|
+
assert self.tbl is not None
|
|
153
|
+
assert self.tbl.tbl_version.get().is_validated
|
|
154
|
+
return self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
155
|
+
return []
|
|
156
|
+
|
|
137
157
|
def _create_stmt(self) -> sql.Select:
|
|
138
158
|
"""Create Select from local state"""
|
|
139
159
|
|
|
140
160
|
assert self.sql_elements.contains_all(self.select_list)
|
|
141
|
-
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
142
|
-
if self.set_pk:
|
|
143
|
-
assert self.tbl is not None
|
|
144
|
-
sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
161
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.select_list] + self._create_pk_cols()
|
|
145
162
|
stmt = sql.select(*sql_select_list)
|
|
146
163
|
|
|
147
164
|
where_clause_element = (
|
|
@@ -167,9 +184,10 @@ class SqlNode(ExecNode):
|
|
|
167
184
|
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
168
185
|
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
169
186
|
|
|
170
|
-
def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
187
|
+
def to_cte(self, keep_pk: bool = False) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
171
188
|
"""
|
|
172
|
-
|
|
189
|
+
Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
|
|
190
|
+
keep_pk: if True, the PK columns are included in the CTE Select statement
|
|
173
191
|
|
|
174
192
|
Returns:
|
|
175
193
|
(CTE, dict from Expr to output column)
|
|
@@ -177,11 +195,13 @@ class SqlNode(ExecNode):
|
|
|
177
195
|
if self.py_filter is not None:
|
|
178
196
|
# the filter needs to run in Python
|
|
179
197
|
return None
|
|
180
|
-
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
181
198
|
if self.cte is None:
|
|
199
|
+
if not keep_pk:
|
|
200
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
182
201
|
self.cte = self._create_stmt().cte()
|
|
183
|
-
|
|
184
|
-
|
|
202
|
+
pk_count = self.num_pk_cols if self.set_pk else 0
|
|
203
|
+
assert len(self.select_list) + pk_count == len(self.cte.c)
|
|
204
|
+
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c)) # skip pk cols
|
|
185
205
|
|
|
186
206
|
@classmethod
|
|
187
207
|
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
@@ -220,26 +240,29 @@ class SqlNode(ExecNode):
|
|
|
220
240
|
joined_tbls.append(t)
|
|
221
241
|
|
|
222
242
|
first = True
|
|
223
|
-
|
|
243
|
+
prev_tv: Optional[catalog.TableVersion] = None
|
|
224
244
|
for t in joined_tbls[::-1]:
|
|
245
|
+
tv = t.get()
|
|
246
|
+
# _logger.debug(f'create_from_clause: tbl_id={tv.id} {id(tv.store_tbl.sa_tbl)}')
|
|
225
247
|
if first:
|
|
226
|
-
stmt = stmt.select_from(
|
|
248
|
+
stmt = stmt.select_from(tv.store_tbl.sa_tbl)
|
|
227
249
|
first = False
|
|
228
250
|
else:
|
|
229
|
-
# join
|
|
230
|
-
prev_tbl_rowid_cols =
|
|
231
|
-
tbl_rowid_cols =
|
|
251
|
+
# join tv to prev_tv on prev_tv's rowid cols
|
|
252
|
+
prev_tbl_rowid_cols = prev_tv.store_tbl.rowid_columns()
|
|
253
|
+
tbl_rowid_cols = tv.store_tbl.rowid_columns()
|
|
232
254
|
rowid_clauses = [
|
|
233
255
|
c1 == c2 for c1, c2 in zip(prev_tbl_rowid_cols, tbl_rowid_cols[: len(prev_tbl_rowid_cols)])
|
|
234
256
|
]
|
|
235
|
-
stmt = stmt.join(
|
|
257
|
+
stmt = stmt.join(tv.store_tbl.sa_tbl, sql.and_(*rowid_clauses))
|
|
258
|
+
|
|
236
259
|
if t.id in exact_version_only:
|
|
237
|
-
stmt = stmt.where(
|
|
260
|
+
stmt = stmt.where(tv.store_tbl.v_min_col == tv.version)
|
|
238
261
|
else:
|
|
239
|
-
stmt = stmt.where(
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
262
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_min <= tv.version)
|
|
263
|
+
stmt = stmt.where(tv.store_tbl.sa_tbl.c.v_max > tv.version)
|
|
264
|
+
prev_tv = tv
|
|
265
|
+
|
|
243
266
|
return stmt
|
|
244
267
|
|
|
245
268
|
def set_where(self, where_clause: exprs.Expr) -> None:
|
|
@@ -284,7 +307,8 @@ class SqlNode(ExecNode):
|
|
|
284
307
|
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
285
308
|
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
286
309
|
except Exception:
|
|
287
|
-
|
|
310
|
+
# log something if we can't log the compiled stmt
|
|
311
|
+
_logger.debug(f'SqlLookupNode proto-stmt:\n{stmt}')
|
|
288
312
|
self._log_explain(stmt)
|
|
289
313
|
|
|
290
314
|
conn = Env.get().conn
|
|
@@ -501,3 +525,145 @@ class SqlJoinNode(SqlNode):
|
|
|
501
525
|
full=join_clause == plan.JoinType.FULL_OUTER,
|
|
502
526
|
)
|
|
503
527
|
return stmt
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class SqlSampleNode(SqlNode):
|
|
531
|
+
"""
|
|
532
|
+
Returns rows sampled from the input node.
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
input_cte: Optional[sql.CTE]
|
|
536
|
+
pk_count: int
|
|
537
|
+
stratify_exprs: Optional[list[exprs.Expr]]
|
|
538
|
+
sample_clause: 'SampleClause'
|
|
539
|
+
|
|
540
|
+
def __init__(
|
|
541
|
+
self,
|
|
542
|
+
row_builder: exprs.RowBuilder,
|
|
543
|
+
input: SqlNode,
|
|
544
|
+
select_list: Iterable[exprs.Expr],
|
|
545
|
+
sample_clause: 'SampleClause',
|
|
546
|
+
stratify_exprs: list[exprs.Expr],
|
|
547
|
+
):
|
|
548
|
+
"""
|
|
549
|
+
Args:
|
|
550
|
+
input: SqlNode to sample from
|
|
551
|
+
select_list: can contain calls to AggregateFunctions
|
|
552
|
+
sample_clause: specifies the sampling method
|
|
553
|
+
stratify_exprs: Analyzer processed list of expressions to stratify by.
|
|
554
|
+
"""
|
|
555
|
+
assert isinstance(input, SqlNode)
|
|
556
|
+
self.input_cte, input_col_map = input.to_cte(keep_pk=True)
|
|
557
|
+
self.pk_count = input.num_pk_cols
|
|
558
|
+
assert self.pk_count > 1
|
|
559
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
560
|
+
assert sql_elements.contains_all(stratify_exprs)
|
|
561
|
+
super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
562
|
+
self.stratify_exprs = stratify_exprs
|
|
563
|
+
self.sample_clause = sample_clause
|
|
564
|
+
assert isinstance(self.sample_clause.seed, int)
|
|
565
|
+
|
|
566
|
+
@classmethod
|
|
567
|
+
def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
|
|
568
|
+
"""Construct expression which is the ordering key for rows to be sampled
|
|
569
|
+
General SQL form is:
|
|
570
|
+
- MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
|
|
571
|
+
"""
|
|
572
|
+
sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
|
|
573
|
+
for e in sql_cols:
|
|
574
|
+
# Quotes are required below to guarantee that the string is properly presented in SQL
|
|
575
|
+
sql_expr = sql_expr + sql.literal_column("'___'", sql.Text) + sql.cast(e, sql.Text)
|
|
576
|
+
sql_expr = sql.func.md5(sql_expr)
|
|
577
|
+
return sql_expr
|
|
578
|
+
|
|
579
|
+
def _create_key_sql(self, cte: sql.CTE) -> sql.ColumnElement:
|
|
580
|
+
"""Create an expression for randomly ordering rows with a given seed"""
|
|
581
|
+
rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
|
|
582
|
+
assert len(rowid_cols) > 0
|
|
583
|
+
return self.key_sql_expr(sql.literal_column(str(self.sample_clause.seed)), rowid_cols)
|
|
584
|
+
|
|
585
|
+
def _create_stmt(self) -> sql.Select:
|
|
586
|
+
from pixeltable.plan import SampleClause
|
|
587
|
+
|
|
588
|
+
if self.sample_clause.fraction is not None:
|
|
589
|
+
if len(self.stratify_exprs) == 0:
|
|
590
|
+
# If non-stratified sampling, construct a where clause, order_by, and limit clauses
|
|
591
|
+
s_key = self._create_key_sql(self.input_cte)
|
|
592
|
+
|
|
593
|
+
# Construct a suitable where clause
|
|
594
|
+
fraction_sql = sql.cast(SampleClause.fraction_to_md5_hex(float(self.sample_clause.fraction)), sql.Text)
|
|
595
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
596
|
+
return sql.select(*self.input_cte.c).where(s_key < fraction_sql).order_by(order_by)
|
|
597
|
+
|
|
598
|
+
return self._create_stmt_stratified_fraction(self.sample_clause.fraction)
|
|
599
|
+
else:
|
|
600
|
+
if len(self.stratify_exprs) == 0:
|
|
601
|
+
# No stratification, just return n samples from the input CTE
|
|
602
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
603
|
+
return sql.select(*self.input_cte.c).order_by(order_by).limit(self.sample_clause.n)
|
|
604
|
+
|
|
605
|
+
return self._create_stmt_stratified_n(self.sample_clause.n, self.sample_clause.n_per_stratum)
|
|
606
|
+
|
|
607
|
+
def _create_stmt_stratified_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
|
|
608
|
+
"""Create a Select stmt that returns n samples across all strata or n_per_stratum samples per stratum"""
|
|
609
|
+
|
|
610
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
611
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
612
|
+
|
|
613
|
+
# Create a list of all columns plus the rank
|
|
614
|
+
# Get all columns from the input CTE dynamically
|
|
615
|
+
select_columns = [*self.input_cte.c]
|
|
616
|
+
select_columns.append(
|
|
617
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
618
|
+
)
|
|
619
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
620
|
+
|
|
621
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
622
|
+
if n_per_stratum is not None:
|
|
623
|
+
return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
|
|
624
|
+
else:
|
|
625
|
+
secondary_order = self._create_key_sql(row_rank_cte)
|
|
626
|
+
return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
|
|
627
|
+
|
|
628
|
+
def _create_stmt_stratified_fraction(self, fraction_samples: float) -> sql.Select:
|
|
629
|
+
"""Create a Select stmt that returns a fraction of the rows per strata"""
|
|
630
|
+
|
|
631
|
+
# Build the strata count CTE
|
|
632
|
+
# Produces a table of the form:
|
|
633
|
+
# (*stratify_exprs, s_s_size)
|
|
634
|
+
# where s_s_size is the number of samples to take from each stratum
|
|
635
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
636
|
+
per_strata_count_cte = (
|
|
637
|
+
sql.select(
|
|
638
|
+
*sql_strata_exprs,
|
|
639
|
+
sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
|
|
640
|
+
)
|
|
641
|
+
.select_from(self.input_cte)
|
|
642
|
+
.group_by(*sql_strata_exprs)
|
|
643
|
+
.cte('per_strata_count_cte')
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# Build a CTE that ranks the rows within each stratum
|
|
647
|
+
# Include all columns from the input CTE dynamically
|
|
648
|
+
order_by = self._create_key_sql(self.input_cte)
|
|
649
|
+
select_columns = [*self.input_cte.c]
|
|
650
|
+
select_columns.append(
|
|
651
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
652
|
+
)
|
|
653
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
654
|
+
|
|
655
|
+
# Build the join criterion dynamically to accommodate any number of stratify_by expressions
|
|
656
|
+
join_c = sql.true()
|
|
657
|
+
for col in per_strata_count_cte.c[:-1]:
|
|
658
|
+
join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
|
|
659
|
+
|
|
660
|
+
# Join with per_strata_count_cte to limit returns to the requested fraction of rows
|
|
661
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
662
|
+
stmt = (
|
|
663
|
+
sql.select(*final_columns)
|
|
664
|
+
.select_from(row_rank_cte)
|
|
665
|
+
.join(per_strata_count_cte, join_c)
|
|
666
|
+
.where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
return stmt
|
|
@@ -58,20 +58,30 @@ class ColumnPropertyRef(Expr):
|
|
|
58
58
|
if not self._col_ref.col.is_stored:
|
|
59
59
|
return None
|
|
60
60
|
|
|
61
|
+
# we need to reestablish that we have the correct Column instance, there could have been a metadata
|
|
62
|
+
# reload since init()
|
|
63
|
+
# TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
|
|
64
|
+
# perform runtime checks and update state
|
|
65
|
+
tv = self._col_ref.tbl_version.get()
|
|
66
|
+
assert tv.is_validated
|
|
67
|
+
# we can assume at this point during query execution that the column exists
|
|
68
|
+
assert self._col_ref.col_id in tv.cols_by_id
|
|
69
|
+
col = tv.cols_by_id[self._col_ref.col_id]
|
|
70
|
+
|
|
61
71
|
# the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
|
|
62
72
|
if (
|
|
63
|
-
|
|
64
|
-
and
|
|
73
|
+
col.col_type.is_media_type()
|
|
74
|
+
and col.media_validation == catalog.MediaValidation.ON_READ
|
|
65
75
|
and self.is_error_prop()
|
|
66
76
|
):
|
|
67
77
|
return None
|
|
68
78
|
|
|
69
79
|
if self.prop == self.Property.ERRORTYPE:
|
|
70
|
-
assert
|
|
71
|
-
return
|
|
80
|
+
assert col.sa_errortype_col is not None
|
|
81
|
+
return col.sa_errortype_col
|
|
72
82
|
if self.prop == self.Property.ERRORMSG:
|
|
73
|
-
assert
|
|
74
|
-
return
|
|
83
|
+
assert col.sa_errormsg_col is not None
|
|
84
|
+
return col.sa_errormsg_col
|
|
75
85
|
if self.prop == self.Property.FILEURL:
|
|
76
86
|
# the file url is stored as the column value
|
|
77
87
|
return sql_elements.get(self._col_ref)
|
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -52,6 +52,10 @@ class ColumnRef(Expr):
|
|
|
52
52
|
id: int
|
|
53
53
|
perform_validation: bool # if True, performs media validation
|
|
54
54
|
|
|
55
|
+
# needed by sql_expr() to re-resolve Column instance after a metadata reload
|
|
56
|
+
tbl_version: catalog.TableVersionHandle
|
|
57
|
+
col_id: int
|
|
58
|
+
|
|
55
59
|
def __init__(
|
|
56
60
|
self,
|
|
57
61
|
col: catalog.Column,
|
|
@@ -62,16 +66,17 @@ class ColumnRef(Expr):
|
|
|
62
66
|
assert col.tbl is not None
|
|
63
67
|
self.col = col
|
|
64
68
|
self.reference_tbl = reference_tbl
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
|
|
69
|
+
self.tbl_version = catalog.TableVersionHandle(col.tbl.id, col.tbl.effective_version)
|
|
70
|
+
self.col_id = col.id
|
|
71
|
+
|
|
72
|
+
self.is_unstored_iter_col = col.tbl.is_component_view and col.tbl.is_iterator_column(col) and not col.is_stored
|
|
68
73
|
self.iter_arg_ctx = None
|
|
69
74
|
# number of rowid columns in the base table
|
|
70
|
-
self.base_rowid_len = col.tbl.
|
|
75
|
+
self.base_rowid_len = col.tbl.base.get().num_rowid_columns() if self.is_unstored_iter_col else 0
|
|
71
76
|
self.base_rowid = [None] * self.base_rowid_len
|
|
72
77
|
self.iterator = None
|
|
73
78
|
# index of the position column in the view's primary key; don't try to reference tbl.store_tbl here
|
|
74
|
-
self.pos_idx = col.tbl.
|
|
79
|
+
self.pos_idx = col.tbl.num_rowid_columns() - 1 if self.is_unstored_iter_col else None
|
|
75
80
|
|
|
76
81
|
self.perform_validation = False
|
|
77
82
|
if col.col_type.is_media_type():
|
|
@@ -175,7 +180,7 @@ class ColumnRef(Expr):
|
|
|
175
180
|
assert len(idx_info) == 1
|
|
176
181
|
col = copy.copy(next(iter(idx_info.values())).val_col)
|
|
177
182
|
col.name = f'{self.col.name}_embedding_{idx if idx is not None else ""}'
|
|
178
|
-
col.create_sa_cols()
|
|
183
|
+
# col.create_sa_cols()
|
|
179
184
|
return ColumnRef(col)
|
|
180
185
|
|
|
181
186
|
def default_column_name(self) -> Optional[str]:
|
|
@@ -226,7 +231,7 @@ class ColumnRef(Expr):
|
|
|
226
231
|
def _descriptors(self) -> DescriptionHelper:
|
|
227
232
|
tbl = catalog.Catalog.get().get_table_by_id(self.col.tbl.id)
|
|
228
233
|
helper = DescriptionHelper()
|
|
229
|
-
helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path!r})')
|
|
234
|
+
helper.append(f'Column\n{self.col.name!r}\n(of table {tbl._path()!r})')
|
|
230
235
|
helper.append(tbl._col_descriptor([self.col.name]))
|
|
231
236
|
idxs = tbl._index_descriptor([self.col.name])
|
|
232
237
|
if len(idxs) > 0:
|
|
@@ -234,7 +239,19 @@ class ColumnRef(Expr):
|
|
|
234
239
|
return helper
|
|
235
240
|
|
|
236
241
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
237
|
-
|
|
242
|
+
if self.perform_validation:
|
|
243
|
+
return None
|
|
244
|
+
# we need to reestablish that we have the correct Column instance, there could have been a metadata
|
|
245
|
+
# reload since init()
|
|
246
|
+
# TODO: add an explicit prepare phase (ie, Expr.prepare()) that gives every subclass instance a chance to
|
|
247
|
+
# perform runtime checks and update state
|
|
248
|
+
tv = self.tbl_version.get()
|
|
249
|
+
assert tv.is_validated
|
|
250
|
+
# we can assume at this point during query execution that the column exists
|
|
251
|
+
assert self.col_id in tv.cols_by_id
|
|
252
|
+
self.col = tv.cols_by_id[self.col_id]
|
|
253
|
+
assert self.col.tbl is tv
|
|
254
|
+
return self.col.sa_col
|
|
238
255
|
|
|
239
256
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
240
257
|
if self.perform_validation:
|
|
@@ -275,7 +292,7 @@ class ColumnRef(Expr):
|
|
|
275
292
|
if self.base_rowid != data_row.pk[: self.base_rowid_len]:
|
|
276
293
|
row_builder.eval(data_row, self.iter_arg_ctx)
|
|
277
294
|
iterator_args = data_row[self.iter_arg_ctx.target_slot_idxs[0]]
|
|
278
|
-
self.iterator = self.col.tbl.
|
|
295
|
+
self.iterator = self.col.tbl.iterator_cls(**iterator_args)
|
|
279
296
|
self.base_rowid = data_row.pk[: self.base_rowid_len]
|
|
280
297
|
self.iterator.set_pos(data_row.pk[self.pos_idx])
|
|
281
298
|
res = next(self.iterator)
|
|
@@ -283,17 +300,22 @@ class ColumnRef(Expr):
|
|
|
283
300
|
|
|
284
301
|
def _as_dict(self) -> dict:
|
|
285
302
|
tbl = self.col.tbl
|
|
286
|
-
|
|
303
|
+
version = tbl.version if tbl.is_snapshot else None
|
|
287
304
|
# we omit self.components, even if this is a validating ColumnRef, because init() will recreate the
|
|
288
305
|
# non-validating component ColumnRef
|
|
289
306
|
return {
|
|
290
307
|
'tbl_id': str(tbl.id),
|
|
291
|
-
'tbl_version':
|
|
308
|
+
'tbl_version': version,
|
|
292
309
|
'col_id': self.col.id,
|
|
293
310
|
'reference_tbl': self.reference_tbl.as_dict() if self.reference_tbl is not None else None,
|
|
294
311
|
'perform_validation': self.perform_validation,
|
|
295
312
|
}
|
|
296
313
|
|
|
314
|
+
@classmethod
|
|
315
|
+
def get_column_id(cls, d: dict) -> catalog.QColumnId:
|
|
316
|
+
tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
|
|
317
|
+
return catalog.QColumnId(tbl_id, col_id)
|
|
318
|
+
|
|
297
319
|
@classmethod
|
|
298
320
|
def get_column(cls, d: dict) -> catalog.Column:
|
|
299
321
|
tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -81,7 +81,7 @@ class Comparison(Expr):
|
|
|
81
81
|
if self.is_search_arg_comparison:
|
|
82
82
|
# reference the index value column if there is an index and this is not a snapshot
|
|
83
83
|
# (indices don't apply to snapshots)
|
|
84
|
-
tbl = self._op1.col.tbl
|
|
84
|
+
tbl = self._op1.col.tbl
|
|
85
85
|
idx_info = [
|
|
86
86
|
info for info in self._op1.col.get_idx_info().values() if isinstance(info.idx, index.BtreeIndex)
|
|
87
87
|
]
|
pixeltable/exprs/data_row.py
CHANGED
|
@@ -214,6 +214,7 @@ class DataRow:
|
|
|
214
214
|
"""Assign in-memory cell value
|
|
215
215
|
This allows overwriting
|
|
216
216
|
"""
|
|
217
|
+
assert isinstance(idx, int)
|
|
217
218
|
assert self.excs[idx] is None
|
|
218
219
|
|
|
219
220
|
if (idx in self.img_slot_idxs or idx in self.media_slot_idxs) and isinstance(val, str):
|
|
@@ -253,14 +254,15 @@ class DataRow:
|
|
|
253
254
|
assert self.excs[index] is None
|
|
254
255
|
if self.file_paths[index] is None:
|
|
255
256
|
if filepath is not None:
|
|
256
|
-
# we want to save this to a file
|
|
257
|
-
self.file_paths[index] = filepath
|
|
258
|
-
self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
|
|
259
257
|
image = self.vals[index]
|
|
260
258
|
assert isinstance(image, PIL.Image.Image)
|
|
261
259
|
# Default to JPEG unless the image has a transparency layer (which isn't supported by JPEG).
|
|
262
260
|
# In that case, use WebP instead.
|
|
263
261
|
format = 'webp' if image.has_transparency_data else 'jpeg'
|
|
262
|
+
if not filepath.endswith(f'.{format}'):
|
|
263
|
+
filepath += f'.{format}'
|
|
264
|
+
self.file_paths[index] = filepath
|
|
265
|
+
self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
|
|
264
266
|
image.save(filepath, format=format)
|
|
265
267
|
else:
|
|
266
268
|
# we discard the content of this cell
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -276,6 +276,13 @@ class Expr(abc.ABC):
|
|
|
276
276
|
tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
|
|
277
277
|
return self._retarget(tbl_versions)
|
|
278
278
|
|
|
279
|
+
@classmethod
|
|
280
|
+
def retarget_list(cls, expr_list: list[Expr], tbl: catalog.TableVersionPath) -> None:
|
|
281
|
+
"""Retarget ColumnRefs in expr_list to the specific TableVersions in tbl."""
|
|
282
|
+
tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
|
|
283
|
+
for i in range(len(expr_list)):
|
|
284
|
+
expr_list[i] = expr_list[i]._retarget(tbl_versions)
|
|
285
|
+
|
|
279
286
|
def _retarget(self, tbl_versions: dict[UUID, catalog.TableVersion]) -> Self:
|
|
280
287
|
for i in range(len(self.components)):
|
|
281
288
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
|
@@ -387,17 +394,17 @@ class Expr(abc.ABC):
|
|
|
387
394
|
return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
|
|
388
395
|
|
|
389
396
|
@classmethod
|
|
390
|
-
def
|
|
397
|
+
def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
|
|
391
398
|
"""Return Columns referenced by expr_dict."""
|
|
392
|
-
result:
|
|
399
|
+
result: set[catalog.QColumnId] = set()
|
|
393
400
|
assert '_classname' in expr_dict
|
|
394
401
|
from .column_ref import ColumnRef
|
|
395
402
|
|
|
396
403
|
if expr_dict['_classname'] == 'ColumnRef':
|
|
397
|
-
result.
|
|
404
|
+
result.add(ColumnRef.get_column_id(expr_dict))
|
|
398
405
|
if 'components' in expr_dict:
|
|
399
406
|
for component_dict in expr_dict['components']:
|
|
400
|
-
result.
|
|
407
|
+
result.update(cls.get_refd_column_ids(component_dict))
|
|
401
408
|
return result
|
|
402
409
|
|
|
403
410
|
def as_literal(self) -> Optional[Expr]:
|
pixeltable/exprs/literal.py
CHANGED
pixeltable/exprs/row_builder.py
CHANGED
|
@@ -172,13 +172,11 @@ class RowBuilder:
|
|
|
172
172
|
|
|
173
173
|
def refs_unstored_iter_col(col_ref: ColumnRef) -> bool:
|
|
174
174
|
tbl = col_ref.col.tbl
|
|
175
|
-
return (
|
|
176
|
-
tbl.get().is_component_view and tbl.get().is_iterator_column(col_ref.col) and not col_ref.col.is_stored
|
|
177
|
-
)
|
|
175
|
+
return tbl.is_component_view and tbl.is_iterator_column(col_ref.col) and not col_ref.col.is_stored
|
|
178
176
|
|
|
179
177
|
unstored_iter_col_refs = [col_ref for col_ref in col_refs if refs_unstored_iter_col(col_ref)]
|
|
180
178
|
component_views = [col_ref.col.tbl for col_ref in unstored_iter_col_refs]
|
|
181
|
-
unstored_iter_args = {view.id: view.
|
|
179
|
+
unstored_iter_args = {view.id: view.iterator_args.copy() for view in component_views}
|
|
182
180
|
self.unstored_iter_args = {
|
|
183
181
|
id: self._record_unique_expr(arg, recursive=True) for id, arg in unstored_iter_args.items()
|
|
184
182
|
}
|
|
@@ -450,9 +448,9 @@ class RowBuilder:
|
|
|
450
448
|
else:
|
|
451
449
|
if col.col_type.is_image_type() and data_row.file_urls[slot_idx] is None:
|
|
452
450
|
# we have yet to store this image
|
|
453
|
-
filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.
|
|
451
|
+
filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
|
|
454
452
|
data_row.flush_img(slot_idx, filepath)
|
|
455
|
-
val = data_row.get_stored_val(slot_idx, col.
|
|
453
|
+
val = data_row.get_stored_val(slot_idx, col.get_sa_col_type())
|
|
456
454
|
table_row[col.store_name()] = val
|
|
457
455
|
# we unfortunately need to set these, even if there are no errors
|
|
458
456
|
table_row[col.errortype_store_name()] = None
|
pixeltable/exprs/rowid_ref.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Any, Optional, cast
|
|
4
5
|
from uuid import UUID
|
|
5
6
|
|
|
@@ -12,6 +13,8 @@ from .expr import Expr
|
|
|
12
13
|
from .row_builder import RowBuilder
|
|
13
14
|
from .sql_element_cache import SqlElementCache
|
|
14
15
|
|
|
16
|
+
_logger = logging.getLogger('pixeltable')
|
|
17
|
+
|
|
15
18
|
|
|
16
19
|
class RowidRef(Expr):
|
|
17
20
|
"""A reference to a part of a table rowid
|
|
@@ -97,10 +100,15 @@ class RowidRef(Expr):
|
|
|
97
100
|
|
|
98
101
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
99
102
|
tbl = self.tbl.get() if self.tbl is not None else catalog.Catalog.get().get_tbl_version(self.tbl_id, None)
|
|
103
|
+
assert tbl.is_validated
|
|
100
104
|
rowid_cols = tbl.store_tbl.rowid_columns()
|
|
101
105
|
assert self.rowid_component_idx <= len(rowid_cols), (
|
|
102
106
|
f'{self.rowid_component_idx} not consistent with {rowid_cols}'
|
|
103
107
|
)
|
|
108
|
+
# _logger.debug(
|
|
109
|
+
# f'RowidRef.sql_expr: tbl={tbl.id}{tbl.effective_version} sa_tbl={id(tbl.store_tbl.sa_tbl):x} '
|
|
110
|
+
# f'tv={id(tbl):x}'
|
|
111
|
+
# )
|
|
104
112
|
return rowid_cols[self.rowid_component_idx]
|
|
105
113
|
|
|
106
114
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
@@ -54,6 +54,7 @@ class SimilarityExpr(Expr):
|
|
|
54
54
|
return 'similarity'
|
|
55
55
|
|
|
56
56
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
57
|
+
# TODO: validate that the index still exists
|
|
57
58
|
if not isinstance(self.components[1], Literal):
|
|
58
59
|
raise excs.Error('similarity(): requires a string or a PIL.Image.Image object, not an expression')
|
|
59
60
|
item = self.components[1].val
|
pixeltable/func/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from .callable_function import CallableFunction
|
|
|
5
5
|
from .expr_template_function import ExprTemplateFunction
|
|
6
6
|
from .function import Function, InvalidFunction
|
|
7
7
|
from .function_registry import FunctionRegistry
|
|
8
|
+
from .mcp import mcp_udfs
|
|
8
9
|
from .query_template_function import QueryTemplateFunction, query, retrieval_udf
|
|
9
10
|
from .signature import Batch, Parameter, Signature
|
|
10
11
|
from .tools import Tool, ToolChoice, Tools
|
pixeltable/func/mcp.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
from pixeltable import exceptions as excs, type_system as ts
|
|
7
|
+
from pixeltable.func.signature import Parameter
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import mcp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mcp_udfs(url: str) -> list['pxt.func.Function']:
|
|
14
|
+
return asyncio.run(mcp_udfs_async(url))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def mcp_udfs_async(url: str) -> list['pxt.func.Function']:
|
|
18
|
+
import mcp
|
|
19
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
20
|
+
|
|
21
|
+
list_tools_result: Optional[mcp.types.ListToolsResult] = None
|
|
22
|
+
async with (
|
|
23
|
+
streamablehttp_client(url) as (read_stream, write_stream, _),
|
|
24
|
+
mcp.ClientSession(read_stream, write_stream) as session,
|
|
25
|
+
):
|
|
26
|
+
await session.initialize()
|
|
27
|
+
list_tools_result = await session.list_tools()
|
|
28
|
+
assert list_tools_result is not None
|
|
29
|
+
|
|
30
|
+
return [mcp_tool_to_udf(url, tool) for tool in list_tools_result.tools]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def mcp_tool_to_udf(url: str, mcp_tool: 'mcp.types.Tool') -> 'pxt.func.Function':
|
|
34
|
+
import mcp
|
|
35
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
36
|
+
|
|
37
|
+
async def invoke(**kwargs: Any) -> str:
|
|
38
|
+
# TODO: Cache session objects rather than creating a new one each time?
|
|
39
|
+
async with (
|
|
40
|
+
streamablehttp_client(url) as (read_stream, write_stream, _),
|
|
41
|
+
mcp.ClientSession(read_stream, write_stream) as session,
|
|
42
|
+
):
|
|
43
|
+
await session.initialize()
|
|
44
|
+
res = await session.call_tool(name=mcp_tool.name, arguments=kwargs)
|
|
45
|
+
# TODO Handle image/audio responses?
|
|
46
|
+
return res.content[0].text # type: ignore[union-attr]
|
|
47
|
+
|
|
48
|
+
if mcp_tool.description is not None:
|
|
49
|
+
invoke.__doc__ = mcp_tool.description
|
|
50
|
+
|
|
51
|
+
input_schema = mcp_tool.inputSchema
|
|
52
|
+
params = {
|
|
53
|
+
name: __mcp_param_to_pxt_type(mcp_tool.name, name, param) for name, param in input_schema['properties'].items()
|
|
54
|
+
}
|
|
55
|
+
required = input_schema.get('required', [])
|
|
56
|
+
|
|
57
|
+
# Ensure that any params not appearing in `required` are nullable.
|
|
58
|
+
# (A required param might or might not be nullable, since its type might be an 'anyOf' containing a null.)
|
|
59
|
+
for name in params.keys() - required:
|
|
60
|
+
params[name] = params[name].copy(nullable=True)
|
|
61
|
+
|
|
62
|
+
signature = pxt.func.Signature(
|
|
63
|
+
return_type=ts.StringType(), # Return type is always string
|
|
64
|
+
parameters=[Parameter(name, col_type, inspect.Parameter.KEYWORD_ONLY) for name, col_type in params.items()],
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return pxt.func.CallableFunction(signatures=[signature], py_fns=[invoke], self_name=mcp_tool.name)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def __mcp_param_to_pxt_type(tool_name: str, name: str, param: dict[str, Any]) -> ts.ColumnType:
|
|
71
|
+
pxt_type = ts.ColumnType.from_json_schema(param)
|
|
72
|
+
if pxt_type is None:
|
|
73
|
+
raise excs.Error(f'Unknown type schema for MCP parameter {name!r} of tool {tool_name!r}: {param}')
|
|
74
|
+
return pxt_type
|