chatlas 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +79 -45
- chatlas/_auto.py +3 -12
- chatlas/_chat.py +800 -169
- chatlas/_content.py +149 -29
- chatlas/_databricks.py +4 -14
- chatlas/_github.py +21 -25
- chatlas/_google.py +71 -32
- chatlas/_groq.py +15 -18
- chatlas/_interpolate.py +3 -4
- chatlas/_mcp_manager.py +306 -0
- chatlas/_ollama.py +14 -18
- chatlas/_openai.py +74 -39
- chatlas/_perplexity.py +14 -18
- chatlas/_provider.py +78 -8
- chatlas/_snowflake.py +29 -18
- chatlas/_tokens.py +93 -5
- chatlas/_tools.py +181 -22
- chatlas/_turn.py +2 -18
- chatlas/_utils.py +27 -1
- chatlas/_version.py +2 -2
- chatlas/data/prices.json +264 -0
- chatlas/types/anthropic/_submit.py +2 -0
- chatlas/types/openai/_client.py +1 -0
- chatlas/types/openai/_client_azure.py +1 -0
- chatlas/types/openai/_submit.py +4 -1
- chatlas-0.9.0.dist-info/METADATA +141 -0
- chatlas-0.9.0.dist-info/RECORD +48 -0
- chatlas-0.8.0.dist-info/METADATA +0 -383
- chatlas-0.8.0.dist-info/RECORD +0 -46
- {chatlas-0.8.0.dist-info → chatlas-0.9.0.dist-info}/WHEEL +0 -0
- {chatlas-0.8.0.dist-info → chatlas-0.9.0.dist-info}/licenses/LICENSE +0 -0
chatlas/_provider.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from typing import (
|
|
5
|
-
Any,
|
|
6
5
|
AsyncIterable,
|
|
7
6
|
Generic,
|
|
8
7
|
Iterable,
|
|
@@ -17,6 +16,7 @@ from pydantic import BaseModel
|
|
|
17
16
|
from ._content import Content
|
|
18
17
|
from ._tools import Tool
|
|
19
18
|
from ._turn import Turn
|
|
19
|
+
from ._typing_extensions import TypedDict
|
|
20
20
|
|
|
21
21
|
ChatCompletionT = TypeVar("ChatCompletionT")
|
|
22
22
|
ChatCompletionChunkT = TypeVar("ChatCompletionChunkT")
|
|
@@ -24,8 +24,52 @@ ChatCompletionChunkT = TypeVar("ChatCompletionChunkT")
|
|
|
24
24
|
ChatCompletionDictT = TypeVar("ChatCompletionDictT")
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
class AnyTypeDict(TypedDict, total=False):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
|
|
32
|
+
"""
|
|
33
|
+
A TypedDict representing the provider specific arguments that can specified when
|
|
34
|
+
submitting input to a model provider.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class StandardModelParams(TypedDict, total=False):
|
|
39
|
+
"""
|
|
40
|
+
A TypedDict representing the standard model parameters that can be set
|
|
41
|
+
when using a [](`~chatlas.Chat`) instance.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
temperature: float
|
|
45
|
+
top_p: float
|
|
46
|
+
top_k: int
|
|
47
|
+
frequency_penalty: float
|
|
48
|
+
presence_penalty: float
|
|
49
|
+
seed: int
|
|
50
|
+
max_tokens: int
|
|
51
|
+
log_probs: bool
|
|
52
|
+
stop_sequences: list[str]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
StandardModelParamNames = Literal[
|
|
56
|
+
"temperature",
|
|
57
|
+
"top_p",
|
|
58
|
+
"top_k",
|
|
59
|
+
"frequency_penalty",
|
|
60
|
+
"presence_penalty",
|
|
61
|
+
"seed",
|
|
62
|
+
"max_tokens",
|
|
63
|
+
"log_probs",
|
|
64
|
+
"stop_sequences",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
27
68
|
class Provider(
|
|
28
|
-
ABC,
|
|
69
|
+
ABC,
|
|
70
|
+
Generic[
|
|
71
|
+
ChatCompletionT, ChatCompletionChunkT, ChatCompletionDictT, SubmitInputArgsT
|
|
72
|
+
],
|
|
29
73
|
):
|
|
30
74
|
"""
|
|
31
75
|
A model provider interface for a [](`~chatlas.Chat`).
|
|
@@ -40,6 +84,24 @@ class Provider(
|
|
|
40
84
|
directly.
|
|
41
85
|
"""
|
|
42
86
|
|
|
87
|
+
def __init__(self, *, name: str, model: str):
|
|
88
|
+
self._name = name
|
|
89
|
+
self._model = model
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def name(self):
|
|
93
|
+
"""
|
|
94
|
+
Get the name of the provider
|
|
95
|
+
"""
|
|
96
|
+
return self._name
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def model(self):
|
|
100
|
+
"""
|
|
101
|
+
Get the model used by the provider
|
|
102
|
+
"""
|
|
103
|
+
return self._model
|
|
104
|
+
|
|
43
105
|
@overload
|
|
44
106
|
@abstractmethod
|
|
45
107
|
def chat_perform(
|
|
@@ -49,7 +111,7 @@ class Provider(
|
|
|
49
111
|
turns: list[Turn],
|
|
50
112
|
tools: dict[str, Tool],
|
|
51
113
|
data_model: Optional[type[BaseModel]],
|
|
52
|
-
kwargs:
|
|
114
|
+
kwargs: SubmitInputArgsT,
|
|
53
115
|
) -> ChatCompletionT: ...
|
|
54
116
|
|
|
55
117
|
@overload
|
|
@@ -61,7 +123,7 @@ class Provider(
|
|
|
61
123
|
turns: list[Turn],
|
|
62
124
|
tools: dict[str, Tool],
|
|
63
125
|
data_model: Optional[type[BaseModel]],
|
|
64
|
-
kwargs:
|
|
126
|
+
kwargs: SubmitInputArgsT,
|
|
65
127
|
) -> Iterable[ChatCompletionChunkT]: ...
|
|
66
128
|
|
|
67
129
|
@abstractmethod
|
|
@@ -72,7 +134,7 @@ class Provider(
|
|
|
72
134
|
turns: list[Turn],
|
|
73
135
|
tools: dict[str, Tool],
|
|
74
136
|
data_model: Optional[type[BaseModel]],
|
|
75
|
-
kwargs:
|
|
137
|
+
kwargs: SubmitInputArgsT,
|
|
76
138
|
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
|
|
77
139
|
|
|
78
140
|
@overload
|
|
@@ -84,7 +146,7 @@ class Provider(
|
|
|
84
146
|
turns: list[Turn],
|
|
85
147
|
tools: dict[str, Tool],
|
|
86
148
|
data_model: Optional[type[BaseModel]],
|
|
87
|
-
kwargs:
|
|
149
|
+
kwargs: SubmitInputArgsT,
|
|
88
150
|
) -> ChatCompletionT: ...
|
|
89
151
|
|
|
90
152
|
@overload
|
|
@@ -96,7 +158,7 @@ class Provider(
|
|
|
96
158
|
turns: list[Turn],
|
|
97
159
|
tools: dict[str, Tool],
|
|
98
160
|
data_model: Optional[type[BaseModel]],
|
|
99
|
-
kwargs:
|
|
161
|
+
kwargs: SubmitInputArgsT,
|
|
100
162
|
) -> AsyncIterable[ChatCompletionChunkT]: ...
|
|
101
163
|
|
|
102
164
|
@abstractmethod
|
|
@@ -107,7 +169,7 @@ class Provider(
|
|
|
107
169
|
turns: list[Turn],
|
|
108
170
|
tools: dict[str, Tool],
|
|
109
171
|
data_model: Optional[type[BaseModel]],
|
|
110
|
-
kwargs:
|
|
172
|
+
kwargs: SubmitInputArgsT,
|
|
111
173
|
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
|
|
112
174
|
|
|
113
175
|
@abstractmethod
|
|
@@ -149,3 +211,11 @@ class Provider(
|
|
|
149
211
|
tools: dict[str, Tool],
|
|
150
212
|
data_model: Optional[type[BaseModel]],
|
|
151
213
|
) -> int: ...
|
|
214
|
+
|
|
215
|
+
@abstractmethod
|
|
216
|
+
def translate_model_params(
|
|
217
|
+
self, params: StandardModelParams
|
|
218
|
+
) -> SubmitInputArgsT: ...
|
|
219
|
+
|
|
220
|
+
@abstractmethod
|
|
221
|
+
def supported_model_params(self) -> set[StandardModelParamNames]: ...
|
chatlas/_snowflake.py
CHANGED
|
@@ -20,10 +20,10 @@ from ._content import (
|
|
|
20
20
|
ContentToolResult,
|
|
21
21
|
)
|
|
22
22
|
from ._logging import log_model_default
|
|
23
|
-
from ._provider import Provider
|
|
23
|
+
from ._provider import Provider, StandardModelParamNames, StandardModelParams
|
|
24
24
|
from ._tokens import tokens_log
|
|
25
25
|
from ._tools import Tool, basemodel_to_param_schema
|
|
26
|
-
from ._turn import Turn
|
|
26
|
+
from ._turn import Turn
|
|
27
27
|
from ._utils import drop_none
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
@@ -61,7 +61,6 @@ def ChatSnowflake(
|
|
|
61
61
|
*,
|
|
62
62
|
system_prompt: Optional[str] = None,
|
|
63
63
|
model: Optional[str] = None,
|
|
64
|
-
turns: Optional[list[Turn]] = None,
|
|
65
64
|
connection_name: Optional[str] = None,
|
|
66
65
|
account: Optional[str] = None,
|
|
67
66
|
user: Optional[str] = None,
|
|
@@ -111,13 +110,6 @@ def ChatSnowflake(
|
|
|
111
110
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
112
111
|
default, and warn you about it. We strongly recommend explicitly
|
|
113
112
|
choosing a model for all but the most casual use.
|
|
114
|
-
turns
|
|
115
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
116
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
117
|
-
not provide non-None values for both `turns` and `system_prompt`. Each
|
|
118
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
119
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
120
|
-
there is also a `content` field, which is a string.
|
|
121
113
|
connection_name
|
|
122
114
|
The name of the connection (i.e., section) within the connections.toml file.
|
|
123
115
|
This is useful if you want to keep your credentials in a connections.toml file
|
|
@@ -157,14 +149,13 @@ def ChatSnowflake(
|
|
|
157
149
|
private_key_file_pwd=private_key_file_pwd,
|
|
158
150
|
kwargs=kwargs,
|
|
159
151
|
),
|
|
160
|
-
|
|
161
|
-
turns or [],
|
|
162
|
-
system_prompt,
|
|
163
|
-
),
|
|
152
|
+
system_prompt=system_prompt,
|
|
164
153
|
)
|
|
165
154
|
|
|
166
155
|
|
|
167
|
-
class SnowflakeProvider(
|
|
156
|
+
class SnowflakeProvider(
|
|
157
|
+
Provider["Completion", "CompletionChunk", "CompletionChunk", "CompleteRequest"]
|
|
158
|
+
):
|
|
168
159
|
def __init__(
|
|
169
160
|
self,
|
|
170
161
|
*,
|
|
@@ -175,6 +166,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
175
166
|
password: Optional[str],
|
|
176
167
|
private_key_file: Optional[str],
|
|
177
168
|
private_key_file_pwd: Optional[str],
|
|
169
|
+
name: str = "Snowflake",
|
|
178
170
|
kwargs: Optional[dict[str, "str | int"]],
|
|
179
171
|
):
|
|
180
172
|
try:
|
|
@@ -185,6 +177,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
185
177
|
"`ChatSnowflake()` requires the `snowflake-ml-python` package. "
|
|
186
178
|
"Please install it via `pip install snowflake-ml-python`."
|
|
187
179
|
)
|
|
180
|
+
super().__init__(name=name, model=model)
|
|
188
181
|
|
|
189
182
|
configs: dict[str, str | int] = drop_none(
|
|
190
183
|
{
|
|
@@ -198,8 +191,6 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
198
191
|
}
|
|
199
192
|
)
|
|
200
193
|
|
|
201
|
-
self._model = model
|
|
202
|
-
|
|
203
194
|
session = Session.builder.configs(configs).create()
|
|
204
195
|
self._cortex_service = Root(session).cortex_inference_service
|
|
205
196
|
|
|
@@ -314,7 +305,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
314
305
|
from snowflake.core.cortex.inference_service import CompleteRequest
|
|
315
306
|
|
|
316
307
|
req = CompleteRequest(
|
|
317
|
-
model=self.
|
|
308
|
+
model=self.model,
|
|
318
309
|
messages=self._as_request_messages(turns),
|
|
319
310
|
stream=stream,
|
|
320
311
|
)
|
|
@@ -599,6 +590,26 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
599
590
|
|
|
600
591
|
return models.Tool(tool_spec=spec)
|
|
601
592
|
|
|
593
|
+
def translate_model_params(self, params: StandardModelParams) -> "CompleteRequest":
|
|
594
|
+
res: "CompleteRequest" = {}
|
|
595
|
+
if "temperature" in params:
|
|
596
|
+
res["temperature"] = params["temperature"]
|
|
597
|
+
|
|
598
|
+
if "top_p" in params:
|
|
599
|
+
res["top_p"] = params["top_p"]
|
|
600
|
+
|
|
601
|
+
if "max_tokens" in params:
|
|
602
|
+
res["max_tokens"] = params["max_tokens"]
|
|
603
|
+
|
|
604
|
+
return res
|
|
605
|
+
|
|
606
|
+
def supported_model_params(self) -> set[StandardModelParamNames]:
|
|
607
|
+
return {
|
|
608
|
+
"temperature",
|
|
609
|
+
"top_p",
|
|
610
|
+
"max_tokens",
|
|
611
|
+
}
|
|
612
|
+
|
|
602
613
|
|
|
603
614
|
# Yield parsed event data from the Snowflake SSEClient
|
|
604
615
|
# (this is only needed for the streaming case).
|
chatlas/_tokens.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
+
import importlib.resources as resources
|
|
5
|
+
import warnings
|
|
4
6
|
from threading import Lock
|
|
5
7
|
from typing import TYPE_CHECKING
|
|
6
8
|
|
|
9
|
+
import orjson
|
|
10
|
+
|
|
7
11
|
from ._logging import logger
|
|
8
12
|
from ._typing_extensions import TypedDict
|
|
9
13
|
|
|
@@ -17,8 +21,10 @@ class TokenUsage(TypedDict):
|
|
|
17
21
|
"""
|
|
18
22
|
|
|
19
23
|
name: str
|
|
24
|
+
model: str
|
|
20
25
|
input: int
|
|
21
26
|
output: int
|
|
27
|
+
cost: float | None
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
class ThreadSafeTokenCounter:
|
|
@@ -26,7 +32,9 @@ class ThreadSafeTokenCounter:
|
|
|
26
32
|
self._lock = Lock()
|
|
27
33
|
self._tokens: dict[str, TokenUsage] = {}
|
|
28
34
|
|
|
29
|
-
def log_tokens(
|
|
35
|
+
def log_tokens(
|
|
36
|
+
self, name: str, model: str, input_tokens: int, output_tokens: int
|
|
37
|
+
) -> None:
|
|
30
38
|
logger.info(
|
|
31
39
|
f"Provider '{name}' generated a response of {output_tokens} tokens "
|
|
32
40
|
f"from an input of {input_tokens} tokens."
|
|
@@ -36,12 +44,21 @@ class ThreadSafeTokenCounter:
|
|
|
36
44
|
if name not in self._tokens:
|
|
37
45
|
self._tokens[name] = {
|
|
38
46
|
"name": name,
|
|
47
|
+
"model": model,
|
|
39
48
|
"input": input_tokens,
|
|
40
49
|
"output": output_tokens,
|
|
50
|
+
"cost": compute_price(name, model, input_tokens, output_tokens),
|
|
41
51
|
}
|
|
42
52
|
else:
|
|
43
53
|
self._tokens[name]["input"] += input_tokens
|
|
44
54
|
self._tokens[name]["output"] += output_tokens
|
|
55
|
+
price = compute_price(name, model, input_tokens, output_tokens)
|
|
56
|
+
if price is not None:
|
|
57
|
+
cost = self._tokens[name]["cost"]
|
|
58
|
+
if cost is None:
|
|
59
|
+
self._tokens[name]["cost"] = price
|
|
60
|
+
else:
|
|
61
|
+
self._tokens[name]["cost"] = cost + price
|
|
45
62
|
|
|
46
63
|
def get_usage(self) -> list[TokenUsage] | None:
|
|
47
64
|
with self._lock:
|
|
@@ -59,8 +76,7 @@ def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None:
|
|
|
59
76
|
"""
|
|
60
77
|
Log token usage for a provider in a thread-safe manner.
|
|
61
78
|
"""
|
|
62
|
-
name
|
|
63
|
-
_token_counter.log_tokens(name, tokens[0], tokens[1])
|
|
79
|
+
_token_counter.log_tokens(provider.name, provider.model, tokens[0], tokens[1])
|
|
64
80
|
|
|
65
81
|
|
|
66
82
|
def tokens_reset() -> None:
|
|
@@ -71,17 +87,89 @@ def tokens_reset() -> None:
|
|
|
71
87
|
_token_counter = ThreadSafeTokenCounter()
|
|
72
88
|
|
|
73
89
|
|
|
90
|
+
class TokenPrice(TypedDict):
|
|
91
|
+
"""
|
|
92
|
+
Defines the necessary information to look up pricing for a given turn.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
provider: str
|
|
96
|
+
"""The provider name (e.g., "OpenAI", "Anthropic", etc.)"""
|
|
97
|
+
model: str
|
|
98
|
+
"""The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.)"""
|
|
99
|
+
cached_input: float
|
|
100
|
+
"""The cost per user token in USD per million tokens for cached input"""
|
|
101
|
+
input: float
|
|
102
|
+
"""The cost per user token in USD per million tokens"""
|
|
103
|
+
output: float
|
|
104
|
+
"""The cost per assistant token in USD per million tokens"""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# Load in pricing pulled from ellmer
|
|
108
|
+
f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8")
|
|
109
|
+
pricing_list: list[TokenPrice] = orjson.loads(f)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_token_pricing(name: str, model: str) -> TokenPrice | None:
|
|
113
|
+
"""
|
|
114
|
+
Get token pricing information given a provider name and model
|
|
115
|
+
|
|
116
|
+
Note
|
|
117
|
+
----
|
|
118
|
+
Only a subset of providers and models and currently supported.
|
|
119
|
+
The pricing information derives from ellmer.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
TokenPrice | None
|
|
124
|
+
"""
|
|
125
|
+
result = next(
|
|
126
|
+
(
|
|
127
|
+
item
|
|
128
|
+
for item in pricing_list
|
|
129
|
+
if item["provider"] == name and item["model"] == model
|
|
130
|
+
),
|
|
131
|
+
None,
|
|
132
|
+
)
|
|
133
|
+
if result is None:
|
|
134
|
+
warnings.warn(
|
|
135
|
+
f"Token pricing for the provider '{name}' and model '{model}' you selected is not available. "
|
|
136
|
+
"Please check the provider's documentation."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def compute_price(
|
|
143
|
+
name: str, model: str, input_tokens: int, output_tokens: int
|
|
144
|
+
) -> float | None:
|
|
145
|
+
"""
|
|
146
|
+
Compute the cost of a turn.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
float | None
|
|
151
|
+
The cost of the turn in USD, or None if the cost could not be calculated.
|
|
152
|
+
"""
|
|
153
|
+
price = get_token_pricing(name, model)
|
|
154
|
+
if price is None:
|
|
155
|
+
return None
|
|
156
|
+
input_price = input_tokens * (price["input"] / 1e6)
|
|
157
|
+
output_price = output_tokens * (price["output"] / 1e6)
|
|
158
|
+
return input_price + output_price
|
|
159
|
+
|
|
160
|
+
|
|
74
161
|
def token_usage() -> list[TokenUsage] | None:
|
|
75
162
|
"""
|
|
76
163
|
Report on token usage in the current session
|
|
77
164
|
|
|
78
165
|
Call this function to find out the cumulative number of tokens that you
|
|
79
|
-
have sent and received in the current session.
|
|
166
|
+
have sent and received in the current session. The price will be shown if known
|
|
80
167
|
|
|
81
168
|
Returns
|
|
82
169
|
-------
|
|
83
170
|
list[TokenUsage] | None
|
|
84
|
-
A list of dictionaries with the following keys: "name", "input", and "
|
|
171
|
+
A list of dictionaries with the following keys: "name", "input", "output", and "cost".
|
|
172
|
+
If no cost data is available for the name/model combination chosen, then "cost" will be None.
|
|
85
173
|
If no tokens have been logged, then None is returned.
|
|
86
174
|
"""
|
|
87
175
|
return _token_counter.get_usage()
|
chatlas/_tools.py
CHANGED
|
@@ -2,11 +2,17 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import warnings
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
|
|
5
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Optional
|
|
6
6
|
|
|
7
|
+
import openai
|
|
7
8
|
from pydantic import BaseModel, Field, create_model
|
|
8
9
|
|
|
9
10
|
from . import _utils
|
|
11
|
+
from ._content import (
|
|
12
|
+
ContentToolResult,
|
|
13
|
+
ContentToolResultImage,
|
|
14
|
+
ContentToolResultResource,
|
|
15
|
+
)
|
|
10
16
|
|
|
11
17
|
__all__ = (
|
|
12
18
|
"Tool",
|
|
@@ -14,6 +20,8 @@ __all__ = (
|
|
|
14
20
|
)
|
|
15
21
|
|
|
16
22
|
if TYPE_CHECKING:
|
|
23
|
+
from mcp import ClientSession as MCPClientSession
|
|
24
|
+
from mcp import Tool as MCPTool
|
|
17
25
|
from openai.types.chat import ChatCompletionToolParam
|
|
18
26
|
|
|
19
27
|
|
|
@@ -28,26 +36,168 @@ class Tool:
|
|
|
28
36
|
----------
|
|
29
37
|
func
|
|
30
38
|
The function to be invoked when the tool is called.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
39
|
+
name
|
|
40
|
+
The name of the tool.
|
|
41
|
+
description
|
|
42
|
+
A description of what the tool does.
|
|
43
|
+
parameters
|
|
44
|
+
A dictionary describing the input parameters and their types.
|
|
37
45
|
"""
|
|
38
46
|
|
|
39
47
|
func: Callable[..., Any] | Callable[..., Awaitable[Any]]
|
|
40
48
|
|
|
41
49
|
def __init__(
|
|
42
50
|
self,
|
|
43
|
-
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
44
51
|
*,
|
|
45
|
-
|
|
52
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
53
|
+
name: str,
|
|
54
|
+
description: str,
|
|
55
|
+
parameters: dict[str, Any],
|
|
46
56
|
):
|
|
57
|
+
self.name = name
|
|
47
58
|
self.func = func
|
|
48
59
|
self._is_async = _utils.is_async_callable(func)
|
|
49
|
-
self.schema =
|
|
50
|
-
|
|
60
|
+
self.schema: "ChatCompletionToolParam" = {
|
|
61
|
+
"type": "function",
|
|
62
|
+
"function": {
|
|
63
|
+
"name": name,
|
|
64
|
+
"description": description,
|
|
65
|
+
"parameters": parameters,
|
|
66
|
+
},
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_func(
|
|
71
|
+
cls: type["Tool"],
|
|
72
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
73
|
+
*,
|
|
74
|
+
model: Optional[type[BaseModel]] = None,
|
|
75
|
+
) -> "Tool":
|
|
76
|
+
"""
|
|
77
|
+
Create a Tool from a Python function
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
func
|
|
82
|
+
The function to wrap as a tool.
|
|
83
|
+
model
|
|
84
|
+
A Pydantic model that describes the input parameters for the function.
|
|
85
|
+
If not provided, the model will be inferred from the function's type hints.
|
|
86
|
+
The primary reason why you might want to provide a model in
|
|
87
|
+
Note that the name and docstring of the model takes precedence over the
|
|
88
|
+
name and docstring of the function.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
Tool
|
|
93
|
+
A new Tool instance wrapping the provided function.
|
|
94
|
+
|
|
95
|
+
Raises
|
|
96
|
+
------
|
|
97
|
+
ValueError
|
|
98
|
+
If there is a mismatch between model fields and function parameters.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
if model is None:
|
|
102
|
+
model = func_to_basemodel(func)
|
|
103
|
+
|
|
104
|
+
# Throw if there is a mismatch between the model and the function parameters
|
|
105
|
+
params = inspect.signature(func).parameters
|
|
106
|
+
fields = model.model_fields
|
|
107
|
+
diff = set(params) ^ set(fields)
|
|
108
|
+
if diff:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"`model` fields must match tool function parameters exactly. "
|
|
111
|
+
f"Fields found in one but not the other: {diff}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
params = basemodel_to_param_schema(model)
|
|
115
|
+
|
|
116
|
+
return cls(
|
|
117
|
+
func=func,
|
|
118
|
+
name=model.__name__ or func.__name__,
|
|
119
|
+
description=model.__doc__ or func.__doc__ or "",
|
|
120
|
+
parameters=params,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_mcp(
|
|
125
|
+
cls: type["Tool"],
|
|
126
|
+
session: "MCPClientSession",
|
|
127
|
+
mcp_tool: "MCPTool",
|
|
128
|
+
) -> "Tool":
|
|
129
|
+
"""
|
|
130
|
+
Create a Tool from an MCP tool
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
session
|
|
135
|
+
The MCP client session to use for calling the tool.
|
|
136
|
+
mcp_tool
|
|
137
|
+
The MCP tool to wrap.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
Tool
|
|
142
|
+
A new Tool instance wrapping the MCP tool.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
async def _call(**args: Any) -> AsyncGenerator[ContentToolResult, None]:
|
|
146
|
+
result = await session.call_tool(mcp_tool.name, args)
|
|
147
|
+
|
|
148
|
+
# Raise an error if the tool call resulted in an error. It doesn't seem to be
|
|
149
|
+
# very well defined how to get at the error message, but it appears that it gets
|
|
150
|
+
# stored in the `text` attribute of the content. Also, empirically, the error
|
|
151
|
+
# message seems to include `Error executing tool {tool_name}: ...`, so
|
|
152
|
+
if result.isError:
|
|
153
|
+
err_msg = getattr(
|
|
154
|
+
result.content[0],
|
|
155
|
+
"text",
|
|
156
|
+
f"Error executing tool {mcp_tool.name}.",
|
|
157
|
+
)
|
|
158
|
+
raise RuntimeError(err_msg)
|
|
159
|
+
|
|
160
|
+
for content in result.content:
|
|
161
|
+
if content.type == "text":
|
|
162
|
+
yield ContentToolResult(value=content.text)
|
|
163
|
+
elif content.type == "image":
|
|
164
|
+
if content.mimeType not in (
|
|
165
|
+
"image/png",
|
|
166
|
+
"image/jpeg",
|
|
167
|
+
"image/webp",
|
|
168
|
+
"image/gif",
|
|
169
|
+
):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Unsupported image MIME type: {content.mimeType}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
yield ContentToolResultImage(
|
|
175
|
+
value=content.data,
|
|
176
|
+
mime_type=content.mimeType,
|
|
177
|
+
)
|
|
178
|
+
elif content.type == "resource":
|
|
179
|
+
from mcp.types import TextResourceContents
|
|
180
|
+
|
|
181
|
+
resource = content.resource
|
|
182
|
+
if isinstance(resource, TextResourceContents):
|
|
183
|
+
blob = resource.text.encode("utf-8")
|
|
184
|
+
else:
|
|
185
|
+
blob = resource.blob.encode("utf-8")
|
|
186
|
+
|
|
187
|
+
yield ContentToolResultResource(
|
|
188
|
+
value=blob, mime_type=content.resource.mimeType
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
raise RuntimeError(f"Unexpected content type: {content.type}")
|
|
192
|
+
|
|
193
|
+
params = mcp_tool_input_schema_to_param_schema(mcp_tool.inputSchema)
|
|
194
|
+
|
|
195
|
+
return cls(
|
|
196
|
+
func=_utils.wrap_async(_call),
|
|
197
|
+
name=mcp_tool.name,
|
|
198
|
+
description=mcp_tool.description or "",
|
|
199
|
+
parameters=params,
|
|
200
|
+
)
|
|
51
201
|
|
|
52
202
|
|
|
53
203
|
class ToolRejectError(Exception):
|
|
@@ -160,14 +310,6 @@ def func_to_basemodel(func: Callable) -> type[BaseModel]:
|
|
|
160
310
|
|
|
161
311
|
|
|
162
312
|
def basemodel_to_param_schema(model: type[BaseModel]) -> dict[str, object]:
|
|
163
|
-
try:
|
|
164
|
-
import openai
|
|
165
|
-
except ImportError:
|
|
166
|
-
raise ImportError(
|
|
167
|
-
"The openai package is required for this functionality. "
|
|
168
|
-
"Please install it with `pip install openai`."
|
|
169
|
-
)
|
|
170
|
-
|
|
171
313
|
# Lean on openai's ability to translate BaseModel.model_json_schema()
|
|
172
314
|
# to a valid tool schema (this wouldn't be impossible to do ourselves,
|
|
173
315
|
# but it's fair amount of logic to substitute `$refs`, etc.)
|
|
@@ -177,10 +319,27 @@ def basemodel_to_param_schema(model: type[BaseModel]) -> dict[str, object]:
|
|
|
177
319
|
if "parameters" not in fn:
|
|
178
320
|
raise ValueError("Expected `parameters` in function definition.")
|
|
179
321
|
|
|
180
|
-
params = fn["parameters"]
|
|
322
|
+
params = rm_param_titles(fn["parameters"])
|
|
323
|
+
|
|
324
|
+
return params
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def mcp_tool_input_schema_to_param_schema(
|
|
328
|
+
input_schema: dict[str, Any],
|
|
329
|
+
) -> dict[str, object]:
|
|
330
|
+
params = rm_param_titles(input_schema)
|
|
331
|
+
|
|
332
|
+
if "additionalProperties" not in params:
|
|
333
|
+
params["additionalProperties"] = False
|
|
334
|
+
|
|
335
|
+
return params
|
|
336
|
+
|
|
181
337
|
|
|
182
|
-
|
|
183
|
-
|
|
338
|
+
def rm_param_titles(
|
|
339
|
+
params: dict[str, object],
|
|
340
|
+
) -> dict[str, object]:
|
|
341
|
+
# For some reason, pydantic wants to include a title at the model and field
|
|
342
|
+
# level. I don't think we actually need or want this.
|
|
184
343
|
if "title" in params:
|
|
185
344
|
del params["title"]
|
|
186
345
|
|