pixeltable 0.2.30__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/table.py +212 -173
- pixeltable/catalog/table_version.py +2 -1
- pixeltable/catalog/view.py +3 -5
- pixeltable/dataframe.py +52 -39
- pixeltable/env.py +94 -5
- pixeltable/exec/__init__.py +1 -1
- pixeltable/exec/aggregation_node.py +3 -3
- pixeltable/exec/cache_prefetch_node.py +13 -7
- pixeltable/exec/component_iteration_node.py +3 -9
- pixeltable/exec/data_row_batch.py +17 -5
- pixeltable/exec/exec_node.py +32 -12
- pixeltable/exec/expr_eval/__init__.py +1 -0
- pixeltable/exec/expr_eval/evaluators.py +245 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +404 -0
- pixeltable/exec/expr_eval/globals.py +114 -0
- pixeltable/exec/expr_eval/row_buffer.py +76 -0
- pixeltable/exec/expr_eval/schedulers.py +232 -0
- pixeltable/exec/in_memory_data_node.py +2 -2
- pixeltable/exec/row_update_node.py +14 -14
- pixeltable/exec/sql_node.py +2 -2
- pixeltable/exprs/column_ref.py +5 -1
- pixeltable/exprs/data_row.py +50 -40
- pixeltable/exprs/expr.py +57 -12
- pixeltable/exprs/function_call.py +54 -19
- pixeltable/exprs/inline_expr.py +12 -21
- pixeltable/exprs/literal.py +25 -8
- pixeltable/exprs/row_builder.py +23 -0
- pixeltable/exprs/similarity_expr.py +4 -4
- pixeltable/func/__init__.py +5 -5
- pixeltable/func/aggregate_function.py +4 -0
- pixeltable/func/callable_function.py +54 -6
- pixeltable/func/expr_template_function.py +5 -1
- pixeltable/func/function.py +54 -13
- pixeltable/func/query_template_function.py +56 -10
- pixeltable/func/tools.py +51 -14
- pixeltable/func/udf.py +7 -1
- pixeltable/functions/__init__.py +1 -1
- pixeltable/functions/anthropic.py +108 -21
- pixeltable/functions/gemini.py +2 -6
- pixeltable/functions/huggingface.py +10 -28
- pixeltable/functions/openai.py +225 -28
- pixeltable/globals.py +8 -5
- pixeltable/index/embedding_index.py +90 -38
- pixeltable/io/label_studio.py +1 -1
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_24.py +11 -2
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/plan.py +24 -9
- pixeltable/store.py +6 -0
- pixeltable/type_system.py +4 -7
- pixeltable/utils/arrow.py +3 -3
- {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/METADATA +5 -11
- {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/RECORD +59 -53
- pixeltable/exec/expr_eval_node.py +0 -232
- {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.30.dist-info → pixeltable-0.3.1.dist-info}/entry_points.txt +0 -0
pixeltable/func/function.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import abc
|
|
4
3
|
import importlib
|
|
5
4
|
import inspect
|
|
5
|
+
from abc import abstractmethod, ABC
|
|
6
6
|
from copy import copy
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, cast
|
|
8
8
|
|
|
@@ -12,7 +12,6 @@ from typing_extensions import Self
|
|
|
12
12
|
import pixeltable as pxt
|
|
13
13
|
import pixeltable.exceptions as excs
|
|
14
14
|
import pixeltable.type_system as ts
|
|
15
|
-
|
|
16
15
|
from .globals import resolve_symbol
|
|
17
16
|
from .signature import Signature
|
|
18
17
|
|
|
@@ -20,7 +19,7 @@ if TYPE_CHECKING:
|
|
|
20
19
|
from .expr_template_function import ExprTemplate, ExprTemplateFunction
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class Function(
|
|
22
|
+
class Function(ABC):
|
|
24
23
|
"""Base class for Pixeltable's function interface.
|
|
25
24
|
|
|
26
25
|
A function in Pixeltable is an object that has a signature and implements __call__().
|
|
@@ -44,6 +43,12 @@ class Function(abc.ABC):
|
|
|
44
43
|
# parameter names as the original function. Each parameter is going to be of type sql.ColumnElement.
|
|
45
44
|
_to_sql: Callable[..., Optional[sql.ColumnElement]]
|
|
46
45
|
|
|
46
|
+
# Returns the resource pool to use for calling this function with the given arguments.
|
|
47
|
+
# Overriden for specific Function instances via the resource_pool() decorator. The override must accept a subset
|
|
48
|
+
# of the parameters of the original function, with the same type.
|
|
49
|
+
_resource_pool: Callable[..., Optional[str]]
|
|
50
|
+
|
|
51
|
+
|
|
47
52
|
def __init__(
|
|
48
53
|
self,
|
|
49
54
|
signatures: list[Signature],
|
|
@@ -60,9 +65,9 @@ class Function(abc.ABC):
|
|
|
60
65
|
self.is_method = is_method
|
|
61
66
|
self.is_property = is_property
|
|
62
67
|
self._conditional_return_type = None
|
|
63
|
-
self._to_sql = self.__default_to_sql
|
|
64
|
-
|
|
65
68
|
self.__resolved_fns = []
|
|
69
|
+
self._to_sql = self.__default_to_sql
|
|
70
|
+
self._resource_pool = self.__default_resource_pool
|
|
66
71
|
|
|
67
72
|
@property
|
|
68
73
|
def name(self) -> str:
|
|
@@ -92,6 +97,10 @@ class Function(abc.ABC):
|
|
|
92
97
|
assert not self.is_polymorphic
|
|
93
98
|
return len(self.signature.parameters)
|
|
94
99
|
|
|
100
|
+
@property
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def is_async(self) -> bool: ...
|
|
103
|
+
|
|
95
104
|
def _docstring(self) -> Optional[str]:
|
|
96
105
|
return None
|
|
97
106
|
|
|
@@ -119,6 +128,7 @@ class Function(abc.ABC):
|
|
|
119
128
|
for idx in range(len(self.signatures)):
|
|
120
129
|
resolution = cast(Self, copy(self))
|
|
121
130
|
resolution.signatures = [self.signatures[idx]]
|
|
131
|
+
resolution.__resolved_fns = [resolution] # Resolves to itself
|
|
122
132
|
resolution._update_as_overload_resolution(idx)
|
|
123
133
|
self.__resolved_fns.append(resolution)
|
|
124
134
|
|
|
@@ -183,6 +193,26 @@ class Function(abc.ABC):
|
|
|
183
193
|
"""Override this to do custom validation of the arguments"""
|
|
184
194
|
assert not self.is_polymorphic
|
|
185
195
|
|
|
196
|
+
def _get_callable_args(self, callable: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
197
|
+
"""Return the kwargs to pass to callable, given kwargs passed to this function"""
|
|
198
|
+
bound_args = self.signature.py_signature.bind(**kwargs).arguments
|
|
199
|
+
# add defaults to bound_args, if not already present
|
|
200
|
+
bound_args.update({
|
|
201
|
+
name: param.default
|
|
202
|
+
for name, param in self.signature.parameters.items() if name not in bound_args and param.has_default()
|
|
203
|
+
})
|
|
204
|
+
result: dict[str, Any] = {}
|
|
205
|
+
sig = inspect.signature(callable)
|
|
206
|
+
for param in sig.parameters.values():
|
|
207
|
+
if param.name in bound_args:
|
|
208
|
+
result[param.name] = bound_args[param.name]
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
def call_resource_pool(self, kwargs: dict[str, Any]) -> str:
|
|
212
|
+
"""Return the resource pool to use for calling this function with the given arguments"""
|
|
213
|
+
kw_args = self._get_callable_args(self._resource_pool, kwargs)
|
|
214
|
+
return self._resource_pool(**kw_args)
|
|
215
|
+
|
|
186
216
|
def call_return_type(self, args: Sequence[Any], kwargs: dict[str, Any]) -> ts.ColumnType:
|
|
187
217
|
"""Return the type of the value returned by calling this function with the given arguments"""
|
|
188
218
|
assert not self.is_polymorphic
|
|
@@ -198,13 +228,12 @@ class Function(abc.ABC):
|
|
|
198
228
|
|
|
199
229
|
def conditional_return_type(self, fn: Callable[..., ts.ColumnType]) -> Callable[..., ts.ColumnType]:
|
|
200
230
|
"""Instance decorator for specifying a conditional return type for this function"""
|
|
201
|
-
if self.is_polymorphic:
|
|
202
|
-
raise excs.Error('`conditional_return_type` is not supported for functions with multiple signatures')
|
|
203
231
|
# verify that call_return_type only has parameters that are also present in the signature
|
|
204
|
-
|
|
205
|
-
for param in
|
|
206
|
-
|
|
207
|
-
|
|
232
|
+
fn_sig = inspect.signature(fn)
|
|
233
|
+
for param in fn_sig.parameters.values():
|
|
234
|
+
for self_sig in self.signatures:
|
|
235
|
+
if param.name not in self_sig.parameters:
|
|
236
|
+
raise ValueError(f'`conditional_return_type` has parameter `{param.name}` that is not in a signature')
|
|
208
237
|
self._conditional_return_type = fn
|
|
209
238
|
return fn
|
|
210
239
|
|
|
@@ -268,10 +297,13 @@ class Function(abc.ABC):
|
|
|
268
297
|
|
|
269
298
|
return ExprTemplate(call, new_signature)
|
|
270
299
|
|
|
271
|
-
@abc.abstractmethod
|
|
272
300
|
def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
|
|
273
301
|
"""Execute the function with the given arguments and return the result."""
|
|
274
|
-
|
|
302
|
+
raise NotImplementedError()
|
|
303
|
+
|
|
304
|
+
async def aexec(self, *args: Any, **kwargs: Any) -> Any:
|
|
305
|
+
"""Execute the function with the given arguments and return the result."""
|
|
306
|
+
raise NotImplementedError()
|
|
275
307
|
|
|
276
308
|
def to_sql(self, fn: Callable[..., Optional[sql.ColumnElement]]) -> Callable[..., Optional[sql.ColumnElement]]:
|
|
277
309
|
"""Instance decorator for specifying the SQL translation of this function"""
|
|
@@ -282,6 +314,15 @@ class Function(abc.ABC):
|
|
|
282
314
|
"""The default implementation of SQL translation, which provides no translation"""
|
|
283
315
|
return None
|
|
284
316
|
|
|
317
|
+
def resource_pool(self, fn: Callable[..., str]) -> Callable[..., str]:
|
|
318
|
+
"""Instance decorator for specifying the resource pool of this function"""
|
|
319
|
+
# TODO: check that fn's parameters are a subset of our parameters
|
|
320
|
+
self._resource_pool = fn
|
|
321
|
+
return fn
|
|
322
|
+
|
|
323
|
+
def __default_resource_pool(self) -> Optional[str]:
|
|
324
|
+
return None
|
|
325
|
+
|
|
285
326
|
def __eq__(self, other: object) -> bool:
|
|
286
327
|
if not isinstance(other, self.__class__):
|
|
287
328
|
return False
|
|
@@ -1,23 +1,31 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import Any, Callable, Optional, Sequence
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, overload
|
|
5
5
|
|
|
6
6
|
import sqlalchemy as sql
|
|
7
7
|
|
|
8
|
-
import pixeltable as
|
|
8
|
+
import pixeltable.exceptions as excs
|
|
9
|
+
import pixeltable.type_system as ts
|
|
9
10
|
from pixeltable import exprs
|
|
10
11
|
|
|
11
12
|
from .function import Function
|
|
12
13
|
from .signature import Signature
|
|
13
14
|
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pixeltable import DataFrame
|
|
17
|
+
|
|
14
18
|
|
|
15
19
|
class QueryTemplateFunction(Function):
|
|
16
20
|
"""A parameterized query/DataFrame from which an executable DataFrame is created with a function call."""
|
|
21
|
+
template_df: Optional['DataFrame']
|
|
22
|
+
self_name: Optional[str]
|
|
23
|
+
conn: Optional[sql.engine.Connection]
|
|
24
|
+
defaults: dict[str, exprs.Literal]
|
|
17
25
|
|
|
18
26
|
@classmethod
|
|
19
27
|
def create(
|
|
20
|
-
cls, template_callable: Callable, param_types: Optional[list[
|
|
28
|
+
cls, template_callable: Callable, param_types: Optional[list[ts.ColumnType]], path: str, name: str
|
|
21
29
|
) -> QueryTemplateFunction:
|
|
22
30
|
# we need to construct a template df and a signature
|
|
23
31
|
py_sig = inspect.signature(template_callable)
|
|
@@ -29,11 +37,11 @@ class QueryTemplateFunction(Function):
|
|
|
29
37
|
from pixeltable import DataFrame
|
|
30
38
|
assert isinstance(template_df, DataFrame)
|
|
31
39
|
# we take params and return json
|
|
32
|
-
sig = Signature(return_type=
|
|
40
|
+
sig = Signature(return_type=ts.JsonType(), parameters=params)
|
|
33
41
|
return QueryTemplateFunction(template_df, sig, path=path, name=name)
|
|
34
42
|
|
|
35
43
|
def __init__(
|
|
36
|
-
self, template_df: Optional['
|
|
44
|
+
self, template_df: Optional['DataFrame'], sig: Signature, path: Optional[str] = None,
|
|
37
45
|
name: Optional[str] = None,
|
|
38
46
|
):
|
|
39
47
|
assert sig is not None
|
|
@@ -44,10 +52,10 @@ class QueryTemplateFunction(Function):
|
|
|
44
52
|
# if we're running as part of an ongoing update operation, we need to use the same connection, otherwise
|
|
45
53
|
# we end up with a deadlock
|
|
46
54
|
# TODO: figure out a more general way to make execution state available
|
|
47
|
-
self.conn
|
|
55
|
+
self.conn = None
|
|
48
56
|
|
|
49
57
|
# convert defaults to Literals
|
|
50
|
-
self.defaults
|
|
58
|
+
self.defaults = {} # key: param name, value: default value converted to a Literal
|
|
51
59
|
param_types = self.template_df.parameters()
|
|
52
60
|
for param in [p for p in sig.parameters.values() if p.has_default()]:
|
|
53
61
|
assert param.name in param_types
|
|
@@ -61,14 +69,18 @@ class QueryTemplateFunction(Function):
|
|
|
61
69
|
def set_conn(self, conn: Optional[sql.engine.Connection]) -> None:
|
|
62
70
|
self.conn = conn
|
|
63
71
|
|
|
64
|
-
|
|
65
|
-
|
|
72
|
+
@property
|
|
73
|
+
def is_async(self) -> bool:
|
|
74
|
+
return True
|
|
75
|
+
|
|
76
|
+
async def aexec(self, *args: Any, **kwargs: Any) -> Any:
|
|
77
|
+
#assert not self.is_polymorphic
|
|
66
78
|
bound_args = self.signature.py_signature.bind(*args, **kwargs).arguments
|
|
67
79
|
# apply defaults, otherwise we might have Parameters left over
|
|
68
80
|
bound_args.update(
|
|
69
81
|
{param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
|
|
70
82
|
bound_df = self.template_df.bind(bound_args)
|
|
71
|
-
result = bound_df.
|
|
83
|
+
result = await bound_df._acollect(self.conn)
|
|
72
84
|
return list(result)
|
|
73
85
|
|
|
74
86
|
@property
|
|
@@ -86,3 +98,37 @@ class QueryTemplateFunction(Function):
|
|
|
86
98
|
def _from_dict(cls, d: dict) -> Function:
|
|
87
99
|
from pixeltable.dataframe import DataFrame
|
|
88
100
|
return cls(DataFrame.from_dict(d['df']), Signature.from_dict(d['signature']), name=d['name'])
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@overload
|
|
104
|
+
def query(py_fn: Callable) -> QueryTemplateFunction: ...
|
|
105
|
+
|
|
106
|
+
@overload
|
|
107
|
+
def query(
|
|
108
|
+
*,
|
|
109
|
+
param_types: Optional[list[ts.ColumnType]] = None
|
|
110
|
+
) -> Callable[[Callable], QueryTemplateFunction]: ...
|
|
111
|
+
|
|
112
|
+
def query(*args: Any, **kwargs: Any) -> Any:
|
|
113
|
+
def make_query_template(
|
|
114
|
+
py_fn: Callable, param_types: Optional[list[ts.ColumnType]]
|
|
115
|
+
) -> QueryTemplateFunction:
|
|
116
|
+
if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
|
|
117
|
+
# this is a named function in a module
|
|
118
|
+
function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
|
|
119
|
+
else:
|
|
120
|
+
function_path = None
|
|
121
|
+
query_name = py_fn.__name__
|
|
122
|
+
query_fn = QueryTemplateFunction.create(
|
|
123
|
+
py_fn, param_types=param_types, path=function_path, name=query_name)
|
|
124
|
+
return query_fn
|
|
125
|
+
|
|
126
|
+
# TODO: verify that the inferred return type matches that of the template
|
|
127
|
+
# TODO: verify that the signature doesn't contain batched parameters
|
|
128
|
+
|
|
129
|
+
if len(args) == 1:
|
|
130
|
+
assert len(kwargs) == 0 and callable(args[0])
|
|
131
|
+
return make_query_template(args[0], None)
|
|
132
|
+
else:
|
|
133
|
+
assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
|
|
134
|
+
return lambda py_fn: make_query_template(py_fn, kwargs['param_types'])
|
pixeltable/func/tools.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
from
|
|
2
|
-
import dataclasses
|
|
3
|
-
import json
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union
|
|
5
2
|
|
|
6
3
|
import pydantic
|
|
7
4
|
|
|
5
|
+
import pixeltable.exceptions as excs
|
|
6
|
+
|
|
8
7
|
from .function import Function
|
|
9
8
|
from .signature import Parameter
|
|
10
9
|
from .udf import udf
|
|
@@ -75,6 +74,13 @@ class Tool(pydantic.BaseModel):
|
|
|
75
74
|
assert False
|
|
76
75
|
|
|
77
76
|
|
|
77
|
+
class ToolChoice(pydantic.BaseModel):
|
|
78
|
+
auto: bool
|
|
79
|
+
required: bool
|
|
80
|
+
tool: Optional[str]
|
|
81
|
+
parallel_tool_calls: bool
|
|
82
|
+
|
|
83
|
+
|
|
78
84
|
class Tools(pydantic.BaseModel):
|
|
79
85
|
tools: list[Tool]
|
|
80
86
|
|
|
@@ -92,25 +98,56 @@ class Tools(pydantic.BaseModel):
|
|
|
92
98
|
for tool in self.tools
|
|
93
99
|
})
|
|
94
100
|
|
|
101
|
+
def choice(
|
|
102
|
+
self,
|
|
103
|
+
auto: bool = False,
|
|
104
|
+
required: bool = False,
|
|
105
|
+
tool: Union[str, Function, None] = None,
|
|
106
|
+
parallel_tool_calls: bool = True,
|
|
107
|
+
) -> ToolChoice:
|
|
108
|
+
if sum([auto, required, tool is not None]) != 1:
|
|
109
|
+
raise excs.Error('Exactly one of `auto`, `required`, or `tool` must be specified.')
|
|
110
|
+
tool_name: Optional[str] = None
|
|
111
|
+
if tool is not None:
|
|
112
|
+
try:
|
|
113
|
+
tool_obj = next(
|
|
114
|
+
t for t in self.tools
|
|
115
|
+
if (isinstance(tool, Function) and t.fn == tool)
|
|
116
|
+
or (isinstance(tool, str) and (t.name or t.fn.name) == tool)
|
|
117
|
+
)
|
|
118
|
+
tool_name = tool_obj.name or tool_obj.fn.name
|
|
119
|
+
except StopIteration:
|
|
120
|
+
raise excs.Error(f'That tool is not in the specified list of tools: {tool}')
|
|
121
|
+
return ToolChoice(auto=auto, required=required, tool=tool_name, parallel_tool_calls=parallel_tool_calls)
|
|
122
|
+
|
|
95
123
|
|
|
96
124
|
@udf
|
|
97
|
-
def _extract_str_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[str]:
|
|
98
|
-
return
|
|
125
|
+
def _extract_str_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[str]:
|
|
126
|
+
return _extract_arg(str, tool_calls, func_name, param_name)
|
|
127
|
+
|
|
99
128
|
|
|
100
129
|
@udf
|
|
101
|
-
def _extract_int_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[int]:
|
|
102
|
-
return
|
|
130
|
+
def _extract_int_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[int]:
|
|
131
|
+
return _extract_arg(int, tool_calls, func_name, param_name)
|
|
132
|
+
|
|
103
133
|
|
|
104
134
|
@udf
|
|
105
|
-
def _extract_float_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[float]:
|
|
106
|
-
return
|
|
135
|
+
def _extract_float_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[float]:
|
|
136
|
+
return _extract_arg(float, tool_calls, func_name, param_name)
|
|
137
|
+
|
|
107
138
|
|
|
108
139
|
@udf
|
|
109
|
-
def _extract_bool_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[bool]:
|
|
110
|
-
return
|
|
140
|
+
def _extract_bool_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[bool]:
|
|
141
|
+
return _extract_arg(bool, tool_calls, func_name, param_name)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
T = TypeVar('T')
|
|
145
|
+
|
|
111
146
|
|
|
112
|
-
def _extract_arg(tool_calls: dict, func_name: str, param_name: str) ->
|
|
147
|
+
def _extract_arg(eval_fn: Callable[[Any], T], tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[T]:
|
|
113
148
|
if func_name in tool_calls:
|
|
114
149
|
arguments = tool_calls[func_name]['args']
|
|
115
|
-
|
|
150
|
+
if param_name in arguments:
|
|
151
|
+
return eval_fn(arguments[param_name])
|
|
152
|
+
return None
|
|
116
153
|
return None
|
pixeltable/func/udf.py
CHANGED
|
@@ -15,7 +15,7 @@ from .signature import Signature
|
|
|
15
15
|
|
|
16
16
|
# Decorator invoked without parentheses: @pxt.udf
|
|
17
17
|
@overload
|
|
18
|
-
def udf(decorated_fn: Callable) ->
|
|
18
|
+
def udf(decorated_fn: Callable) -> CallableFunction: ...
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
# Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
|
|
@@ -26,6 +26,7 @@ def udf(
|
|
|
26
26
|
substitute_fn: Optional[Callable] = None,
|
|
27
27
|
is_method: bool = False,
|
|
28
28
|
is_property: bool = False,
|
|
29
|
+
resource_pool: Optional[str] = None,
|
|
29
30
|
type_substitutions: Optional[Sequence[dict]] = None,
|
|
30
31
|
_force_stored: bool = False
|
|
31
32
|
) -> Callable[[Callable], CallableFunction]: ...
|
|
@@ -53,6 +54,7 @@ def udf(*args, **kwargs):
|
|
|
53
54
|
substitute_fn = kwargs.pop('substitute_fn', None)
|
|
54
55
|
is_method = kwargs.pop('is_method', None)
|
|
55
56
|
is_property = kwargs.pop('is_property', None)
|
|
57
|
+
resource_pool = kwargs.pop('resource_pool', None)
|
|
56
58
|
type_substitutions = kwargs.pop('type_substitutions', None)
|
|
57
59
|
force_stored = kwargs.pop('_force_stored', False)
|
|
58
60
|
if len(kwargs) > 0:
|
|
@@ -67,6 +69,7 @@ def udf(*args, **kwargs):
|
|
|
67
69
|
substitute_fn=substitute_fn,
|
|
68
70
|
is_method=is_method,
|
|
69
71
|
is_property=is_property,
|
|
72
|
+
resource_pool=resource_pool,
|
|
70
73
|
type_substitutions=type_substitutions,
|
|
71
74
|
force_stored=force_stored
|
|
72
75
|
)
|
|
@@ -82,6 +85,7 @@ def make_function(
|
|
|
82
85
|
substitute_fn: Optional[Callable] = None,
|
|
83
86
|
is_method: bool = False,
|
|
84
87
|
is_property: bool = False,
|
|
88
|
+
resource_pool: Optional[str] = None,
|
|
85
89
|
type_substitutions: Optional[Sequence[dict]] = None,
|
|
86
90
|
function_name: Optional[str] = None,
|
|
87
91
|
force_stored: bool = False
|
|
@@ -162,6 +166,8 @@ def make_function(
|
|
|
162
166
|
is_method=is_method,
|
|
163
167
|
is_property=is_property
|
|
164
168
|
)
|
|
169
|
+
if resource_pool is not None:
|
|
170
|
+
result.resource_pool(lambda: resource_pool)
|
|
165
171
|
|
|
166
172
|
# If this function is part of a module, register it
|
|
167
173
|
if function_path is not None:
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@ from pixeltable.utils.code import local_public_names
|
|
|
2
2
|
|
|
3
3
|
from . import (anthropic, audio, fireworks, gemini, huggingface, image, json, llama_cpp, math, mistralai, ollama,
|
|
4
4
|
openai, string, timestamp, together, video, vision, whisper)
|
|
5
|
-
from .globals import
|
|
5
|
+
from .globals import count, max, mean, min, sum
|
|
6
6
|
|
|
7
7
|
__all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
|
|
8
8
|
|
|
@@ -5,9 +5,12 @@ first `pip install anthropic` and configure your Anthropic credentials, as descr
|
|
|
5
5
|
the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anthropic) tutorial.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
import datetime
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, Iterable
|
|
9
12
|
|
|
10
|
-
import
|
|
13
|
+
import httpx
|
|
11
14
|
|
|
12
15
|
import pixeltable as pxt
|
|
13
16
|
from pixeltable import env, exprs
|
|
@@ -17,28 +20,54 @@ from pixeltable.utils.code import local_public_names
|
|
|
17
20
|
if TYPE_CHECKING:
|
|
18
21
|
import anthropic
|
|
19
22
|
|
|
23
|
+
_logger = logging.getLogger('pixeltable')
|
|
20
24
|
|
|
21
25
|
@env.register_client('anthropic')
|
|
22
|
-
def _(api_key: str) -> 'anthropic.
|
|
26
|
+
def _(api_key: str) -> 'anthropic.AsyncAnthropic':
|
|
23
27
|
import anthropic
|
|
24
|
-
return anthropic.
|
|
28
|
+
return anthropic.AsyncAnthropic(
|
|
29
|
+
api_key=api_key,
|
|
30
|
+
# recommended to increase limits for async client to avoid connection errors
|
|
31
|
+
http_client = httpx.AsyncClient(limits=httpx.Limits(max_keepalive_connections=100, max_connections=500)))
|
|
25
32
|
|
|
26
33
|
|
|
27
|
-
def _anthropic_client() -> 'anthropic.
|
|
34
|
+
def _anthropic_client() -> 'anthropic.AsyncAnthropic':
|
|
28
35
|
return env.Env.get().get_client('anthropic')
|
|
29
36
|
|
|
30
37
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
+
class AnthropicRateLimitsInfo(env.RateLimitsInfo):
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
super().__init__(self._get_request_resources)
|
|
42
|
+
|
|
43
|
+
def _get_request_resources(self, messages: dict, max_tokens: int) -> dict[str, int]:
|
|
44
|
+
input_len = 0
|
|
45
|
+
for message in messages:
|
|
46
|
+
if 'role' in message:
|
|
47
|
+
input_len += len(message['role'])
|
|
48
|
+
if 'content' in message:
|
|
49
|
+
input_len += len(message['content'])
|
|
50
|
+
return {'requests': 1, 'input_tokens': int(input_len / 4), 'output_tokens': max_tokens}
|
|
51
|
+
|
|
52
|
+
def get_retry_delay(self, exc: Exception) -> Optional[float]:
|
|
53
|
+
import anthropic
|
|
54
|
+
|
|
55
|
+
# deal with timeouts separately, they don't come with headers
|
|
56
|
+
if isinstance(exc, anthropic.APITimeoutError):
|
|
57
|
+
return 1.0
|
|
58
|
+
|
|
59
|
+
if not isinstance(exc, anthropic.APIStatusError):
|
|
60
|
+
return None
|
|
61
|
+
_logger.debug(f'headers={exc.response.headers}')
|
|
62
|
+
should_retry_str = exc.response.headers.get('x-should-retry', '')
|
|
63
|
+
if should_retry_str.lower() != 'true':
|
|
64
|
+
return None
|
|
65
|
+
retry_after_str = exc.response.headers.get('retry-after', '1')
|
|
66
|
+
return int(retry_after_str)
|
|
38
67
|
|
|
39
68
|
|
|
40
69
|
@pxt.udf
|
|
41
|
-
def messages(
|
|
70
|
+
async def messages(
|
|
42
71
|
messages: list[dict[str, str]],
|
|
43
72
|
*,
|
|
44
73
|
model: str,
|
|
@@ -47,7 +76,7 @@ def messages(
|
|
|
47
76
|
stop_sequences: Optional[list[str]] = None,
|
|
48
77
|
system: Optional[str] = None,
|
|
49
78
|
temperature: Optional[float] = None,
|
|
50
|
-
tool_choice: Optional[
|
|
79
|
+
tool_choice: Optional[dict] = None,
|
|
51
80
|
tools: Optional[list[dict]] = None,
|
|
52
81
|
top_k: Optional[int] = None,
|
|
53
82
|
top_p: Optional[float] = None,
|
|
@@ -78,6 +107,9 @@ def messages(
|
|
|
78
107
|
>>> msgs = [{'role': 'user', 'content': tbl.prompt}]
|
|
79
108
|
... tbl['response'] = messages(msgs, model='claude-3-haiku-20240307')
|
|
80
109
|
"""
|
|
110
|
+
|
|
111
|
+
# it doesn't look like count_tokens() actually exists in the current version of the library
|
|
112
|
+
|
|
81
113
|
if tools is not None:
|
|
82
114
|
# Reformat `tools` into Anthropic format
|
|
83
115
|
tools = [
|
|
@@ -93,19 +125,74 @@ def messages(
|
|
|
93
125
|
for tool in tools
|
|
94
126
|
]
|
|
95
127
|
|
|
96
|
-
|
|
97
|
-
|
|
128
|
+
tool_choice_: Optional[dict] = None
|
|
129
|
+
if tool_choice is not None:
|
|
130
|
+
if tool_choice['auto']:
|
|
131
|
+
tool_choice_ = {'type': 'auto'}
|
|
132
|
+
elif tool_choice['required']:
|
|
133
|
+
tool_choice_ = {'type': 'any'}
|
|
134
|
+
else:
|
|
135
|
+
assert tool_choice['tool'] is not None
|
|
136
|
+
tool_choice_ = {'type': 'tool', 'name': tool_choice['tool']}
|
|
137
|
+
if not tool_choice['parallel_tool_calls']:
|
|
138
|
+
tool_choice_['disable_parallel_tool_use'] = True
|
|
139
|
+
|
|
140
|
+
# TODO: timeouts should be set system-wide and be user-configurable
|
|
141
|
+
from anthropic.types import MessageParam
|
|
142
|
+
|
|
143
|
+
# cast(Any, ...): avoid mypy errors
|
|
144
|
+
result = await _anthropic_client().messages.with_raw_response.create(
|
|
145
|
+
messages=cast(Iterable[MessageParam], messages),
|
|
98
146
|
model=model,
|
|
99
147
|
max_tokens=max_tokens,
|
|
100
|
-
metadata=_opt(metadata),
|
|
148
|
+
metadata=_opt(cast(Any, metadata)),
|
|
101
149
|
stop_sequences=_opt(stop_sequences),
|
|
102
150
|
system=_opt(system),
|
|
103
|
-
temperature=_opt(temperature),
|
|
104
|
-
|
|
105
|
-
|
|
151
|
+
temperature=_opt(cast(Any, temperature)),
|
|
152
|
+
tools=_opt(cast(Any, tools)),
|
|
153
|
+
tool_choice=_opt(cast(Any, tool_choice_)),
|
|
106
154
|
top_k=_opt(top_k),
|
|
107
155
|
top_p=_opt(top_p),
|
|
108
|
-
|
|
156
|
+
timeout=10,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
requests_limit_str = result.headers.get('anthropic-ratelimit-requests-limit')
|
|
160
|
+
requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
|
|
161
|
+
requests_remaining_str = result.headers.get('anthropic-ratelimit-requests-remaining')
|
|
162
|
+
requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
|
|
163
|
+
requests_reset_str = result.headers.get('anthropic-ratelimit-requests-reset')
|
|
164
|
+
requests_reset = datetime.datetime.fromisoformat(requests_reset_str.replace('Z', '+00:00'))
|
|
165
|
+
input_tokens_limit_str = result.headers.get('anthropic-ratelimit-input-tokens-limit')
|
|
166
|
+
input_tokens_limit = int(input_tokens_limit_str) if input_tokens_limit_str is not None else None
|
|
167
|
+
input_tokens_remaining_str = result.headers.get('anthropic-ratelimit-input-tokens-remaining')
|
|
168
|
+
input_tokens_remaining = int(input_tokens_remaining_str) if input_tokens_remaining_str is not None else None
|
|
169
|
+
input_tokens_reset_str = result.headers.get('anthropic-ratelimit-input-tokens-reset')
|
|
170
|
+
input_tokens_reset = datetime.datetime.fromisoformat(input_tokens_reset_str.replace('Z', '+00:00'))
|
|
171
|
+
output_tokens_limit_str = result.headers.get('anthropic-ratelimit-output-tokens-limit')
|
|
172
|
+
output_tokens_limit = int(output_tokens_limit_str) if output_tokens_limit_str is not None else None
|
|
173
|
+
output_tokens_remaining_str = result.headers.get('anthropic-ratelimit-output-tokens-remaining')
|
|
174
|
+
output_tokens_remaining = int(output_tokens_remaining_str) if output_tokens_remaining_str is not None else None
|
|
175
|
+
output_tokens_reset_str = result.headers.get('anthropic-ratelimit-output-tokens-reset')
|
|
176
|
+
output_tokens_reset = datetime.datetime.fromisoformat(output_tokens_reset_str.replace('Z', '+00:00'))
|
|
177
|
+
retry_after_str = result.headers.get('retry-after')
|
|
178
|
+
if retry_after_str is not None:
|
|
179
|
+
_logger.debug(f'retry-after: {retry_after_str}')
|
|
180
|
+
|
|
181
|
+
resource_pool_id = f'rate-limits:anthropic:{model}'
|
|
182
|
+
rate_limits_info = env.Env.get().get_resource_pool_info(resource_pool_id, AnthropicRateLimitsInfo)
|
|
183
|
+
assert isinstance(rate_limits_info, env.RateLimitsInfo)
|
|
184
|
+
rate_limits_info.record(
|
|
185
|
+
requests=(requests_limit, requests_remaining, requests_reset),
|
|
186
|
+
input_tokens=(input_tokens_limit, input_tokens_remaining, input_tokens_reset),
|
|
187
|
+
output_tokens=(output_tokens_limit, output_tokens_remaining, output_tokens_reset))
|
|
188
|
+
|
|
189
|
+
result_dict = json.loads(result.text)
|
|
190
|
+
return result_dict
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@messages.resource_pool
|
|
194
|
+
def _(model: str) -> str:
|
|
195
|
+
return f'rate-limits:anthropic:{model}'
|
|
109
196
|
|
|
110
197
|
|
|
111
198
|
def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
pixeltable/functions/gemini.py
CHANGED
|
@@ -13,7 +13,7 @@ from pixeltable import env
|
|
|
13
13
|
|
|
14
14
|
@env.register_client('gemini')
|
|
15
15
|
def _(api_key: str) -> None:
|
|
16
|
-
import google.generativeai as genai
|
|
16
|
+
import google.generativeai as genai
|
|
17
17
|
genai.configure(api_key=api_key)
|
|
18
18
|
|
|
19
19
|
|
|
@@ -36,8 +36,6 @@ def generate_content(
|
|
|
36
36
|
response_schema: Optional[dict] = None,
|
|
37
37
|
presence_penalty: Optional[float] = None,
|
|
38
38
|
frequency_penalty: Optional[float] = None,
|
|
39
|
-
response_logprobs: Optional[bool] = None,
|
|
40
|
-
logprobs: Optional[int] = None,
|
|
41
39
|
) -> dict:
|
|
42
40
|
"""
|
|
43
41
|
Generate content from the specified model. For additional details, see:
|
|
@@ -60,7 +58,7 @@ def generate_content(
|
|
|
60
58
|
Add a computed column that applies the model `gemini-1.5-flash`
|
|
61
59
|
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
62
60
|
|
|
63
|
-
>>> tbl
|
|
61
|
+
>>> tbl.add_computed_column(response=generate_content(tbl.prompt, model_name='gemini-1.5-flash'))
|
|
64
62
|
"""
|
|
65
63
|
env.Env.get().require_package('google.generativeai')
|
|
66
64
|
_ensure_loaded()
|
|
@@ -78,8 +76,6 @@ def generate_content(
|
|
|
78
76
|
response_schema=response_schema,
|
|
79
77
|
presence_penalty=presence_penalty,
|
|
80
78
|
frequency_penalty=frequency_penalty,
|
|
81
|
-
response_logprobs=response_logprobs,
|
|
82
|
-
logprobs=logprobs,
|
|
83
79
|
)
|
|
84
80
|
response = model.generate_content(contents, generation_config=gc)
|
|
85
81
|
return response.to_dict()
|