pixeltable 0.4.0rc1__py3-none-any.whl → 0.4.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/catalog.py +4 -0
- pixeltable/catalog/table.py +16 -0
- pixeltable/catalog/table_version.py +17 -2
- pixeltable/catalog/view.py +24 -1
- pixeltable/dataframe.py +185 -9
- pixeltable/env.py +2 -0
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/expr_eval/evaluators.py +4 -1
- pixeltable/exec/sql_node.py +152 -12
- pixeltable/exprs/data_row.py +5 -3
- pixeltable/exprs/expr.py +7 -0
- pixeltable/exprs/literal.py +2 -0
- pixeltable/func/tools.py +1 -1
- pixeltable/functions/anthropic.py +19 -45
- pixeltable/functions/deepseek.py +19 -38
- pixeltable/functions/fireworks.py +9 -18
- pixeltable/functions/gemini.py +2 -3
- pixeltable/functions/llama_cpp.py +6 -6
- pixeltable/functions/mistralai.py +15 -41
- pixeltable/functions/ollama.py +1 -1
- pixeltable/functions/openai.py +82 -165
- pixeltable/functions/together.py +22 -80
- pixeltable/globals.py +5 -0
- pixeltable/metadata/__init__.py +11 -2
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +3 -0
- pixeltable/plan.py +217 -10
- pixeltable/share/packager.py +115 -6
- pixeltable/utils/formatter.py +64 -42
- pixeltable/utils/sample.py +25 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/METADATA +2 -1
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/RECORD +37 -35
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/LICENSE +0 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/WHEEL +0 -0
- {pixeltable-0.4.0rc1.dist-info → pixeltable-0.4.0rc3.dist-info}/entry_points.txt +0 -0
pixeltable/exec/sql_node.py
CHANGED
|
@@ -14,6 +14,7 @@ from .exec_node import ExecNode
|
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
import pixeltable.plan
|
|
17
|
+
from pixeltable.plan import SampleClause
|
|
17
18
|
|
|
18
19
|
_logger = logging.getLogger('pixeltable')
|
|
19
20
|
|
|
@@ -64,8 +65,12 @@ def print_order_by_clause(clause: OrderByClause) -> str:
|
|
|
64
65
|
|
|
65
66
|
class SqlNode(ExecNode):
|
|
66
67
|
"""
|
|
67
|
-
Materializes data from the store via
|
|
68
|
+
Materializes data from the store via an SQL statement.
|
|
68
69
|
This only provides the select list. The subclasses are responsible for the From clause and any additional clauses.
|
|
70
|
+
The pk columns are not included in the select list.
|
|
71
|
+
If set_pk is True, they are added to the end of the result set when creating the SQL statement
|
|
72
|
+
so they can always be referenced as cols[-num_pk_cols:] in the result set.
|
|
73
|
+
The pk_columns consist of the rowid columns of the target table followed by the version number.
|
|
69
74
|
"""
|
|
70
75
|
|
|
71
76
|
tbl: Optional[catalog.TableVersionPath]
|
|
@@ -122,6 +127,7 @@ class SqlNode(ExecNode):
|
|
|
122
127
|
# we also need to retrieve the pk columns
|
|
123
128
|
assert tbl is not None
|
|
124
129
|
self.num_pk_cols = len(tbl.tbl_version.get().store_tbl.pk_columns())
|
|
130
|
+
assert self.num_pk_cols > 1
|
|
125
131
|
|
|
126
132
|
# additional state
|
|
127
133
|
self.result_cursor = None
|
|
@@ -139,15 +145,20 @@ class SqlNode(ExecNode):
|
|
|
139
145
|
if tv is not None:
|
|
140
146
|
assert tv.is_validated
|
|
141
147
|
|
|
148
|
+
def _create_pk_cols(self) -> list[sql.Column]:
|
|
149
|
+
"""Create a list of pk columns"""
|
|
150
|
+
# we need to retrieve the pk columns
|
|
151
|
+
if self.set_pk:
|
|
152
|
+
assert self.tbl is not None
|
|
153
|
+
assert self.tbl.tbl_version.get().is_validated
|
|
154
|
+
return self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
155
|
+
return []
|
|
156
|
+
|
|
142
157
|
def _create_stmt(self) -> sql.Select:
|
|
143
158
|
"""Create Select from local state"""
|
|
144
159
|
|
|
145
160
|
assert self.sql_elements.contains_all(self.select_list)
|
|
146
|
-
sql_select_list = [self.sql_elements.get(e) for e in self.select_list]
|
|
147
|
-
if self.set_pk:
|
|
148
|
-
assert self.tbl is not None
|
|
149
|
-
assert self.tbl.tbl_version.get().is_validated
|
|
150
|
-
sql_select_list += self.tbl.tbl_version.get().store_tbl.pk_columns()
|
|
161
|
+
sql_select_list = [self.sql_elements.get(e) for e in self.select_list] + self._create_pk_cols()
|
|
151
162
|
stmt = sql.select(*sql_select_list)
|
|
152
163
|
|
|
153
164
|
where_clause_element = (
|
|
@@ -173,9 +184,10 @@ class SqlNode(ExecNode):
|
|
|
173
184
|
def _ordering_tbl_ids(self) -> set[UUID]:
|
|
174
185
|
return exprs.Expr.all_tbl_ids(e for e, _ in self.order_by_clause)
|
|
175
186
|
|
|
176
|
-
def to_cte(self) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
187
|
+
def to_cte(self, keep_pk: bool = False) -> Optional[tuple[sql.CTE, exprs.ExprDict[sql.ColumnElement]]]:
|
|
177
188
|
"""
|
|
178
|
-
|
|
189
|
+
Creates a CTE that materializes the output of this node plus a mapping from select list expr to output column.
|
|
190
|
+
keep_pk: if True, the PK columns are included in the CTE Select statement
|
|
179
191
|
|
|
180
192
|
Returns:
|
|
181
193
|
(CTE, dict from Expr to output column)
|
|
@@ -183,11 +195,13 @@ class SqlNode(ExecNode):
|
|
|
183
195
|
if self.py_filter is not None:
|
|
184
196
|
# the filter needs to run in Python
|
|
185
197
|
return None
|
|
186
|
-
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
187
198
|
if self.cte is None:
|
|
199
|
+
if not keep_pk:
|
|
200
|
+
self.set_pk = False # we don't need the PK if we use this SqlNode as a CTE
|
|
188
201
|
self.cte = self._create_stmt().cte()
|
|
189
|
-
|
|
190
|
-
|
|
202
|
+
pk_count = self.num_pk_cols if self.set_pk else 0
|
|
203
|
+
assert len(self.select_list) + pk_count == len(self.cte.c)
|
|
204
|
+
return self.cte, exprs.ExprDict(zip(self.select_list, self.cte.c)) # skip pk cols
|
|
191
205
|
|
|
192
206
|
@classmethod
|
|
193
207
|
def retarget_rowid_refs(cls, target: catalog.TableVersionPath, expr_seq: Iterable[exprs.Expr]) -> None:
|
|
@@ -293,7 +307,9 @@ class SqlNode(ExecNode):
|
|
|
293
307
|
stmt_str = str(stmt.compile(compile_kwargs={'literal_binds': True}))
|
|
294
308
|
_logger.debug(f'SqlLookupNode stmt:\n{stmt_str}')
|
|
295
309
|
except Exception:
|
|
296
|
-
|
|
310
|
+
# log something if we can't log the compiled stmt
|
|
311
|
+
stmt_str = repr(stmt)
|
|
312
|
+
_logger.debug(f'SqlLookupNode proto-stmt:\n{stmt_str}')
|
|
297
313
|
self._log_explain(stmt)
|
|
298
314
|
|
|
299
315
|
conn = Env.get().conn
|
|
@@ -510,3 +526,127 @@ class SqlJoinNode(SqlNode):
|
|
|
510
526
|
full=join_clause == plan.JoinType.FULL_OUTER,
|
|
511
527
|
)
|
|
512
528
|
return stmt
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class SqlSampleNode(SqlNode):
|
|
532
|
+
"""
|
|
533
|
+
Returns rows from a stratified sample with N samples per strata.
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
stratify_exprs: Optional[list[exprs.Expr]]
|
|
537
|
+
n_samples: Optional[int]
|
|
538
|
+
fraction_samples: Optional[float]
|
|
539
|
+
seed: int
|
|
540
|
+
input_cte: Optional[sql.CTE]
|
|
541
|
+
pk_count: int
|
|
542
|
+
|
|
543
|
+
def __init__(
|
|
544
|
+
self,
|
|
545
|
+
row_builder: exprs.RowBuilder,
|
|
546
|
+
input: SqlNode,
|
|
547
|
+
select_list: Iterable[exprs.Expr],
|
|
548
|
+
stratify_exprs: Optional[list[exprs.Expr]] = None,
|
|
549
|
+
sample_clause: Optional['SampleClause'] = None,
|
|
550
|
+
):
|
|
551
|
+
"""
|
|
552
|
+
Args:
|
|
553
|
+
select_list: can contain calls to AggregateFunctions
|
|
554
|
+
stratify_exprs: list of expressions to group by
|
|
555
|
+
n: number of samples per strata
|
|
556
|
+
"""
|
|
557
|
+
self.input_cte, input_col_map = input.to_cte(keep_pk=True)
|
|
558
|
+
self.pk_count = input.num_pk_cols
|
|
559
|
+
assert self.pk_count > 1
|
|
560
|
+
sql_elements = exprs.SqlElementCache(input_col_map)
|
|
561
|
+
super().__init__(input.tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
562
|
+
self.stratify_exprs = stratify_exprs
|
|
563
|
+
self.n_samples = sample_clause.n
|
|
564
|
+
self.n_per_stratum = sample_clause.n_per_stratum
|
|
565
|
+
self.fraction_samples = sample_clause.fraction
|
|
566
|
+
self.seed = sample_clause.seed if sample_clause.seed is not None else 0
|
|
567
|
+
|
|
568
|
+
@classmethod
|
|
569
|
+
def key_sql_expr(cls, seed: sql.ColumnElement, sql_cols: Iterable[sql.ColumnElement]) -> sql.ColumnElement:
|
|
570
|
+
"""Construct expression which is the ordering key for rows to be sampled
|
|
571
|
+
General SQL form is:
|
|
572
|
+
- MD5(<seed::text> [ + '___' + <rowid_col_val>::text]+
|
|
573
|
+
"""
|
|
574
|
+
sql_expr: sql.ColumnElement = sql.cast(seed, sql.Text)
|
|
575
|
+
for e in sql_cols:
|
|
576
|
+
sql_expr = sql_expr + sql.literal_column("'___'") + sql.cast(e, sql.Text)
|
|
577
|
+
sql_expr = sql.func.md5(sql_expr)
|
|
578
|
+
return sql_expr
|
|
579
|
+
|
|
580
|
+
def _create_order_by(self, cte: sql.CTE) -> sql.ColumnElement:
|
|
581
|
+
"""Create an expression for randomly ordering rows with a given seed"""
|
|
582
|
+
rowid_cols = [*cte.c[-self.pk_count : -1]] # exclude the version column
|
|
583
|
+
assert len(rowid_cols) > 0
|
|
584
|
+
return self.key_sql_expr(sql.literal_column(str(self.seed)), rowid_cols)
|
|
585
|
+
|
|
586
|
+
def _create_stmt(self) -> sql.Select:
|
|
587
|
+
if self.fraction_samples is not None:
|
|
588
|
+
return self._create_stmt_fraction(self.fraction_samples)
|
|
589
|
+
return self._create_stmt_n(self.n_samples, self.n_per_stratum)
|
|
590
|
+
|
|
591
|
+
def _create_stmt_n(self, n: Optional[int], n_per_stratum: Optional[int]) -> sql.Select:
|
|
592
|
+
"""Create a Select stmt that returns n samples across all strata"""
|
|
593
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
594
|
+
order_by = self._create_order_by(self.input_cte)
|
|
595
|
+
|
|
596
|
+
# Create a list of all columns plus the rank
|
|
597
|
+
# Get all columns from the input CTE dynamically
|
|
598
|
+
select_columns = [*self.input_cte.c]
|
|
599
|
+
select_columns.append(
|
|
600
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
601
|
+
)
|
|
602
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
603
|
+
|
|
604
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
605
|
+
if n_per_stratum is not None:
|
|
606
|
+
return sql.select(*final_columns).filter(row_rank_cte.c.rank <= n_per_stratum)
|
|
607
|
+
else:
|
|
608
|
+
secondary_order = self._create_order_by(row_rank_cte)
|
|
609
|
+
return sql.select(*final_columns).order_by(row_rank_cte.c.rank, secondary_order).limit(n)
|
|
610
|
+
|
|
611
|
+
def _create_stmt_fraction(self, fraction_samples: float) -> sql.Select:
|
|
612
|
+
"""Create a Select stmt that returns a fraction of the rows per strata"""
|
|
613
|
+
|
|
614
|
+
# Build the strata count CTE
|
|
615
|
+
# Produces a table of the form:
|
|
616
|
+
# ([stratify_exprs], s_s_size)
|
|
617
|
+
# where s_s_size is the number of samples to take from each stratum
|
|
618
|
+
sql_strata_exprs = [self.sql_elements.get(e) for e in self.stratify_exprs]
|
|
619
|
+
per_strata_count_cte = (
|
|
620
|
+
sql.select(
|
|
621
|
+
*sql_strata_exprs,
|
|
622
|
+
sql.func.ceil(fraction_samples * sql.func.count(1).cast(sql.Integer)).label('s_s_size'),
|
|
623
|
+
)
|
|
624
|
+
.select_from(self.input_cte)
|
|
625
|
+
.group_by(*sql_strata_exprs)
|
|
626
|
+
.cte('per_strata_count_cte')
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Build a CTE that ranks the rows within each stratum
|
|
630
|
+
# Include all columns from the input CTE dynamically
|
|
631
|
+
order_by = self._create_order_by(self.input_cte)
|
|
632
|
+
select_columns = [*self.input_cte.c]
|
|
633
|
+
select_columns.append(
|
|
634
|
+
sql.func.row_number().over(partition_by=sql_strata_exprs, order_by=order_by).label('rank')
|
|
635
|
+
)
|
|
636
|
+
row_rank_cte = sql.select(*select_columns).select_from(self.input_cte).cte('row_rank_cte')
|
|
637
|
+
|
|
638
|
+
# Build the join criterion dynamically to accommodate any number of group by columns
|
|
639
|
+
join_c = sql.true()
|
|
640
|
+
for col in per_strata_count_cte.c[:-1]:
|
|
641
|
+
join_c &= row_rank_cte.c[col.name].isnot_distinct_from(col)
|
|
642
|
+
|
|
643
|
+
# Join srcp with per_strata_count_cte to limit returns to the requested fraction of rows
|
|
644
|
+
final_columns = [*row_rank_cte.c[:-1]] # exclude the rank column
|
|
645
|
+
stmt = (
|
|
646
|
+
sql.select(*final_columns)
|
|
647
|
+
.select_from(row_rank_cte)
|
|
648
|
+
.join(per_strata_count_cte, join_c)
|
|
649
|
+
.where(row_rank_cte.c.rank <= per_strata_count_cte.c.s_s_size)
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
return stmt
|
pixeltable/exprs/data_row.py
CHANGED
|
@@ -214,6 +214,7 @@ class DataRow:
|
|
|
214
214
|
"""Assign in-memory cell value
|
|
215
215
|
This allows overwriting
|
|
216
216
|
"""
|
|
217
|
+
assert isinstance(idx, int)
|
|
217
218
|
assert self.excs[idx] is None
|
|
218
219
|
|
|
219
220
|
if (idx in self.img_slot_idxs or idx in self.media_slot_idxs) and isinstance(val, str):
|
|
@@ -253,14 +254,15 @@ class DataRow:
|
|
|
253
254
|
assert self.excs[index] is None
|
|
254
255
|
if self.file_paths[index] is None:
|
|
255
256
|
if filepath is not None:
|
|
256
|
-
# we want to save this to a file
|
|
257
|
-
self.file_paths[index] = filepath
|
|
258
|
-
self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
|
|
259
257
|
image = self.vals[index]
|
|
260
258
|
assert isinstance(image, PIL.Image.Image)
|
|
261
259
|
# Default to JPEG unless the image has a transparency layer (which isn't supported by JPEG).
|
|
262
260
|
# In that case, use WebP instead.
|
|
263
261
|
format = 'webp' if image.has_transparency_data else 'jpeg'
|
|
262
|
+
if not filepath.endswith(f'.{format}'):
|
|
263
|
+
filepath += f'.{format}'
|
|
264
|
+
self.file_paths[index] = filepath
|
|
265
|
+
self.file_urls[index] = urllib.parse.urljoin('file:', urllib.request.pathname2url(filepath))
|
|
264
266
|
image.save(filepath, format=format)
|
|
265
267
|
else:
|
|
266
268
|
# we discard the content of this cell
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -276,6 +276,13 @@ class Expr(abc.ABC):
|
|
|
276
276
|
tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
|
|
277
277
|
return self._retarget(tbl_versions)
|
|
278
278
|
|
|
279
|
+
@classmethod
|
|
280
|
+
def retarget_list(cls, expr_list: list[Expr], tbl: catalog.TableVersionPath) -> None:
|
|
281
|
+
"""Retarget ColumnRefs in expr_list to the specific TableVersions in tbl."""
|
|
282
|
+
tbl_versions = {tbl_version.id: tbl_version.get() for tbl_version in tbl.get_tbl_versions()}
|
|
283
|
+
for i in range(len(expr_list)):
|
|
284
|
+
expr_list[i] = expr_list[i]._retarget(tbl_versions)
|
|
285
|
+
|
|
279
286
|
def _retarget(self, tbl_versions: dict[UUID, catalog.TableVersion]) -> Self:
|
|
280
287
|
for i in range(len(self.components)):
|
|
281
288
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
pixeltable/exprs/literal.py
CHANGED
pixeltable/func/tools.py
CHANGED
|
@@ -8,7 +8,7 @@ the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anth
|
|
|
8
8
|
import datetime
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Iterable, Optional,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
|
|
12
12
|
|
|
13
13
|
import httpx
|
|
14
14
|
|
|
@@ -73,16 +73,10 @@ async def messages(
|
|
|
73
73
|
messages: list[dict[str, str]],
|
|
74
74
|
*,
|
|
75
75
|
model: str,
|
|
76
|
-
max_tokens: int
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
temperature: Optional[float] = None,
|
|
81
|
-
tool_choice: Optional[dict] = None,
|
|
82
|
-
tools: Optional[list[dict]] = None,
|
|
83
|
-
top_k: Optional[int] = None,
|
|
84
|
-
top_p: Optional[float] = None,
|
|
85
|
-
timeout: Optional[float] = None,
|
|
76
|
+
max_tokens: int,
|
|
77
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
78
|
+
tools: Optional[list[dict[str, Any]]] = None,
|
|
79
|
+
tool_choice: Optional[dict[str, Any]] = None,
|
|
86
80
|
) -> dict:
|
|
87
81
|
"""
|
|
88
82
|
Create a Message.
|
|
@@ -101,25 +95,27 @@ async def messages(
|
|
|
101
95
|
Args:
|
|
102
96
|
messages: Input messages.
|
|
103
97
|
model: The model that will complete your prompt.
|
|
104
|
-
|
|
105
|
-
|
|
98
|
+
model_kwargs: Additional keyword args for the Anthropic `messages` API.
|
|
99
|
+
For details on the available parameters, see: <https://docs.anthropic.com/en/api/messages>
|
|
100
|
+
tools: An optional list of Pixeltable tools to use for the request.
|
|
101
|
+
tool_choice: An optional tool choice configuration.
|
|
106
102
|
|
|
107
103
|
Returns:
|
|
108
104
|
A dictionary containing the response and other metadata.
|
|
109
105
|
|
|
110
106
|
Examples:
|
|
111
|
-
Add a computed column that applies the model `claude-3-
|
|
107
|
+
Add a computed column that applies the model `claude-3-5-sonnet-20241022`
|
|
112
108
|
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
113
109
|
|
|
114
110
|
>>> msgs = [{'role': 'user', 'content': tbl.prompt}]
|
|
115
|
-
... tbl.add_computed_column(response=messages(msgs, model='claude-3-
|
|
111
|
+
... tbl.add_computed_column(response=messages(msgs, model='claude-3-5-sonnet-20241022'))
|
|
116
112
|
"""
|
|
117
|
-
|
|
118
|
-
|
|
113
|
+
if model_kwargs is None:
|
|
114
|
+
model_kwargs = {}
|
|
119
115
|
|
|
120
116
|
if tools is not None:
|
|
121
117
|
# Reformat `tools` into Anthropic format
|
|
122
|
-
tools = [
|
|
118
|
+
model_kwargs['tools'] = [
|
|
123
119
|
{
|
|
124
120
|
'name': tool['name'],
|
|
125
121
|
'description': tool['description'],
|
|
@@ -132,17 +128,16 @@ async def messages(
|
|
|
132
128
|
for tool in tools
|
|
133
129
|
]
|
|
134
130
|
|
|
135
|
-
tool_choice_: Optional[dict] = None
|
|
136
131
|
if tool_choice is not None:
|
|
137
132
|
if tool_choice['auto']:
|
|
138
|
-
|
|
133
|
+
model_kwargs['tool_choice'] = {'type': 'auto'}
|
|
139
134
|
elif tool_choice['required']:
|
|
140
|
-
|
|
135
|
+
model_kwargs['tool_choice'] = {'type': 'any'}
|
|
141
136
|
else:
|
|
142
137
|
assert tool_choice['tool'] is not None
|
|
143
|
-
|
|
138
|
+
model_kwargs['tool_choice'] = {'type': 'tool', 'name': tool_choice['tool']}
|
|
144
139
|
if not tool_choice['parallel_tool_calls']:
|
|
145
|
-
|
|
140
|
+
model_kwargs['tool_choice']['disable_parallel_tool_use'] = True
|
|
146
141
|
|
|
147
142
|
# make sure the pool info exists prior to making the request
|
|
148
143
|
resource_pool_id = f'rate-limits:anthropic:{model}'
|
|
@@ -152,20 +147,8 @@ async def messages(
|
|
|
152
147
|
# TODO: timeouts should be set system-wide and be user-configurable
|
|
153
148
|
from anthropic.types import MessageParam
|
|
154
149
|
|
|
155
|
-
# cast(Any, ...): avoid mypy errors
|
|
156
150
|
result = await _anthropic_client().messages.with_raw_response.create(
|
|
157
|
-
messages=cast(Iterable[MessageParam], messages),
|
|
158
|
-
model=model,
|
|
159
|
-
max_tokens=max_tokens,
|
|
160
|
-
metadata=_opt(cast(Any, metadata)),
|
|
161
|
-
stop_sequences=_opt(stop_sequences),
|
|
162
|
-
system=_opt(system),
|
|
163
|
-
temperature=_opt(cast(Any, temperature)),
|
|
164
|
-
tools=_opt(cast(Any, tools)),
|
|
165
|
-
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
166
|
-
top_k=_opt(top_k),
|
|
167
|
-
top_p=_opt(top_p),
|
|
168
|
-
timeout=_opt(timeout),
|
|
151
|
+
messages=cast(Iterable[MessageParam], messages), model=model, max_tokens=max_tokens, **model_kwargs
|
|
169
152
|
)
|
|
170
153
|
|
|
171
154
|
requests_limit_str = result.headers.get('anthropic-ratelimit-requests-limit')
|
|
@@ -224,15 +207,6 @@ def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
|
|
|
224
207
|
return pxt_tool_calls
|
|
225
208
|
|
|
226
209
|
|
|
227
|
-
_T = TypeVar('_T')
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
def _opt(arg: _T) -> Union[_T, 'anthropic.NotGiven']:
|
|
231
|
-
import anthropic
|
|
232
|
-
|
|
233
|
-
return arg if arg is not None else anthropic.NOT_GIVEN
|
|
234
|
-
|
|
235
|
-
|
|
236
210
|
__all__ = local_public_names(__name__)
|
|
237
211
|
|
|
238
212
|
|
pixeltable/functions/deepseek.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
3
|
|
|
4
4
|
import httpx
|
|
5
5
|
|
|
@@ -7,8 +7,6 @@ import pixeltable as pxt
|
|
|
7
7
|
from pixeltable import env
|
|
8
8
|
from pixeltable.utils.code import local_public_names
|
|
9
9
|
|
|
10
|
-
from .openai import _opt
|
|
11
|
-
|
|
12
10
|
if TYPE_CHECKING:
|
|
13
11
|
import openai
|
|
14
12
|
|
|
@@ -33,17 +31,9 @@ async def chat_completions(
|
|
|
33
31
|
messages: list,
|
|
34
32
|
*,
|
|
35
33
|
model: str,
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
max_tokens: Optional[int] = None,
|
|
40
|
-
presence_penalty: Optional[float] = None,
|
|
41
|
-
response_format: Optional[dict] = None,
|
|
42
|
-
stop: Optional[list[str]] = None,
|
|
43
|
-
temperature: Optional[float] = None,
|
|
44
|
-
tools: Optional[list[dict]] = None,
|
|
45
|
-
tool_choice: Optional[dict] = None,
|
|
46
|
-
top_p: Optional[float] = None,
|
|
34
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
35
|
+
tools: Optional[list[dict[str, Any]]] = None,
|
|
36
|
+
tool_choice: Optional[dict[str, Any]] = None,
|
|
47
37
|
) -> dict:
|
|
48
38
|
"""
|
|
49
39
|
Creates a model response for the given chat conversation.
|
|
@@ -60,8 +50,10 @@ async def chat_completions(
|
|
|
60
50
|
Args:
|
|
61
51
|
messages: A list of messages to use for chat completion, as described in the Deepseek API documentation.
|
|
62
52
|
model: The model to use for chat completion.
|
|
63
|
-
|
|
64
|
-
|
|
53
|
+
model_kwargs: Additional keyword args for the Deepseek `chat/completions` API.
|
|
54
|
+
For details on the available parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
|
|
55
|
+
tools: An optional list of Pixeltable tools to use for the request.
|
|
56
|
+
tool_choice: An optional tool choice configuration.
|
|
65
57
|
|
|
66
58
|
Returns:
|
|
67
59
|
A dictionary containing the response and other metadata.
|
|
@@ -76,39 +68,28 @@ async def chat_completions(
|
|
|
76
68
|
]
|
|
77
69
|
tbl.add_computed_column(response=chat_completions(messages, model='deepseek-chat'))
|
|
78
70
|
"""
|
|
71
|
+
if model_kwargs is None:
|
|
72
|
+
model_kwargs = {}
|
|
73
|
+
|
|
79
74
|
if tools is not None:
|
|
80
|
-
tools = [{'type': 'function', 'function': tool} for tool in tools]
|
|
75
|
+
model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
|
|
81
76
|
|
|
82
|
-
tool_choice_: Union[str, dict, None] = None
|
|
83
77
|
if tool_choice is not None:
|
|
84
78
|
if tool_choice['auto']:
|
|
85
|
-
|
|
79
|
+
model_kwargs['tool_choice'] = 'auto'
|
|
86
80
|
elif tool_choice['required']:
|
|
87
|
-
|
|
81
|
+
model_kwargs['tool_choice'] = 'required'
|
|
88
82
|
else:
|
|
89
83
|
assert tool_choice['tool'] is not None
|
|
90
|
-
|
|
84
|
+
model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
|
|
91
85
|
|
|
92
|
-
extra_body: Optional[dict[str, Any]] = None
|
|
93
86
|
if tool_choice is not None and not tool_choice['parallel_tool_calls']:
|
|
94
|
-
extra_body
|
|
87
|
+
if 'extra_body' not in model_kwargs:
|
|
88
|
+
model_kwargs['extra_body'] = {}
|
|
89
|
+
model_kwargs['extra_body']['parallel_tool_calls'] = False
|
|
95
90
|
|
|
96
|
-
# cast(Any, ...): avoid mypy errors
|
|
97
91
|
result = await _deepseek_client().chat.completions.with_raw_response.create(
|
|
98
|
-
messages=messages,
|
|
99
|
-
model=model,
|
|
100
|
-
frequency_penalty=_opt(frequency_penalty),
|
|
101
|
-
logprobs=_opt(logprobs),
|
|
102
|
-
top_logprobs=_opt(top_logprobs),
|
|
103
|
-
max_tokens=_opt(max_tokens),
|
|
104
|
-
presence_penalty=_opt(presence_penalty),
|
|
105
|
-
response_format=_opt(cast(Any, response_format)),
|
|
106
|
-
stop=_opt(stop),
|
|
107
|
-
temperature=_opt(temperature),
|
|
108
|
-
tools=_opt(cast(Any, tools)),
|
|
109
|
-
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
110
|
-
top_p=_opt(top_p),
|
|
111
|
-
extra_body=extra_body,
|
|
92
|
+
messages=messages, model=model, **model_kwargs
|
|
112
93
|
)
|
|
113
94
|
|
|
114
95
|
return json.loads(result.text)
|
|
@@ -5,7 +5,7 @@ first `pip install fireworks-ai` and configure your Fireworks AI credentials, as
|
|
|
5
5
|
the [Working with Fireworks](https://pixeltable.readme.io/docs/working-with-fireworks) tutorial.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import TYPE_CHECKING, Optional
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
9
9
|
|
|
10
10
|
import pixeltable as pxt
|
|
11
11
|
from pixeltable import env
|
|
@@ -29,14 +29,7 @@ def _fireworks_client() -> 'fireworks.client.Fireworks':
|
|
|
29
29
|
|
|
30
30
|
@pxt.udf(resource_pool='request-rate:fireworks')
|
|
31
31
|
async def chat_completions(
|
|
32
|
-
messages: list[dict[str, str]],
|
|
33
|
-
*,
|
|
34
|
-
model: str,
|
|
35
|
-
max_tokens: Optional[int] = None,
|
|
36
|
-
top_k: Optional[int] = None,
|
|
37
|
-
top_p: Optional[float] = None,
|
|
38
|
-
temperature: Optional[float] = None,
|
|
39
|
-
request_timeout: Optional[int] = None,
|
|
32
|
+
messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
|
|
40
33
|
) -> dict:
|
|
41
34
|
"""
|
|
42
35
|
Creates a model response for the given chat conversation.
|
|
@@ -55,8 +48,8 @@ async def chat_completions(
|
|
|
55
48
|
Args:
|
|
56
49
|
messages: A list of messages comprising the conversation so far.
|
|
57
50
|
model: The name of the model to use.
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
model_kwargs: Additional keyword args for the Fireworks `chat_completions` API. For details on the available
|
|
52
|
+
parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
|
|
60
53
|
|
|
61
54
|
Returns:
|
|
62
55
|
A dictionary containing the response and other metadata.
|
|
@@ -70,20 +63,18 @@ async def chat_completions(
|
|
|
70
63
|
... response=chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct')
|
|
71
64
|
... )
|
|
72
65
|
"""
|
|
73
|
-
|
|
74
|
-
|
|
66
|
+
if model_kwargs is None:
|
|
67
|
+
model_kwargs = {}
|
|
75
68
|
|
|
76
69
|
# for debugging purposes:
|
|
77
70
|
# res_sync = _fireworks_client().chat.completions.create(model=model, messages=messages, **kwargs_not_none)
|
|
78
71
|
# res_sync_dict = res_sync.dict()
|
|
79
72
|
|
|
80
|
-
if request_timeout
|
|
81
|
-
request_timeout = Config.get().get_int_value('timeout', section='fireworks') or 600
|
|
73
|
+
if 'request_timeout' not in model_kwargs:
|
|
74
|
+
model_kwargs['request_timeout'] = Config.get().get_int_value('timeout', section='fireworks') or 600
|
|
82
75
|
# TODO: this timeout doesn't really work, I think it only applies to returning the stream, but not to the timing
|
|
83
76
|
# of the chunks; addressing this would require a timeout for the task running this udf
|
|
84
|
-
stream = _fireworks_client().chat.completions.acreate(
|
|
85
|
-
model=model, messages=messages, request_timeout=request_timeout, **kwargs_not_none
|
|
86
|
-
)
|
|
77
|
+
stream = _fireworks_client().chat.completions.acreate(model=model, messages=messages, **model_kwargs)
|
|
87
78
|
chunks = []
|
|
88
79
|
async for chunk in stream:
|
|
89
80
|
chunks.append(chunk)
|
pixeltable/functions/gemini.py
CHANGED
|
@@ -53,8 +53,8 @@ async def generate_content(
|
|
|
53
53
|
config: Configuration for generation, corresponding to keyword arguments of
|
|
54
54
|
`genai.types.GenerateContentConfig`. For details on the parameters, see:
|
|
55
55
|
<https://googleapis.github.io/python-genai/genai.html#module-genai.types>
|
|
56
|
-
tools:
|
|
57
|
-
`config
|
|
56
|
+
tools: An optional list of Pixeltable tools to use. It is also possible to specify tools manually via the
|
|
57
|
+
`config['tools']` parameter, but at most one of `config['tools']` or `tools` may be used.
|
|
58
58
|
|
|
59
59
|
Returns:
|
|
60
60
|
A dictionary containing the response and other metadata.
|
|
@@ -103,7 +103,6 @@ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDic
|
|
|
103
103
|
|
|
104
104
|
@pxt.udf
|
|
105
105
|
def _gemini_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
|
|
106
|
-
print(response)
|
|
107
106
|
pxt_tool_calls: dict[str, list[dict]] = {}
|
|
108
107
|
for part in response['candidates'][0]['content']['parts']:
|
|
109
108
|
tool_call = part.get('function_call')
|
|
@@ -17,7 +17,7 @@ def create_chat_completion(
|
|
|
17
17
|
model_path: Optional[str] = None,
|
|
18
18
|
repo_id: Optional[str] = None,
|
|
19
19
|
repo_filename: Optional[str] = None,
|
|
20
|
-
|
|
20
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
21
21
|
) -> dict:
|
|
22
22
|
"""
|
|
23
23
|
Generate a chat completion from a list of messages.
|
|
@@ -35,14 +35,14 @@ def create_chat_completion(
|
|
|
35
35
|
repo_id: The Hugging Face model repo id (if using a pretrained model).
|
|
36
36
|
repo_filename: A filename or glob pattern to match the model file in the repo (optional, if using a
|
|
37
37
|
pretrained model).
|
|
38
|
-
|
|
39
|
-
`top_p`, and `top_k`. For details, see the
|
|
38
|
+
model_kwargs: Additional keyword args for the llama_cpp `create_chat_completions` API, such as `max_tokens`,
|
|
39
|
+
`temperature`, `top_p`, and `top_k`. For details, see the
|
|
40
40
|
[llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
|
|
41
41
|
"""
|
|
42
42
|
Env.get().require_package('llama_cpp', min_version=[0, 3, 1])
|
|
43
43
|
|
|
44
|
-
if
|
|
45
|
-
|
|
44
|
+
if model_kwargs is None:
|
|
45
|
+
model_kwargs = {}
|
|
46
46
|
|
|
47
47
|
if (model_path is None) == (repo_id is None):
|
|
48
48
|
raise excs.Error('Exactly one of `model_path` or `repo_id` must be provided.')
|
|
@@ -56,7 +56,7 @@ def create_chat_completion(
|
|
|
56
56
|
else:
|
|
57
57
|
Env.get().require_package('huggingface_hub')
|
|
58
58
|
llm = _lookup_pretrained_model(repo_id, repo_filename, n_gpu_layers)
|
|
59
|
-
return llm.create_chat_completion(messages, **
|
|
59
|
+
return llm.create_chat_completion(messages, **model_kwargs) # type: ignore
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
def _is_gpu_available() -> bool:
|