chatlas 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +79 -45
- chatlas/_auto.py +3 -12
- chatlas/_chat.py +774 -148
- chatlas/_content.py +149 -29
- chatlas/_databricks.py +4 -14
- chatlas/_github.py +21 -25
- chatlas/_google.py +71 -32
- chatlas/_groq.py +15 -18
- chatlas/_interpolate.py +3 -4
- chatlas/_mcp_manager.py +306 -0
- chatlas/_ollama.py +14 -18
- chatlas/_openai.py +74 -39
- chatlas/_perplexity.py +14 -18
- chatlas/_provider.py +78 -8
- chatlas/_snowflake.py +29 -18
- chatlas/_tokens.py +85 -5
- chatlas/_tools.py +181 -22
- chatlas/_turn.py +2 -18
- chatlas/_utils.py +27 -1
- chatlas/_version.py +2 -2
- chatlas/data/prices.json +264 -0
- chatlas/types/anthropic/_submit.py +2 -0
- chatlas/types/openai/_client.py +1 -0
- chatlas/types/openai/_client_azure.py +1 -0
- chatlas/types/openai/_submit.py +4 -1
- chatlas-0.9.1.dist-info/METADATA +141 -0
- chatlas-0.9.1.dist-info/RECORD +48 -0
- chatlas-0.8.1.dist-info/METADATA +0 -383
- chatlas-0.8.1.dist-info/RECORD +0 -46
- {chatlas-0.8.1.dist-info → chatlas-0.9.1.dist-info}/WHEEL +0 -0
- {chatlas-0.8.1.dist-info → chatlas-0.9.1.dist-info}/licenses/LICENSE +0 -0
chatlas/_chat.py
CHANGED
|
@@ -42,27 +42,41 @@ from ._display import (
|
|
|
42
42
|
MockMarkdownDisplay,
|
|
43
43
|
)
|
|
44
44
|
from ._logging import log_tool_error
|
|
45
|
-
from .
|
|
45
|
+
from ._mcp_manager import MCPSessionManager
|
|
46
|
+
from ._provider import Provider, StandardModelParams, SubmitInputArgsT
|
|
47
|
+
from ._tokens import compute_cost, get_token_pricing
|
|
46
48
|
from ._tools import Tool, ToolRejectError
|
|
47
49
|
from ._turn import Turn, user_turn
|
|
48
|
-
from ._typing_extensions import TypedDict
|
|
49
|
-
from ._utils import html_escape, wrap_async
|
|
50
|
+
from ._typing_extensions import TypedDict, TypeGuard
|
|
51
|
+
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
|
|
50
52
|
|
|
51
53
|
|
|
52
|
-
class
|
|
53
|
-
|
|
54
|
+
class TokensDict(TypedDict):
|
|
55
|
+
"""
|
|
56
|
+
A TypedDict representing the token counts for a turn in the chat.
|
|
57
|
+
This is used to represent the token counts for each turn in the chat.
|
|
58
|
+
`role` represents the role of the turn (i.e., "user" or "assistant").
|
|
59
|
+
`tokens` represents the new tokens used in the turn.
|
|
60
|
+
`tokens_total` represents the total tokens used in the turn.
|
|
61
|
+
Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output).
|
|
62
|
+
This would have a `tokens_total` of 12.
|
|
63
|
+
"""
|
|
54
64
|
|
|
65
|
+
role: Literal["user", "assistant"]
|
|
66
|
+
tokens: int
|
|
67
|
+
tokens_total: int
|
|
55
68
|
|
|
56
|
-
SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
|
|
57
|
-
"""
|
|
58
|
-
A TypedDict representing the arguments that can be passed to the `.chat()`
|
|
59
|
-
method of a [](`~chatlas.Chat`) instance.
|
|
60
|
-
"""
|
|
61
69
|
|
|
62
70
|
CompletionT = TypeVar("CompletionT")
|
|
63
71
|
|
|
64
72
|
EchoOptions = Literal["output", "all", "none", "text"]
|
|
65
73
|
|
|
74
|
+
T = TypeVar("T")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def is_present(value: T | None | MISSING_TYPE) -> TypeGuard[T]:
|
|
78
|
+
return value is not None and not isinstance(value, MISSING_TYPE)
|
|
79
|
+
|
|
66
80
|
|
|
67
81
|
class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
68
82
|
"""
|
|
@@ -82,7 +96,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
82
96
|
def __init__(
|
|
83
97
|
self,
|
|
84
98
|
provider: Provider,
|
|
85
|
-
|
|
99
|
+
system_prompt: Optional[str] = None,
|
|
86
100
|
):
|
|
87
101
|
"""
|
|
88
102
|
Create a new chat object.
|
|
@@ -91,11 +105,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
91
105
|
----------
|
|
92
106
|
provider
|
|
93
107
|
A [](`~chatlas.Provider`) object.
|
|
94
|
-
|
|
95
|
-
A
|
|
108
|
+
system_prompt
|
|
109
|
+
A system prompt to set the behavior of the assistant.
|
|
96
110
|
"""
|
|
97
111
|
self.provider = provider
|
|
98
|
-
self._turns: list[Turn] =
|
|
112
|
+
self._turns: list[Turn] = []
|
|
113
|
+
self.system_prompt = system_prompt
|
|
114
|
+
|
|
99
115
|
self._tools: dict[str, Tool] = {}
|
|
100
116
|
self._on_tool_request_callbacks = CallbackManager()
|
|
101
117
|
self._on_tool_result_callbacks = CallbackManager()
|
|
@@ -105,6 +121,11 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
105
121
|
"rich_console": {},
|
|
106
122
|
"css_styles": {},
|
|
107
123
|
}
|
|
124
|
+
self._mcp_manager = MCPSessionManager()
|
|
125
|
+
|
|
126
|
+
# Chat input parameters from `set_model_params()`
|
|
127
|
+
self._standard_model_params: StandardModelParams = {}
|
|
128
|
+
self._submit_input_kwargs: Optional[SubmitInputArgsT] = None
|
|
108
129
|
|
|
109
130
|
def get_turns(
|
|
110
131
|
self,
|
|
@@ -149,8 +170,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
149
170
|
"""
|
|
150
171
|
Set the turns of the chat.
|
|
151
172
|
|
|
152
|
-
|
|
153
|
-
|
|
173
|
+
Replaces the current chat history state (i.e., turns) with the provided turns.
|
|
174
|
+
This can be useful for:
|
|
175
|
+
* Clearing (or trimming) the chat history (i.e., `.set_turns([])`).
|
|
176
|
+
* Restoring context from a previous chat.
|
|
154
177
|
|
|
155
178
|
Parameters
|
|
156
179
|
----------
|
|
@@ -165,7 +188,28 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
165
188
|
"Consider removing this turn and setting the `.system_prompt` separately "
|
|
166
189
|
"if you want to change the system prompt."
|
|
167
190
|
)
|
|
168
|
-
|
|
191
|
+
|
|
192
|
+
turns_list = list(turns)
|
|
193
|
+
# Preserve the system prompt if it exists
|
|
194
|
+
if self._turns and self._turns[0].role == "system":
|
|
195
|
+
turns_list.insert(0, self._turns[0])
|
|
196
|
+
self._turns = turns_list
|
|
197
|
+
|
|
198
|
+
def add_turn(self, turn: Turn):
|
|
199
|
+
"""
|
|
200
|
+
Add a turn to the chat.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
turn
|
|
205
|
+
The turn to add. Turns with the role "system" are not allowed.
|
|
206
|
+
"""
|
|
207
|
+
if turn.role == "system":
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Turns with the role 'system' are not allowed. "
|
|
210
|
+
"The system prompt must be set separately using the `.system_prompt` property."
|
|
211
|
+
)
|
|
212
|
+
self._turns.append(turn)
|
|
169
213
|
|
|
170
214
|
@property
|
|
171
215
|
def system_prompt(self) -> str | None:
|
|
@@ -188,43 +232,14 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
188
232
|
if value is not None:
|
|
189
233
|
self._turns.insert(0, Turn("system", value))
|
|
190
234
|
|
|
191
|
-
|
|
192
|
-
def tokens(self) -> list[tuple[int, int] | None]: ...
|
|
193
|
-
|
|
194
|
-
@overload
|
|
195
|
-
def tokens(
|
|
196
|
-
self,
|
|
197
|
-
values: Literal["cumulative"],
|
|
198
|
-
) -> list[tuple[int, int] | None]: ...
|
|
199
|
-
|
|
200
|
-
@overload
|
|
201
|
-
def tokens(
|
|
202
|
-
self,
|
|
203
|
-
values: Literal["discrete"],
|
|
204
|
-
) -> list[int]: ...
|
|
205
|
-
|
|
206
|
-
def tokens(
|
|
207
|
-
self,
|
|
208
|
-
values: Literal["cumulative", "discrete"] = "discrete",
|
|
209
|
-
) -> list[int] | list[tuple[int, int] | None]:
|
|
235
|
+
def get_tokens(self) -> list[TokensDict]:
|
|
210
236
|
"""
|
|
211
237
|
Get the tokens for each turn in the chat.
|
|
212
238
|
|
|
213
|
-
Parameters
|
|
214
|
-
----------
|
|
215
|
-
values
|
|
216
|
-
If "cumulative" (the default), the result can be summed to get the
|
|
217
|
-
chat's overall token usage (helpful for computing overall cost of
|
|
218
|
-
the chat). If "discrete", the result can be summed to get the number of
|
|
219
|
-
tokens the turns will cost to generate the next response (helpful
|
|
220
|
-
for estimating cost of the next response, or for determining if you
|
|
221
|
-
are about to exceed the token limit).
|
|
222
|
-
|
|
223
239
|
Returns
|
|
224
240
|
-------
|
|
225
|
-
list[
|
|
226
|
-
|
|
227
|
-
1st turn includes the tokens count for the system prompt (if any).
|
|
241
|
+
list[TokensDict]
|
|
242
|
+
A list of dictionaries with the token counts for each (non-system) turn
|
|
228
243
|
|
|
229
244
|
Raises
|
|
230
245
|
------
|
|
@@ -238,9 +253,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
238
253
|
|
|
239
254
|
turns = self.get_turns(include_system_prompt=False)
|
|
240
255
|
|
|
241
|
-
if values == "cumulative":
|
|
242
|
-
return [turn.tokens for turn in turns]
|
|
243
|
-
|
|
244
256
|
if len(turns) == 0:
|
|
245
257
|
return []
|
|
246
258
|
|
|
@@ -276,12 +288,21 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
276
288
|
"Expected the 1st assistant turn to contain token counts. " + err_info
|
|
277
289
|
)
|
|
278
290
|
|
|
279
|
-
res: list[
|
|
291
|
+
res: list[TokensDict] = [
|
|
280
292
|
# Implied token count for the 1st user input
|
|
281
|
-
|
|
293
|
+
{
|
|
294
|
+
"role": "user",
|
|
295
|
+
"tokens": turns[1].tokens[0],
|
|
296
|
+
"tokens_total": turns[1].tokens[0],
|
|
297
|
+
},
|
|
282
298
|
# The token count for the 1st assistant response
|
|
283
|
-
|
|
299
|
+
{
|
|
300
|
+
"role": "assistant",
|
|
301
|
+
"tokens": turns[1].tokens[1],
|
|
302
|
+
"tokens_total": turns[1].tokens[1],
|
|
303
|
+
},
|
|
284
304
|
]
|
|
305
|
+
|
|
285
306
|
for i in range(1, len(turns) - 1, 2):
|
|
286
307
|
ti = turns[i]
|
|
287
308
|
tj = turns[i + 2]
|
|
@@ -296,15 +317,102 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
296
317
|
)
|
|
297
318
|
res.extend(
|
|
298
319
|
[
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
320
|
+
{
|
|
321
|
+
"role": "user",
|
|
322
|
+
# Implied token count for the user input
|
|
323
|
+
"tokens": tj.tokens[0] - sum(ti.tokens),
|
|
324
|
+
# Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent
|
|
325
|
+
"tokens_total": tj.tokens[0],
|
|
326
|
+
},
|
|
327
|
+
{
|
|
328
|
+
"role": "assistant",
|
|
329
|
+
# The token count for the assistant response
|
|
330
|
+
"tokens": tj.tokens[1],
|
|
331
|
+
# Total tokens = Total Assistant tokens used in the turn
|
|
332
|
+
"tokens_total": tj.tokens[1],
|
|
333
|
+
},
|
|
303
334
|
]
|
|
304
335
|
)
|
|
305
336
|
|
|
306
337
|
return res
|
|
307
338
|
|
|
339
|
+
def get_cost(
|
|
340
|
+
self,
|
|
341
|
+
options: Literal["all", "last"] = "all",
|
|
342
|
+
token_price: Optional[tuple[float, float]] = None,
|
|
343
|
+
) -> float:
|
|
344
|
+
"""
|
|
345
|
+
Estimate the cost of the chat.
|
|
346
|
+
|
|
347
|
+
Note
|
|
348
|
+
----
|
|
349
|
+
This is a rough estimate, treat it as such. Providers may change their
|
|
350
|
+
pricing frequently and without notice.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
options
|
|
355
|
+
One of the following (default is "all"):
|
|
356
|
+
- `"all"`: Return the total cost of all turns in the chat.
|
|
357
|
+
- `"last"`: Return the cost of the last turn in the chat.
|
|
358
|
+
token_price
|
|
359
|
+
An optional tuple in the format of (input_token_cost,
|
|
360
|
+
output_token_cost) for bringing your own cost information.
|
|
361
|
+
- `"input_token_cost"`: The cost per user token in USD per
|
|
362
|
+
million tokens.
|
|
363
|
+
- `"output_token_cost"`: The cost per assistant token in USD
|
|
364
|
+
per million tokens.
|
|
365
|
+
|
|
366
|
+
Returns
|
|
367
|
+
-------
|
|
368
|
+
float
|
|
369
|
+
The cost of the chat, in USD.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
# Look up token cost for user and input tokens based on the provider and model
|
|
373
|
+
turns_tokens = self.get_tokens()
|
|
374
|
+
if token_price:
|
|
375
|
+
input_token_price = token_price[0] / 1e6
|
|
376
|
+
output_token_price = token_price[1] / 1e6
|
|
377
|
+
else:
|
|
378
|
+
price_token = get_token_pricing(self.provider.name, self.provider.model)
|
|
379
|
+
if not price_token:
|
|
380
|
+
raise KeyError(
|
|
381
|
+
f"We could not locate pricing information for model '{self.provider.model}' from provider '{self.provider.name}'. "
|
|
382
|
+
"If you know the pricing for this model, specify it in `token_price`."
|
|
383
|
+
)
|
|
384
|
+
input_token_price = price_token["input"] / 1e6
|
|
385
|
+
output_token_price = price_token["output"] / 1e6
|
|
386
|
+
|
|
387
|
+
if len(turns_tokens) == 0:
|
|
388
|
+
return 0.0
|
|
389
|
+
|
|
390
|
+
if options not in ("all", "last"):
|
|
391
|
+
raise ValueError(
|
|
392
|
+
f"Expected `options` to be one of 'all' or 'last', not '{options}'"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if options == "all":
|
|
396
|
+
asst_tokens = sum(
|
|
397
|
+
u["tokens_total"] for u in turns_tokens if u["role"] == "assistant"
|
|
398
|
+
)
|
|
399
|
+
user_tokens = sum(
|
|
400
|
+
u["tokens_total"] for u in turns_tokens if u["role"] == "user"
|
|
401
|
+
)
|
|
402
|
+
cost = (asst_tokens * output_token_price) + (
|
|
403
|
+
user_tokens * input_token_price
|
|
404
|
+
)
|
|
405
|
+
return cost
|
|
406
|
+
|
|
407
|
+
last_turn = turns_tokens[-1]
|
|
408
|
+
if last_turn["role"] == "assistant":
|
|
409
|
+
return last_turn["tokens"] * output_token_price
|
|
410
|
+
if last_turn["role"] == "user":
|
|
411
|
+
return last_turn["tokens_total"] * input_token_price
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Expected last turn to have a role of 'user' or `'assistant'`, not '{last_turn['role']}'"
|
|
414
|
+
)
|
|
415
|
+
|
|
308
416
|
def token_count(
|
|
309
417
|
self,
|
|
310
418
|
*args: Content | str,
|
|
@@ -397,6 +505,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
397
505
|
*,
|
|
398
506
|
stream: bool = True,
|
|
399
507
|
port: int = 0,
|
|
508
|
+
host: str = "127.0.0.1",
|
|
400
509
|
launch_browser: bool = True,
|
|
401
510
|
bg_thread: Optional[bool] = None,
|
|
402
511
|
echo: Optional[EchoOptions] = None,
|
|
@@ -412,6 +521,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
412
521
|
Whether to stream the response (i.e., have the response appear in chunks).
|
|
413
522
|
port
|
|
414
523
|
The port to run the app on (the default is 0, which will choose a random port).
|
|
524
|
+
host
|
|
525
|
+
The host to run the app on (the default is "127.0.0.1").
|
|
415
526
|
launch_browser
|
|
416
527
|
Whether to launch a browser window.
|
|
417
528
|
bg_thread
|
|
@@ -479,7 +590,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
479
590
|
app = App(app_ui, server)
|
|
480
591
|
|
|
481
592
|
def _run_app():
|
|
482
|
-
run_app(app, launch_browser=launch_browser, port=port)
|
|
593
|
+
run_app(app, launch_browser=launch_browser, port=port, host=host)
|
|
483
594
|
|
|
484
595
|
# Use bg_thread by default in Jupyter and Positron
|
|
485
596
|
if bg_thread is None:
|
|
@@ -910,10 +1021,422 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
910
1021
|
json = res[0]
|
|
911
1022
|
return json.value
|
|
912
1023
|
|
|
1024
|
+
def set_model_params(
|
|
1025
|
+
self,
|
|
1026
|
+
*,
|
|
1027
|
+
temperature: float | None | MISSING_TYPE = MISSING,
|
|
1028
|
+
top_p: float | None | MISSING_TYPE = MISSING,
|
|
1029
|
+
top_k: int | None | MISSING_TYPE = MISSING,
|
|
1030
|
+
frequency_penalty: float | None | MISSING_TYPE = MISSING,
|
|
1031
|
+
presence_penalty: float | None | MISSING_TYPE = MISSING,
|
|
1032
|
+
seed: int | None | MISSING_TYPE = MISSING,
|
|
1033
|
+
max_tokens: int | None | MISSING_TYPE = MISSING,
|
|
1034
|
+
log_probs: bool | None | MISSING_TYPE = MISSING,
|
|
1035
|
+
stop_sequences: list[str] | None | MISSING_TYPE = MISSING,
|
|
1036
|
+
kwargs: SubmitInputArgsT | None | MISSING_TYPE = MISSING,
|
|
1037
|
+
):
|
|
1038
|
+
"""
|
|
1039
|
+
Set common model parameters for the chat.
|
|
1040
|
+
|
|
1041
|
+
A unified interface for setting common model parameters
|
|
1042
|
+
across different providers. This method is useful for setting
|
|
1043
|
+
parameters that are commonly supported by most providers, such as
|
|
1044
|
+
temperature, top_p, etc.
|
|
1045
|
+
|
|
1046
|
+
By default, if the parameter is not set (i.e., set to `MISSING`),
|
|
1047
|
+
the provider's default value is used. If you want to reset a
|
|
1048
|
+
parameter to its default value, set it to `None`.
|
|
1049
|
+
|
|
1050
|
+
Parameters
|
|
1051
|
+
----------
|
|
1052
|
+
temperature
|
|
1053
|
+
Temperature of the sampling distribution.
|
|
1054
|
+
top_p
|
|
1055
|
+
The cumulative probability for token selection.
|
|
1056
|
+
top_k
|
|
1057
|
+
The number of highest probability vocabulary tokens to keep.
|
|
1058
|
+
frequency_penalty
|
|
1059
|
+
Frequency penalty for generated tokens.
|
|
1060
|
+
presence_penalty
|
|
1061
|
+
Presence penalty for generated tokens.
|
|
1062
|
+
seed
|
|
1063
|
+
Seed for random number generator.
|
|
1064
|
+
max_tokens
|
|
1065
|
+
Maximum number of tokens to generate.
|
|
1066
|
+
log_probs
|
|
1067
|
+
Include the log probabilities in the output?
|
|
1068
|
+
stop_sequences
|
|
1069
|
+
A character vector of tokens to stop generation on.
|
|
1070
|
+
kwargs
|
|
1071
|
+
Additional keyword arguments to use when submitting input to the
|
|
1072
|
+
model. When calling this method repeatedly with different parameters,
|
|
1073
|
+
only the parameters from the last call will be used.
|
|
1074
|
+
"""
|
|
1075
|
+
|
|
1076
|
+
params: StandardModelParams = {}
|
|
1077
|
+
|
|
1078
|
+
# Collect specified parameters
|
|
1079
|
+
if is_present(temperature):
|
|
1080
|
+
params["temperature"] = temperature
|
|
1081
|
+
if is_present(top_p):
|
|
1082
|
+
params["top_p"] = top_p
|
|
1083
|
+
if is_present(top_k):
|
|
1084
|
+
params["top_k"] = top_k
|
|
1085
|
+
if is_present(frequency_penalty):
|
|
1086
|
+
params["frequency_penalty"] = frequency_penalty
|
|
1087
|
+
if is_present(presence_penalty):
|
|
1088
|
+
params["presence_penalty"] = presence_penalty
|
|
1089
|
+
if is_present(seed):
|
|
1090
|
+
params["seed"] = seed
|
|
1091
|
+
if is_present(max_tokens):
|
|
1092
|
+
params["max_tokens"] = max_tokens
|
|
1093
|
+
if is_present(log_probs):
|
|
1094
|
+
params["log_probs"] = log_probs
|
|
1095
|
+
if is_present(stop_sequences):
|
|
1096
|
+
params["stop_sequences"] = stop_sequences
|
|
1097
|
+
|
|
1098
|
+
# Warn about un-supported parameters
|
|
1099
|
+
supported = self.provider.supported_model_params()
|
|
1100
|
+
unsupported = set(params.keys()) - set(supported)
|
|
1101
|
+
if unsupported:
|
|
1102
|
+
warnings.warn(
|
|
1103
|
+
f"The following parameters are not supported by the provider: {unsupported}. "
|
|
1104
|
+
"Please check the provider's documentation for supported parameters.",
|
|
1105
|
+
UserWarning,
|
|
1106
|
+
)
|
|
1107
|
+
# Drop the unsupported parameters
|
|
1108
|
+
for key in unsupported:
|
|
1109
|
+
del params[key]
|
|
1110
|
+
|
|
1111
|
+
# Drop parameters that are set to None
|
|
1112
|
+
discard = []
|
|
1113
|
+
if temperature is None:
|
|
1114
|
+
discard.append("temperature")
|
|
1115
|
+
if top_p is None:
|
|
1116
|
+
discard.append("top_p")
|
|
1117
|
+
if top_k is None:
|
|
1118
|
+
discard.append("top_k")
|
|
1119
|
+
if frequency_penalty is None:
|
|
1120
|
+
discard.append("frequency_penalty")
|
|
1121
|
+
if presence_penalty is None:
|
|
1122
|
+
discard.append("presence_penalty")
|
|
1123
|
+
if seed is None:
|
|
1124
|
+
discard.append("seed")
|
|
1125
|
+
if max_tokens is None:
|
|
1126
|
+
discard.append("max_tokens")
|
|
1127
|
+
if log_probs is None:
|
|
1128
|
+
discard.append("log_probs")
|
|
1129
|
+
if stop_sequences is None:
|
|
1130
|
+
discard.append("stop_sequences")
|
|
1131
|
+
|
|
1132
|
+
for key in discard:
|
|
1133
|
+
if key in self._standard_model_params:
|
|
1134
|
+
del self._standard_model_params[key]
|
|
1135
|
+
|
|
1136
|
+
# Update the standard model parameters
|
|
1137
|
+
self._standard_model_params.update(params)
|
|
1138
|
+
|
|
1139
|
+
# Update the submit input kwargs
|
|
1140
|
+
if kwargs is None:
|
|
1141
|
+
self._submit_input_kwargs = None
|
|
1142
|
+
|
|
1143
|
+
if is_present(kwargs):
|
|
1144
|
+
self._submit_input_kwargs = kwargs
|
|
1145
|
+
|
|
1146
|
+
async def register_mcp_tools_http_stream_async(
|
|
1147
|
+
self,
|
|
1148
|
+
*,
|
|
1149
|
+
url: str,
|
|
1150
|
+
include_tools: Sequence[str] = (),
|
|
1151
|
+
exclude_tools: Sequence[str] = (),
|
|
1152
|
+
name: Optional[str] = None,
|
|
1153
|
+
namespace: Optional[str] = None,
|
|
1154
|
+
transport_kwargs: Optional[dict[str, Any]] = None,
|
|
1155
|
+
):
|
|
1156
|
+
"""
|
|
1157
|
+
Register tools from an MCP server using streamable HTTP transport.
|
|
1158
|
+
|
|
1159
|
+
Connects to an MCP server (that communicates over a streamable HTTP
|
|
1160
|
+
transport) and registers the available tools. This is useful for
|
|
1161
|
+
utilizing tools provided by an MCP server running on a remote server (or
|
|
1162
|
+
locally) over HTTP.
|
|
1163
|
+
|
|
1164
|
+
Pre-requisites
|
|
1165
|
+
--------------
|
|
1166
|
+
|
|
1167
|
+
::: {.callout-note}
|
|
1168
|
+
Requires the `mcp` package to be installed. Install it with:
|
|
1169
|
+
|
|
1170
|
+
```bash
|
|
1171
|
+
pip install mcp
|
|
1172
|
+
```
|
|
1173
|
+
:::
|
|
1174
|
+
|
|
1175
|
+
Parameters
|
|
1176
|
+
----------
|
|
1177
|
+
url
|
|
1178
|
+
URL endpoint where the Streamable HTTP server is mounted (e.g.,
|
|
1179
|
+
`http://localhost:8000/mcp`)
|
|
1180
|
+
name
|
|
1181
|
+
A unique name for the MCP server session. If not provided, the name
|
|
1182
|
+
is derived from the MCP server information. This name is primarily
|
|
1183
|
+
useful for cleanup purposes (i.e., to close a particular MCP
|
|
1184
|
+
session).
|
|
1185
|
+
include_tools
|
|
1186
|
+
List of tool names to include. By default, all available tools are
|
|
1187
|
+
included.
|
|
1188
|
+
exclude_tools
|
|
1189
|
+
List of tool names to exclude. This parameter and `include_tools`
|
|
1190
|
+
are mutually exclusive.
|
|
1191
|
+
namespace
|
|
1192
|
+
A namespace to prepend to tool names (i.e., `namespace.tool_name`)
|
|
1193
|
+
from this MCP server. This is primarily useful to avoid name
|
|
1194
|
+
collisions with other tools already registered with the chat. This
|
|
1195
|
+
namespace applies when tools are advertised to the LLM, so try
|
|
1196
|
+
to use a meaningful name that describes the server and/or the tools
|
|
1197
|
+
it provides. For example, if you have a server that provides tools
|
|
1198
|
+
for mathematical operations, you might use `math` as the namespace.
|
|
1199
|
+
transport_kwargs
|
|
1200
|
+
Additional keyword arguments for the transport layer (i.e.,
|
|
1201
|
+
`mcp.client.streamable_http.streamablehttp_client`).
|
|
1202
|
+
|
|
1203
|
+
Returns
|
|
1204
|
+
-------
|
|
1205
|
+
None
|
|
1206
|
+
|
|
1207
|
+
See Also
|
|
1208
|
+
--------
|
|
1209
|
+
* `.cleanup_mcp_tools_async()` : Cleanup registered MCP tools.
|
|
1210
|
+
* `.register_mcp_tools_stdio_async()` : Register tools from an MCP server using stdio transport.
|
|
1211
|
+
|
|
1212
|
+
Note
|
|
1213
|
+
----
|
|
1214
|
+
Unlike the `.register_mcp_tools_stdio_async()` method, this method does
|
|
1215
|
+
not launch an MCP server. Instead, it assumes an HTTP server is already
|
|
1216
|
+
running at the specified URL. This is useful for connecting to an
|
|
1217
|
+
existing MCP server that is already running and serving tools.
|
|
1218
|
+
|
|
1219
|
+
Examples
|
|
1220
|
+
--------
|
|
1221
|
+
|
|
1222
|
+
Assuming you have a Python script `my_mcp_server.py` that implements an
|
|
1223
|
+
MCP server like so:
|
|
1224
|
+
|
|
1225
|
+
```python
|
|
1226
|
+
from mcp.server.fastmcp import FastMCP
|
|
1227
|
+
|
|
1228
|
+
app = FastMCP("my_server")
|
|
1229
|
+
|
|
1230
|
+
@app.tool(description="Add two numbers.")
|
|
1231
|
+
def add(x: int, y: int) -> int:
|
|
1232
|
+
return x + y
|
|
1233
|
+
|
|
1234
|
+
app.run(transport="streamable-http")
|
|
1235
|
+
```
|
|
1236
|
+
|
|
1237
|
+
You can launch this server like so:
|
|
1238
|
+
|
|
1239
|
+
```bash
|
|
1240
|
+
python my_mcp_server.py
|
|
1241
|
+
```
|
|
1242
|
+
|
|
1243
|
+
Then, you can register this server with the chat as follows:
|
|
1244
|
+
|
|
1245
|
+
```python
|
|
1246
|
+
await chat.register_mcp_tools_http_stream_async(
|
|
1247
|
+
url="http://localhost:8080/mcp"
|
|
1248
|
+
)
|
|
1249
|
+
```
|
|
1250
|
+
"""
|
|
1251
|
+
if isinstance(exclude_tools, str):
|
|
1252
|
+
exclude_tools = [exclude_tools]
|
|
1253
|
+
if isinstance(include_tools, str):
|
|
1254
|
+
include_tools = [include_tools]
|
|
1255
|
+
|
|
1256
|
+
session_info = await self._mcp_manager.register_http_stream_tools(
|
|
1257
|
+
name=name,
|
|
1258
|
+
url=url,
|
|
1259
|
+
include_tools=include_tools,
|
|
1260
|
+
exclude_tools=exclude_tools,
|
|
1261
|
+
namespace=namespace,
|
|
1262
|
+
transport_kwargs=transport_kwargs or {},
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
overlapping_tools = set(self._tools.keys()) & set(session_info.tools)
|
|
1266
|
+
if overlapping_tools:
|
|
1267
|
+
await self._mcp_manager.close_sessions([session_info.name])
|
|
1268
|
+
raise ValueError(
|
|
1269
|
+
f"The following tools are already registered: {overlapping_tools}. "
|
|
1270
|
+
"Consider providing a namespace when registering this MCP server "
|
|
1271
|
+
"to avoid name collisions."
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
self._tools.update(session_info.tools)
|
|
1275
|
+
|
|
1276
|
+
async def register_mcp_tools_stdio_async(
|
|
1277
|
+
self,
|
|
1278
|
+
*,
|
|
1279
|
+
command: str,
|
|
1280
|
+
args: list[str],
|
|
1281
|
+
name: Optional[str] = None,
|
|
1282
|
+
include_tools: Sequence[str] = (),
|
|
1283
|
+
exclude_tools: Sequence[str] = (),
|
|
1284
|
+
namespace: Optional[str] = None,
|
|
1285
|
+
transport_kwargs: Optional[dict[str, Any]] = None,
|
|
1286
|
+
):
|
|
1287
|
+
"""
|
|
1288
|
+
Register tools from a MCP server using stdio (standard input/output) transport.
|
|
1289
|
+
|
|
1290
|
+
Useful for launching an MCP server and registering its tools with the chat -- all
|
|
1291
|
+
from the same Python process.
|
|
1292
|
+
|
|
1293
|
+
In more detail, this method:
|
|
1294
|
+
|
|
1295
|
+
1. Executes the given `command` with the provided `args`.
|
|
1296
|
+
* This should start an MCP server that communicates via stdio.
|
|
1297
|
+
2. Establishes a client connection to the MCP server using the `mcp` package.
|
|
1298
|
+
3. Registers the available tools from the MCP server with the chat.
|
|
1299
|
+
4. Returns a cleanup callback to close the MCP session and remove the tools.
|
|
1300
|
+
|
|
1301
|
+
Pre-requisites
|
|
1302
|
+
--------------
|
|
1303
|
+
|
|
1304
|
+
::: {.callout-note}
|
|
1305
|
+
Requires the `mcp` package to be installed. Install it with:
|
|
1306
|
+
|
|
1307
|
+
```bash
|
|
1308
|
+
pip install mcp
|
|
1309
|
+
```
|
|
1310
|
+
:::
|
|
1311
|
+
|
|
1312
|
+
Parameters
|
|
1313
|
+
----------
|
|
1314
|
+
command
|
|
1315
|
+
System command to execute to start the MCP server (e.g., `python`).
|
|
1316
|
+
args
|
|
1317
|
+
Arguments to pass to the system command (e.g., `["-m",
|
|
1318
|
+
"my_mcp_server"]`).
|
|
1319
|
+
name
|
|
1320
|
+
A unique name for the MCP server session. If not provided, the name
|
|
1321
|
+
is derived from the MCP server information. This name is primarily
|
|
1322
|
+
useful for cleanup purposes (i.e., to close a particular MCP
|
|
1323
|
+
session).
|
|
1324
|
+
include_tools
|
|
1325
|
+
List of tool names to include. By default, all available tools are
|
|
1326
|
+
included.
|
|
1327
|
+
exclude_tools
|
|
1328
|
+
List of tool names to exclude. This parameter and `include_tools`
|
|
1329
|
+
are mutually exclusive.
|
|
1330
|
+
namespace
|
|
1331
|
+
A namespace to prepend to tool names (i.e., `namespace.tool_name`)
|
|
1332
|
+
from this MCP server. This is primarily useful to avoid name
|
|
1333
|
+
collisions with other tools already registered with the chat. This
|
|
1334
|
+
namespace applies when tools are advertised to the LLM, so try
|
|
1335
|
+
to use a meaningful name that describes the server and/or the tools
|
|
1336
|
+
it provides. For example, if you have a server that provides tools
|
|
1337
|
+
for mathematical operations, you might use `math` as the namespace.
|
|
1338
|
+
transport_kwargs
|
|
1339
|
+
Additional keyword arguments for the stdio transport layer (i.e.,
|
|
1340
|
+
`mcp.client.stdio.stdio_client`).
|
|
1341
|
+
|
|
1342
|
+
Returns
|
|
1343
|
+
-------
|
|
1344
|
+
None
|
|
1345
|
+
|
|
1346
|
+
See Also
|
|
1347
|
+
--------
|
|
1348
|
+
* `.cleanup_mcp_tools_async()` : Cleanup registered MCP tools.
|
|
1349
|
+
* `.register_mcp_tools_http_stream_async()` : Register tools from an MCP server using streamable HTTP transport.
|
|
1350
|
+
|
|
1351
|
+
Examples
|
|
1352
|
+
--------
|
|
1353
|
+
|
|
1354
|
+
Assuming you have a Python script `my_mcp_server.py` that implements an
|
|
1355
|
+
MCP server like so
|
|
1356
|
+
|
|
1357
|
+
```python
|
|
1358
|
+
from mcp.server.fastmcp import FastMCP
|
|
1359
|
+
|
|
1360
|
+
app = FastMCP("my_server")
|
|
1361
|
+
|
|
1362
|
+
@app.tool(description="Add two numbers.")
|
|
1363
|
+
def add(y: int, z: int) -> int:
|
|
1364
|
+
return y - z
|
|
1365
|
+
|
|
1366
|
+
app.run(transport="stdio")
|
|
1367
|
+
```
|
|
1368
|
+
|
|
1369
|
+
You can register this server with the chat as follows:
|
|
1370
|
+
|
|
1371
|
+
```python
|
|
1372
|
+
from chatlas import ChatOpenAI
|
|
1373
|
+
|
|
1374
|
+
chat = ChatOpenAI()
|
|
1375
|
+
|
|
1376
|
+
await chat.register_mcp_tools_stdio_async(
|
|
1377
|
+
command="python",
|
|
1378
|
+
args=["-m", "my_mcp_server"],
|
|
1379
|
+
)
|
|
1380
|
+
```
|
|
1381
|
+
"""
|
|
1382
|
+
if isinstance(exclude_tools, str):
|
|
1383
|
+
exclude_tools = [exclude_tools]
|
|
1384
|
+
if isinstance(include_tools, str):
|
|
1385
|
+
include_tools = [include_tools]
|
|
1386
|
+
|
|
1387
|
+
session_info = await self._mcp_manager.register_stdio_tools(
|
|
1388
|
+
command=command,
|
|
1389
|
+
args=args,
|
|
1390
|
+
name=name,
|
|
1391
|
+
include_tools=include_tools,
|
|
1392
|
+
exclude_tools=exclude_tools,
|
|
1393
|
+
namespace=namespace,
|
|
1394
|
+
transport_kwargs=transport_kwargs or {},
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
overlapping_tools = set(self._tools.keys()) & set(session_info.tools)
|
|
1398
|
+
if overlapping_tools:
|
|
1399
|
+
await self._mcp_manager.close_sessions([session_info.name])
|
|
1400
|
+
raise ValueError(
|
|
1401
|
+
f"The following tools are already registered: {overlapping_tools}. "
|
|
1402
|
+
"Consider providing a namespace when registering this MCP server "
|
|
1403
|
+
"to avoid name collisions."
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
self._tools.update(session_info.tools)
|
|
1407
|
+
|
|
1408
|
+
async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
|
|
1409
|
+
"""
|
|
1410
|
+
Close MCP server connections (and their corresponding tools).
|
|
1411
|
+
|
|
1412
|
+
This method closes the MCP client sessions and removes the tools registered
|
|
1413
|
+
from the MCP servers. If a specific `name` is provided, it will only clean
|
|
1414
|
+
up the tools and session associated with that name. If no name is provided,
|
|
1415
|
+
it will clean up all registered MCP tools and sessions.
|
|
1416
|
+
|
|
1417
|
+
Parameters
|
|
1418
|
+
----------
|
|
1419
|
+
names
|
|
1420
|
+
If provided, only clean up the tools and session associated
|
|
1421
|
+
with these names. If not provided, clean up all registered MCP tools and sessions.
|
|
1422
|
+
|
|
1423
|
+
Returns
|
|
1424
|
+
-------
|
|
1425
|
+
None
|
|
1426
|
+
"""
|
|
1427
|
+
closed_sessions = await self._mcp_manager.close_sessions(names)
|
|
1428
|
+
|
|
1429
|
+
# Remove relevant MCP tools from the main tools registry
|
|
1430
|
+
for session in closed_sessions:
|
|
1431
|
+
for tool_name in session.tools:
|
|
1432
|
+
if tool_name in self._tools:
|
|
1433
|
+
del self._tools[tool_name]
|
|
1434
|
+
|
|
913
1435
|
def register_tool(
|
|
914
1436
|
self,
|
|
915
1437
|
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
916
1438
|
*,
|
|
1439
|
+
force: bool = False,
|
|
917
1440
|
model: Optional[type[BaseModel]] = None,
|
|
918
1441
|
):
|
|
919
1442
|
"""
|
|
@@ -930,7 +1453,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
930
1453
|
recommended):
|
|
931
1454
|
|
|
932
1455
|
```python
|
|
933
|
-
from chatlas import ChatOpenAI
|
|
1456
|
+
from chatlas import ChatOpenAI
|
|
934
1457
|
|
|
935
1458
|
|
|
936
1459
|
def add(a: int, b: int) -> int:
|
|
@@ -958,7 +1481,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
958
1481
|
and also more directly document the input parameters:
|
|
959
1482
|
|
|
960
1483
|
```python
|
|
961
|
-
from chatlas import ChatOpenAI
|
|
1484
|
+
from chatlas import ChatOpenAI
|
|
962
1485
|
from pydantic import BaseModel, Field
|
|
963
1486
|
|
|
964
1487
|
|
|
@@ -983,16 +1506,62 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
983
1506
|
----------
|
|
984
1507
|
func
|
|
985
1508
|
The function to be invoked when the tool is called.
|
|
1509
|
+
force
|
|
1510
|
+
If `True`, overwrite any existing tool with the same name. If `False`
|
|
1511
|
+
(the default), raise an error if a tool with the same name already exists.
|
|
986
1512
|
model
|
|
987
1513
|
A Pydantic model that describes the input parameters for the function.
|
|
988
1514
|
If not provided, the model will be inferred from the function's type hints.
|
|
989
1515
|
The primary reason why you might want to provide a model in
|
|
990
1516
|
Note that the name and docstring of the model takes precedence over the
|
|
991
1517
|
name and docstring of the function.
|
|
1518
|
+
|
|
1519
|
+
Raises
|
|
1520
|
+
------
|
|
1521
|
+
ValueError
|
|
1522
|
+
If a tool with the same name already exists and `force` is `False`.
|
|
992
1523
|
"""
|
|
993
|
-
tool = Tool(func, model=model)
|
|
1524
|
+
tool = Tool.from_func(func, model=model)
|
|
1525
|
+
if tool.name in self._tools and not force:
|
|
1526
|
+
raise ValueError(
|
|
1527
|
+
f"Tool with name '{tool.name}' is already registered. "
|
|
1528
|
+
"Set `force=True` to overwrite it."
|
|
1529
|
+
)
|
|
994
1530
|
self._tools[tool.name] = tool
|
|
995
1531
|
|
|
1532
|
+
def get_tools(self) -> list[Tool]:
|
|
1533
|
+
"""
|
|
1534
|
+
Get the list of registered tools.
|
|
1535
|
+
|
|
1536
|
+
Returns
|
|
1537
|
+
-------
|
|
1538
|
+
list[Tool]
|
|
1539
|
+
A list of `Tool` instances that are currently registered with the chat.
|
|
1540
|
+
"""
|
|
1541
|
+
return list(self._tools.values())
|
|
1542
|
+
|
|
1543
|
+
def set_tools(
|
|
1544
|
+
self, tools: list[Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool]
|
|
1545
|
+
):
|
|
1546
|
+
"""
|
|
1547
|
+
Set the tools for the chat.
|
|
1548
|
+
|
|
1549
|
+
This replaces any previously registered tools with the provided list of
|
|
1550
|
+
tools. This is for advanced usage -- typically, you would use
|
|
1551
|
+
`.register_tool()` to register individual tools as needed.
|
|
1552
|
+
|
|
1553
|
+
Parameters
|
|
1554
|
+
----------
|
|
1555
|
+
tools
|
|
1556
|
+
A list of `Tool` instances to set as the chat's tools.
|
|
1557
|
+
"""
|
|
1558
|
+
self._tools = {}
|
|
1559
|
+
for tool in tools:
|
|
1560
|
+
if isinstance(tool, Tool):
|
|
1561
|
+
self._tools[tool.name] = tool
|
|
1562
|
+
else:
|
|
1563
|
+
self.register_tool(tool)
|
|
1564
|
+
|
|
996
1565
|
def on_tool_request(self, callback: Callable[[ContentToolRequest], None]):
|
|
997
1566
|
"""
|
|
998
1567
|
Register a callback for a tool request event.
|
|
@@ -1257,22 +1826,23 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1257
1826
|
assert turn is not None
|
|
1258
1827
|
user_turn_result = None
|
|
1259
1828
|
|
|
1260
|
-
|
|
1829
|
+
all_results: list[ContentToolResult] = []
|
|
1261
1830
|
for x in turn.contents:
|
|
1262
1831
|
if isinstance(x, ContentToolRequest):
|
|
1263
1832
|
if echo == "output":
|
|
1264
1833
|
self._echo_content(f"\n\n{x}\n\n")
|
|
1265
1834
|
if content == "all":
|
|
1266
1835
|
yield x
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1836
|
+
results = self._invoke_tool(x)
|
|
1837
|
+
for res in results:
|
|
1838
|
+
if echo == "output":
|
|
1839
|
+
self._echo_content(f"\n\n{res}\n\n")
|
|
1840
|
+
if content == "all":
|
|
1841
|
+
yield res
|
|
1842
|
+
all_results.append(res)
|
|
1273
1843
|
|
|
1274
|
-
if
|
|
1275
|
-
user_turn_result = Turn("user",
|
|
1844
|
+
if all_results:
|
|
1845
|
+
user_turn_result = Turn("user", all_results)
|
|
1276
1846
|
|
|
1277
1847
|
@overload
|
|
1278
1848
|
def _chat_impl_async(
|
|
@@ -1316,24 +1886,25 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1316
1886
|
assert turn is not None
|
|
1317
1887
|
user_turn_result = None
|
|
1318
1888
|
|
|
1319
|
-
|
|
1889
|
+
all_results: list[ContentToolResult] = []
|
|
1320
1890
|
for x in turn.contents:
|
|
1321
1891
|
if isinstance(x, ContentToolRequest):
|
|
1322
1892
|
if echo == "output":
|
|
1323
1893
|
self._echo_content(f"\n\n{x}\n\n")
|
|
1324
1894
|
if content == "all":
|
|
1325
1895
|
yield x
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1896
|
+
results = self._invoke_tool_async(x)
|
|
1897
|
+
async for res in results:
|
|
1898
|
+
if echo == "output":
|
|
1899
|
+
self._echo_content(f"\n\n{res}\n\n")
|
|
1900
|
+
if content == "all":
|
|
1901
|
+
yield res
|
|
1902
|
+
else:
|
|
1903
|
+
yield "\n\n"
|
|
1904
|
+
all_results.append(res)
|
|
1905
|
+
|
|
1906
|
+
if all_results:
|
|
1907
|
+
user_turn_result = Turn("user", all_results)
|
|
1337
1908
|
|
|
1338
1909
|
def _submit_turns(
|
|
1339
1910
|
self,
|
|
@@ -1354,13 +1925,25 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1354
1925
|
if echo == "all":
|
|
1355
1926
|
emit_user_contents(user_turn, emit)
|
|
1356
1927
|
|
|
1928
|
+
# Start collecting additional keyword args (from model parameters)
|
|
1929
|
+
all_kwargs = self.provider.translate_model_params(
|
|
1930
|
+
params=self._standard_model_params,
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
# Add any additional kwargs provided by the user
|
|
1934
|
+
if self._submit_input_kwargs:
|
|
1935
|
+
all_kwargs.update(self._submit_input_kwargs)
|
|
1936
|
+
|
|
1937
|
+
if kwargs:
|
|
1938
|
+
all_kwargs.update(kwargs)
|
|
1939
|
+
|
|
1357
1940
|
if stream:
|
|
1358
1941
|
response = self.provider.chat_perform(
|
|
1359
1942
|
stream=True,
|
|
1360
1943
|
turns=[*self._turns, user_turn],
|
|
1361
1944
|
tools=self._tools,
|
|
1362
1945
|
data_model=data_model,
|
|
1363
|
-
kwargs=
|
|
1946
|
+
kwargs=all_kwargs,
|
|
1364
1947
|
)
|
|
1365
1948
|
|
|
1366
1949
|
result = None
|
|
@@ -1385,7 +1968,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1385
1968
|
turns=[*self._turns, user_turn],
|
|
1386
1969
|
tools=self._tools,
|
|
1387
1970
|
data_model=data_model,
|
|
1388
|
-
kwargs=
|
|
1971
|
+
kwargs=all_kwargs,
|
|
1389
1972
|
)
|
|
1390
1973
|
|
|
1391
1974
|
turn = self.provider.value_turn(
|
|
@@ -1462,54 +2045,56 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1462
2045
|
|
|
1463
2046
|
self._turns.extend([user_turn, turn])
|
|
1464
2047
|
|
|
1465
|
-
def _invoke_tool(self,
|
|
1466
|
-
tool_def = self._tools.get(
|
|
2048
|
+
def _invoke_tool(self, request: ContentToolRequest):
|
|
2049
|
+
tool_def = self._tools.get(request.name, None)
|
|
1467
2050
|
func = tool_def.func if tool_def is not None else None
|
|
1468
2051
|
|
|
1469
2052
|
if func is None:
|
|
1470
|
-
|
|
1471
|
-
|
|
2053
|
+
yield self._handle_tool_error_result(
|
|
2054
|
+
request,
|
|
2055
|
+
error=RuntimeError("Unknown tool."),
|
|
2056
|
+
)
|
|
2057
|
+
return
|
|
1472
2058
|
|
|
1473
2059
|
# First, invoke the request callbacks. If a ToolRejectError is raised,
|
|
1474
2060
|
# treat it like a tool failure (i.e., gracefully handle it).
|
|
1475
2061
|
result: ContentToolResult | None = None
|
|
1476
2062
|
try:
|
|
1477
|
-
self._on_tool_request_callbacks.invoke(
|
|
2063
|
+
self._on_tool_request_callbacks.invoke(request)
|
|
1478
2064
|
except ToolRejectError as e:
|
|
1479
|
-
|
|
2065
|
+
yield self._handle_tool_error_result(request, e)
|
|
2066
|
+
return
|
|
1480
2067
|
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
2068
|
+
try:
|
|
2069
|
+
if isinstance(request.arguments, dict):
|
|
2070
|
+
res = func(**request.arguments)
|
|
2071
|
+
else:
|
|
2072
|
+
res = func(request.arguments)
|
|
2073
|
+
|
|
2074
|
+
# Normalize res as a generator of results.
|
|
2075
|
+
if not inspect.isgenerator(res):
|
|
2076
|
+
|
|
2077
|
+
def _as_generator(res):
|
|
2078
|
+
yield res
|
|
1488
2079
|
|
|
1489
|
-
|
|
1490
|
-
|
|
2080
|
+
res = _as_generator(res)
|
|
2081
|
+
|
|
2082
|
+
for x in res:
|
|
2083
|
+
if isinstance(x, ContentToolResult):
|
|
2084
|
+
result = x
|
|
1491
2085
|
else:
|
|
1492
|
-
result = ContentToolResult(value=
|
|
2086
|
+
result = ContentToolResult(value=x)
|
|
1493
2087
|
|
|
1494
|
-
result.request =
|
|
1495
|
-
except Exception as e:
|
|
1496
|
-
result = ContentToolResult(value=None, error=e, request=x)
|
|
2088
|
+
result.request = request
|
|
1497
2089
|
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
warnings.warn(
|
|
1501
|
-
f"Calling tool '{x.name}' led to an error.",
|
|
1502
|
-
ToolFailureWarning,
|
|
1503
|
-
stacklevel=2,
|
|
1504
|
-
)
|
|
1505
|
-
traceback.print_exc()
|
|
1506
|
-
log_tool_error(x.name, str(x.arguments), result.error)
|
|
2090
|
+
self._on_tool_result_callbacks.invoke(result)
|
|
2091
|
+
yield result
|
|
1507
2092
|
|
|
1508
|
-
|
|
1509
|
-
|
|
2093
|
+
except Exception as e:
|
|
2094
|
+
yield self._handle_tool_error_result(request, e)
|
|
1510
2095
|
|
|
1511
|
-
async def _invoke_tool_async(self,
|
|
1512
|
-
tool_def = self._tools.get(
|
|
2096
|
+
async def _invoke_tool_async(self, request: ContentToolRequest):
|
|
2097
|
+
tool_def = self._tools.get(request.name, None)
|
|
1513
2098
|
func = None
|
|
1514
2099
|
if tool_def:
|
|
1515
2100
|
if tool_def._is_async:
|
|
@@ -1518,45 +2103,59 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1518
2103
|
func = wrap_async(tool_def.func)
|
|
1519
2104
|
|
|
1520
2105
|
if func is None:
|
|
1521
|
-
|
|
1522
|
-
|
|
2106
|
+
yield self._handle_tool_error_result(
|
|
2107
|
+
request,
|
|
2108
|
+
error=RuntimeError("Unknown tool."),
|
|
2109
|
+
)
|
|
2110
|
+
return
|
|
1523
2111
|
|
|
1524
2112
|
# First, invoke the request callbacks. If a ToolRejectError is raised,
|
|
1525
2113
|
# treat it like a tool failure (i.e., gracefully handle it).
|
|
1526
2114
|
result: ContentToolResult | None = None
|
|
1527
2115
|
try:
|
|
1528
|
-
await self._on_tool_request_callbacks.invoke_async(
|
|
2116
|
+
await self._on_tool_request_callbacks.invoke_async(request)
|
|
1529
2117
|
except ToolRejectError as e:
|
|
1530
|
-
|
|
2118
|
+
yield self._handle_tool_error_result(request, e)
|
|
2119
|
+
return
|
|
1531
2120
|
|
|
1532
2121
|
# Invoke the tool (if it hasn't been rejected).
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
2122
|
+
try:
|
|
2123
|
+
if isinstance(request.arguments, dict):
|
|
2124
|
+
res = await func(**request.arguments)
|
|
2125
|
+
else:
|
|
2126
|
+
res = await func(request.arguments)
|
|
2127
|
+
|
|
2128
|
+
# Normalize res into a generator of results.
|
|
2129
|
+
if not inspect.isasyncgen(res):
|
|
2130
|
+
|
|
2131
|
+
async def _as_async_generator(res):
|
|
2132
|
+
yield res
|
|
1539
2133
|
|
|
1540
|
-
|
|
1541
|
-
|
|
2134
|
+
res = _as_async_generator(res)
|
|
2135
|
+
|
|
2136
|
+
async for x in res:
|
|
2137
|
+
if isinstance(x, ContentToolResult):
|
|
2138
|
+
result = x
|
|
1542
2139
|
else:
|
|
1543
|
-
result = ContentToolResult(value=
|
|
2140
|
+
result = ContentToolResult(value=x)
|
|
1544
2141
|
|
|
1545
|
-
result.request =
|
|
1546
|
-
|
|
1547
|
-
result
|
|
2142
|
+
result.request = request
|
|
2143
|
+
await self._on_tool_result_callbacks.invoke_async(result)
|
|
2144
|
+
yield result
|
|
1548
2145
|
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
warnings.warn(
|
|
1552
|
-
f"Calling tool '{x.name}' led to an error.",
|
|
1553
|
-
ToolFailureWarning,
|
|
1554
|
-
stacklevel=2,
|
|
1555
|
-
)
|
|
1556
|
-
traceback.print_exc()
|
|
1557
|
-
log_tool_error(x.name, str(x.arguments), result.error)
|
|
2146
|
+
except Exception as e:
|
|
2147
|
+
yield self._handle_tool_error_result(request, e)
|
|
1558
2148
|
|
|
1559
|
-
|
|
2149
|
+
def _handle_tool_error_result(self, request: ContentToolRequest, error: Exception):
|
|
2150
|
+
warnings.warn(
|
|
2151
|
+
f"Calling tool '{request.name}' led to an error: {error}",
|
|
2152
|
+
ToolFailureWarning,
|
|
2153
|
+
stacklevel=2,
|
|
2154
|
+
)
|
|
2155
|
+
traceback.print_exc()
|
|
2156
|
+
log_tool_error(request.name, str(request.arguments), error)
|
|
2157
|
+
result = ContentToolResult(value=None, error=error, request=request)
|
|
2158
|
+
self._on_tool_result_callbacks.invoke(result)
|
|
1560
2159
|
return result
|
|
1561
2160
|
|
|
1562
2161
|
def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
|
|
@@ -1571,10 +2170,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1571
2170
|
return ChatMarkdownDisplay(MockMarkdownDisplay(), self)
|
|
1572
2171
|
|
|
1573
2172
|
# rich does a lot to detect a notebook environment, but it doesn't
|
|
1574
|
-
# detect Quarto
|
|
2173
|
+
# detect Quarto, or a Positron notebook
|
|
1575
2174
|
from rich.console import Console
|
|
1576
2175
|
|
|
1577
|
-
is_web = Console().is_jupyter or
|
|
2176
|
+
is_web = Console().is_jupyter or is_quarto() or is_positron_notebook()
|
|
1578
2177
|
|
|
1579
2178
|
opts = self._echo_options
|
|
1580
2179
|
|
|
@@ -1622,8 +2221,23 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1622
2221
|
|
|
1623
2222
|
def __repr__(self):
|
|
1624
2223
|
turns = self.get_turns(include_system_prompt=True)
|
|
1625
|
-
tokens =
|
|
1626
|
-
|
|
2224
|
+
tokens = self.get_tokens()
|
|
2225
|
+
tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant")
|
|
2226
|
+
tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user")
|
|
2227
|
+
|
|
2228
|
+
res = f"<Chat {self.provider.name}/{self.provider.model} turns={len(turns)} tokens={tokens_user}/{tokens_asst}"
|
|
2229
|
+
|
|
2230
|
+
# Add cost info only if we can compute it
|
|
2231
|
+
cost = compute_cost(
|
|
2232
|
+
self.provider.name,
|
|
2233
|
+
self.provider.model,
|
|
2234
|
+
tokens_user,
|
|
2235
|
+
tokens_asst,
|
|
2236
|
+
)
|
|
2237
|
+
if cost is not None:
|
|
2238
|
+
res += f" ${round(cost, ndigits=2)}"
|
|
2239
|
+
|
|
2240
|
+
res += ">"
|
|
1627
2241
|
for turn in turns:
|
|
1628
2242
|
res += "\n" + turn.__repr__(indent=2)
|
|
1629
2243
|
return res + "\n"
|
|
@@ -1818,3 +2432,15 @@ class ToolFailureWarning(RuntimeWarning):
|
|
|
1818
2432
|
|
|
1819
2433
|
# By default warnings are shown once; we want to always show them.
|
|
1820
2434
|
warnings.simplefilter("always", ToolFailureWarning)
|
|
2435
|
+
|
|
2436
|
+
|
|
2437
|
+
def is_quarto():
|
|
2438
|
+
return os.getenv("QUARTO_PYTHON", None) is not None
|
|
2439
|
+
|
|
2440
|
+
|
|
2441
|
+
def is_positron_notebook():
|
|
2442
|
+
try:
|
|
2443
|
+
mode = get_ipython().session_mode # noqa: F821 # type: ignore
|
|
2444
|
+
return mode == "notebook"
|
|
2445
|
+
except Exception:
|
|
2446
|
+
return False
|