chatlas 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +104 -6
- chatlas/_chat.py +246 -24
- chatlas/_content.py +20 -7
- chatlas/_google.py +312 -161
- chatlas/_merge.py +1 -1
- chatlas/_ollama.py +8 -0
- chatlas/_openai.py +64 -7
- chatlas/_provider.py +16 -8
- chatlas/py.typed +0 -0
- chatlas/types/__init__.py +5 -1
- chatlas/types/anthropic/_client.py +0 -8
- chatlas/types/anthropic/_submit.py +2 -3
- chatlas/types/google/_client.py +12 -91
- chatlas/types/google/_submit.py +40 -87
- chatlas/types/openai/_client.py +1 -0
- chatlas/types/openai/_client_azure.py +1 -0
- chatlas/types/openai/_submit.py +10 -2
- {chatlas-0.2.0.dist-info → chatlas-0.4.0.dist-info}/METADATA +25 -11
- chatlas-0.4.0.dist-info/RECORD +38 -0
- {chatlas-0.2.0.dist-info → chatlas-0.4.0.dist-info}/WHEEL +1 -1
- chatlas-0.2.0.dist-info/RECORD +0 -37
chatlas/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from ._anthropic import ChatAnthropic, ChatBedrockAnthropic
|
|
|
3
3
|
from ._chat import Chat
|
|
4
4
|
from ._content_image import content_image_file, content_image_plot, content_image_url
|
|
5
5
|
from ._github import ChatGithub
|
|
6
|
-
from ._google import ChatGoogle
|
|
6
|
+
from ._google import ChatGoogle, ChatVertex
|
|
7
7
|
from ._groq import ChatGroq
|
|
8
8
|
from ._interpolate import interpolate, interpolate_file
|
|
9
9
|
from ._ollama import ChatOllama
|
|
@@ -24,6 +24,7 @@ __all__ = (
|
|
|
24
24
|
"ChatOpenAI",
|
|
25
25
|
"ChatAzureOpenAI",
|
|
26
26
|
"ChatPerplexity",
|
|
27
|
+
"ChatVertex",
|
|
27
28
|
"Chat",
|
|
28
29
|
"content_image_file",
|
|
29
30
|
"content_image_plot",
|
chatlas/_anthropic.py
CHANGED
|
@@ -20,7 +20,7 @@ from ._logging import log_model_default
|
|
|
20
20
|
from ._provider import Provider
|
|
21
21
|
from ._tokens import tokens_log
|
|
22
22
|
from ._tools import Tool, basemodel_to_param_schema
|
|
23
|
-
from ._turn import Turn, normalize_turns
|
|
23
|
+
from ._turn import Turn, normalize_turns, user_turn
|
|
24
24
|
|
|
25
25
|
if TYPE_CHECKING:
|
|
26
26
|
from anthropic.types import (
|
|
@@ -311,7 +311,8 @@ class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]):
|
|
|
311
311
|
if stream:
|
|
312
312
|
stream = False
|
|
313
313
|
warnings.warn(
|
|
314
|
-
"Anthropic does not support structured data extraction in streaming mode."
|
|
314
|
+
"Anthropic does not support structured data extraction in streaming mode.",
|
|
315
|
+
stacklevel=2,
|
|
315
316
|
)
|
|
316
317
|
|
|
317
318
|
kwargs_full: "SubmitInputArgs" = {
|
|
@@ -371,15 +372,65 @@ class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]):
|
|
|
371
372
|
|
|
372
373
|
return completion
|
|
373
374
|
|
|
374
|
-
def stream_turn(self, completion, has_data_model
|
|
375
|
-
return self._as_turn(completion, has_data_model)
|
|
376
|
-
|
|
377
|
-
async def stream_turn_async(self, completion, has_data_model, stream) -> Turn:
|
|
375
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
378
376
|
return self._as_turn(completion, has_data_model)
|
|
379
377
|
|
|
380
378
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
381
379
|
return self._as_turn(completion, has_data_model)
|
|
382
380
|
|
|
381
|
+
def token_count(
|
|
382
|
+
self,
|
|
383
|
+
*args: Content | str,
|
|
384
|
+
tools: dict[str, Tool],
|
|
385
|
+
data_model: Optional[type[BaseModel]],
|
|
386
|
+
) -> int:
|
|
387
|
+
kwargs = self._token_count_args(
|
|
388
|
+
*args,
|
|
389
|
+
tools=tools,
|
|
390
|
+
data_model=data_model,
|
|
391
|
+
)
|
|
392
|
+
res = self._client.messages.count_tokens(**kwargs)
|
|
393
|
+
return res.input_tokens
|
|
394
|
+
|
|
395
|
+
async def token_count_async(
|
|
396
|
+
self,
|
|
397
|
+
*args: Content | str,
|
|
398
|
+
tools: dict[str, Tool],
|
|
399
|
+
data_model: Optional[type[BaseModel]],
|
|
400
|
+
) -> int:
|
|
401
|
+
kwargs = self._token_count_args(
|
|
402
|
+
*args,
|
|
403
|
+
tools=tools,
|
|
404
|
+
data_model=data_model,
|
|
405
|
+
)
|
|
406
|
+
res = await self._async_client.messages.count_tokens(**kwargs)
|
|
407
|
+
return res.input_tokens
|
|
408
|
+
|
|
409
|
+
def _token_count_args(
|
|
410
|
+
self,
|
|
411
|
+
*args: Content | str,
|
|
412
|
+
tools: dict[str, Tool],
|
|
413
|
+
data_model: Optional[type[BaseModel]],
|
|
414
|
+
) -> dict[str, Any]:
|
|
415
|
+
turn = user_turn(*args)
|
|
416
|
+
|
|
417
|
+
kwargs = self._chat_perform_args(
|
|
418
|
+
stream=False,
|
|
419
|
+
turns=[turn],
|
|
420
|
+
tools=tools,
|
|
421
|
+
data_model=data_model,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
args_to_keep = [
|
|
425
|
+
"messages",
|
|
426
|
+
"model",
|
|
427
|
+
"system",
|
|
428
|
+
"tools",
|
|
429
|
+
"tool_choice",
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}
|
|
433
|
+
|
|
383
434
|
def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]:
|
|
384
435
|
messages: list["MessageParam"] = []
|
|
385
436
|
for turn in turns:
|
|
@@ -575,6 +626,53 @@ def ChatBedrockAnthropic(
|
|
|
575
626
|
Additional arguments to pass to the `anthropic.AnthropicBedrock()`
|
|
576
627
|
client constructor.
|
|
577
628
|
|
|
629
|
+
Troubleshooting
|
|
630
|
+
---------------
|
|
631
|
+
|
|
632
|
+
If you encounter 400 or 403 errors when trying to use the model, keep the
|
|
633
|
+
following in mind:
|
|
634
|
+
|
|
635
|
+
::: {.callout-note}
|
|
636
|
+
#### Incorrect model name
|
|
637
|
+
|
|
638
|
+
If the model name is completely incorrect, you'll see an error like
|
|
639
|
+
`Error code: 400 - {'message': 'The provided model identifier is invalid.'}`
|
|
640
|
+
|
|
641
|
+
Make sure the model name is correct and active in the specified region.
|
|
642
|
+
:::
|
|
643
|
+
|
|
644
|
+
::: {.callout-note}
|
|
645
|
+
#### Models are region specific
|
|
646
|
+
|
|
647
|
+
If you encounter errors similar to `Error code: 403 - {'message': "You don't
|
|
648
|
+
have access to the model with the specified model ID."}`, make sure your
|
|
649
|
+
model is active in the relevant `aws_region`.
|
|
650
|
+
|
|
651
|
+
Keep in mind, if `aws_region` is not specified, and AWS_REGION is not set,
|
|
652
|
+
the region defaults to us-east-1, which may not match to your AWS config's
|
|
653
|
+
default region.
|
|
654
|
+
:::
|
|
655
|
+
|
|
656
|
+
::: {.callout-note}
|
|
657
|
+
#### Cross region inference ID
|
|
658
|
+
|
|
659
|
+
In some cases, even if you have the right model and the right region, you
|
|
660
|
+
may still encounter an error like `Error code: 400 - {'message':
|
|
661
|
+
'Invocation of model ID anthropic.claude-3-5-sonnet-20240620-v1:0 with
|
|
662
|
+
on-demand throughput isn't supported. Retry your request with the ID or ARN
|
|
663
|
+
of an inference profile that contains this model.'}`
|
|
664
|
+
|
|
665
|
+
In this case, you'll need to look up the 'cross region inference ID' for
|
|
666
|
+
your model. This might required opening your `aws-console` and navigating to
|
|
667
|
+
the 'Anthropic Bedrock' service page. From there, go to the 'cross region
|
|
668
|
+
inference' tab and copy the relevant ID.
|
|
669
|
+
|
|
670
|
+
For example, if the desired model ID is
|
|
671
|
+
`anthropic.claude-3-5-sonnet-20240620-v1:0`, the cross region ID might look
|
|
672
|
+
something like `us.anthropic.claude-3-5-sonnet-20240620-v1:0`.
|
|
673
|
+
:::
|
|
674
|
+
|
|
675
|
+
|
|
578
676
|
Returns
|
|
579
677
|
-------
|
|
580
678
|
Chat
|
chatlas/_chat.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import inspect
|
|
3
4
|
import os
|
|
5
|
+
import sys
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from threading import Thread
|
|
6
8
|
from typing import (
|
|
@@ -16,6 +18,7 @@ from typing import (
|
|
|
16
18
|
Optional,
|
|
17
19
|
Sequence,
|
|
18
20
|
TypeVar,
|
|
21
|
+
overload,
|
|
19
22
|
)
|
|
20
23
|
|
|
21
24
|
from pydantic import BaseModel
|
|
@@ -39,7 +42,7 @@ from ._provider import Provider
|
|
|
39
42
|
from ._tools import Tool
|
|
40
43
|
from ._turn import Turn, user_turn
|
|
41
44
|
from ._typing_extensions import TypedDict
|
|
42
|
-
from ._utils import html_escape
|
|
45
|
+
from ._utils import html_escape, wrap_async
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
class AnyTypeDict(TypedDict, total=False):
|
|
@@ -176,17 +179,209 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
176
179
|
if value is not None:
|
|
177
180
|
self._turns.insert(0, Turn("system", value))
|
|
178
181
|
|
|
179
|
-
|
|
182
|
+
@overload
|
|
183
|
+
def tokens(self) -> list[tuple[int, int] | None]: ...
|
|
184
|
+
|
|
185
|
+
@overload
|
|
186
|
+
def tokens(
|
|
187
|
+
self,
|
|
188
|
+
values: Literal["cumulative"],
|
|
189
|
+
) -> list[tuple[int, int] | None]: ...
|
|
190
|
+
|
|
191
|
+
@overload
|
|
192
|
+
def tokens(
|
|
193
|
+
self,
|
|
194
|
+
values: Literal["discrete"],
|
|
195
|
+
) -> list[int]: ...
|
|
196
|
+
|
|
197
|
+
def tokens(
|
|
198
|
+
self,
|
|
199
|
+
values: Literal["cumulative", "discrete"] = "discrete",
|
|
200
|
+
) -> list[int] | list[tuple[int, int] | None]:
|
|
180
201
|
"""
|
|
181
202
|
Get the tokens for each turn in the chat.
|
|
182
203
|
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
values
|
|
207
|
+
If "cumulative" (the default), the result can be summed to get the
|
|
208
|
+
chat's overall token usage (helpful for computing overall cost of
|
|
209
|
+
the chat). If "discrete", the result can be summed to get the number of
|
|
210
|
+
tokens the turns will cost to generate the next response (helpful
|
|
211
|
+
for estimating cost of the next response, or for determining if you
|
|
212
|
+
are about to exceed the token limit).
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
list[int]
|
|
217
|
+
A list of token counts for each (non-system) turn in the chat. The
|
|
218
|
+
1st turn includes the tokens count for the system prompt (if any).
|
|
219
|
+
|
|
220
|
+
Raises
|
|
221
|
+
------
|
|
222
|
+
ValueError
|
|
223
|
+
If the chat's turns (i.e., `.get_turns()`) are not in an expected
|
|
224
|
+
format. This may happen if the chat history is manually set (i.e.,
|
|
225
|
+
`.set_turns()`). In this case, you can inspect the "raw" token
|
|
226
|
+
values via the `.get_turns()` method (each turn has a `.tokens`
|
|
227
|
+
attribute).
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
turns = self.get_turns(include_system_prompt=False)
|
|
231
|
+
|
|
232
|
+
if values == "cumulative":
|
|
233
|
+
return [turn.tokens for turn in turns]
|
|
234
|
+
|
|
235
|
+
if len(turns) == 0:
|
|
236
|
+
return []
|
|
237
|
+
|
|
238
|
+
err_info = (
|
|
239
|
+
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
|
|
240
|
+
"Consider getting the 'raw' token values via the `.get_turns()` method "
|
|
241
|
+
"(each turn has a `.tokens` attribute)."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Sanity checks for the assumptions made to figure out user token counts
|
|
245
|
+
if len(turns) == 1:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"Expected at least two turns in the chat history. " + err_info
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if len(turns) % 2 != 0:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"Expected an even number of turns in the chat history. " + err_info
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if turns[0].role != "user":
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"Expected the 1st non-system turn to have role='user'. " + err_info
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if turns[1].role != "assistant":
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if turns[1].tokens is None:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
"Expected the 1st assistant turn to contain token counts. " + err_info
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
res: list[int] = [
|
|
271
|
+
# Implied token count for the 1st user input
|
|
272
|
+
turns[1].tokens[0],
|
|
273
|
+
# The token count for the 1st assistant response
|
|
274
|
+
turns[1].tokens[1],
|
|
275
|
+
]
|
|
276
|
+
for i in range(1, len(turns) - 1, 2):
|
|
277
|
+
ti = turns[i]
|
|
278
|
+
tj = turns[i + 2]
|
|
279
|
+
if ti.role != "assistant" or tj.role != "assistant":
|
|
280
|
+
raise ValueError(
|
|
281
|
+
"Expected even turns to have role='assistant'." + err_info
|
|
282
|
+
)
|
|
283
|
+
if ti.tokens is None or tj.tokens is None:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
"Expected role='assistant' turns to contain token counts."
|
|
286
|
+
+ err_info
|
|
287
|
+
)
|
|
288
|
+
res.extend(
|
|
289
|
+
[
|
|
290
|
+
# Implied token count for the user input
|
|
291
|
+
tj.tokens[0] - sum(ti.tokens),
|
|
292
|
+
# The token count for the assistant response
|
|
293
|
+
tj.tokens[1],
|
|
294
|
+
]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return res
|
|
298
|
+
|
|
299
|
+
def token_count(
|
|
300
|
+
self,
|
|
301
|
+
*args: Content | str,
|
|
302
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Get an estimated token count for the given input.
|
|
306
|
+
|
|
307
|
+
Estimate the token size of input content. This can help determine whether input(s)
|
|
308
|
+
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
|
|
309
|
+
sending it to the model.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
args
|
|
314
|
+
The input to get a token count for.
|
|
315
|
+
data_model
|
|
316
|
+
If the input is meant for data extraction (i.e., `.extract_data()`), then
|
|
317
|
+
this should be the Pydantic model that describes the structure of the data to
|
|
318
|
+
extract.
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
int
|
|
323
|
+
The token count for the input.
|
|
324
|
+
|
|
325
|
+
Note
|
|
326
|
+
----
|
|
327
|
+
Remember that the token count is an estimate. Also, models based on
|
|
328
|
+
`ChatOpenAI()` currently does not take tools into account when
|
|
329
|
+
estimating token counts.
|
|
330
|
+
|
|
331
|
+
Examples
|
|
332
|
+
--------
|
|
333
|
+
```python
|
|
334
|
+
from chatlas import ChatAnthropic
|
|
335
|
+
|
|
336
|
+
chat = ChatAnthropic()
|
|
337
|
+
# Estimate the token count before sending the input
|
|
338
|
+
print(chat.token_count("What is 2 + 2?"))
|
|
339
|
+
|
|
340
|
+
# Once input is sent, you can get the actual input and output
|
|
341
|
+
# token counts from the chat object
|
|
342
|
+
chat.chat("What is 2 + 2?", echo="none")
|
|
343
|
+
print(chat.token_usage())
|
|
344
|
+
```
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
return self.provider.token_count(
|
|
348
|
+
*args,
|
|
349
|
+
tools=self._tools,
|
|
350
|
+
data_model=data_model,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
async def token_count_async(
|
|
354
|
+
self,
|
|
355
|
+
*args: Content | str,
|
|
356
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
357
|
+
) -> int:
|
|
358
|
+
"""
|
|
359
|
+
Get an estimated token count for the given input asynchronously.
|
|
360
|
+
|
|
361
|
+
Estimate the token size of input content. This can help determine whether input(s)
|
|
362
|
+
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
|
|
363
|
+
sending it to the model.
|
|
364
|
+
|
|
365
|
+
Parameters
|
|
366
|
+
----------
|
|
367
|
+
args
|
|
368
|
+
The input to get a token count for.
|
|
369
|
+
data_model
|
|
370
|
+
If this input is meant for data extraction (i.e., `.extract_data_async()`),
|
|
371
|
+
then this should be the Pydantic model that describes the structure of the data
|
|
372
|
+
to extract.
|
|
373
|
+
|
|
183
374
|
Returns
|
|
184
375
|
-------
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
indices for a turn.
|
|
376
|
+
int
|
|
377
|
+
The token count for the input.
|
|
188
378
|
"""
|
|
189
|
-
|
|
379
|
+
|
|
380
|
+
return await self.provider.token_count_async(
|
|
381
|
+
*args,
|
|
382
|
+
tools=self._tools,
|
|
383
|
+
data_model=data_model,
|
|
384
|
+
)
|
|
190
385
|
|
|
191
386
|
def app(
|
|
192
387
|
self,
|
|
@@ -195,6 +390,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
195
390
|
port: int = 0,
|
|
196
391
|
launch_browser: bool = True,
|
|
197
392
|
bg_thread: Optional[bool] = None,
|
|
393
|
+
echo: Optional[Literal["text", "all", "none"]] = None,
|
|
198
394
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
199
395
|
):
|
|
200
396
|
"""
|
|
@@ -211,6 +407,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
211
407
|
bg_thread
|
|
212
408
|
Whether to run the app in a background thread. If `None`, the app will
|
|
213
409
|
run in a background thread if the current environment is a notebook.
|
|
410
|
+
echo
|
|
411
|
+
Whether to echo text content, all content (i.e., tool calls), or no content. Defaults to `"none"` when `stream=True` and `"text"` when `stream=False`.
|
|
214
412
|
kwargs
|
|
215
413
|
Additional keyword arguments to pass to the method used for requesting
|
|
216
414
|
the response.
|
|
@@ -245,10 +443,22 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
245
443
|
return
|
|
246
444
|
if stream:
|
|
247
445
|
await chat.append_message_stream(
|
|
248
|
-
self.
|
|
446
|
+
await self.stream_async(
|
|
447
|
+
user_input,
|
|
448
|
+
kwargs=kwargs,
|
|
449
|
+
echo=echo or "none",
|
|
450
|
+
)
|
|
249
451
|
)
|
|
250
452
|
else:
|
|
251
|
-
await chat.append_message(
|
|
453
|
+
await chat.append_message(
|
|
454
|
+
str(
|
|
455
|
+
self.chat(
|
|
456
|
+
user_input,
|
|
457
|
+
kwargs=kwargs,
|
|
458
|
+
echo=echo or "text",
|
|
459
|
+
)
|
|
460
|
+
)
|
|
461
|
+
)
|
|
252
462
|
|
|
253
463
|
app = App(app_ui, server)
|
|
254
464
|
|
|
@@ -755,11 +965,11 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
755
965
|
is_html = filename.suffix == ".html"
|
|
756
966
|
|
|
757
967
|
# Get contents from each turn
|
|
758
|
-
|
|
968
|
+
content_arr: list[str] = []
|
|
759
969
|
for turn in turns:
|
|
760
970
|
turn_content = "\n\n".join(
|
|
761
971
|
[
|
|
762
|
-
str(content)
|
|
972
|
+
str(content).strip()
|
|
763
973
|
for content in turn.contents
|
|
764
974
|
if include == "all" or isinstance(content, ContentText)
|
|
765
975
|
]
|
|
@@ -770,7 +980,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
770
980
|
turn_content = f"<shiny-{msg_type}-message content='{content_attr}'></shiny-{msg_type}-message>"
|
|
771
981
|
else:
|
|
772
982
|
turn_content = f"## {turn.role.capitalize()}\n\n{turn_content}"
|
|
773
|
-
|
|
983
|
+
content_arr.append(turn_content)
|
|
984
|
+
contents = "\n\n".join(content_arr)
|
|
774
985
|
|
|
775
986
|
# Shiny chat message components requires container elements
|
|
776
987
|
if is_html:
|
|
@@ -900,7 +1111,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
900
1111
|
turn = self.provider.stream_turn(
|
|
901
1112
|
result,
|
|
902
1113
|
has_data_model=data_model is not None,
|
|
903
|
-
stream=response,
|
|
904
1114
|
)
|
|
905
1115
|
|
|
906
1116
|
if echo == "all":
|
|
@@ -961,10 +1171,9 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
961
1171
|
yield text
|
|
962
1172
|
result = self.provider.stream_merge_chunks(result, chunk)
|
|
963
1173
|
|
|
964
|
-
turn =
|
|
1174
|
+
turn = self.provider.stream_turn(
|
|
965
1175
|
result,
|
|
966
1176
|
has_data_model=data_model is not None,
|
|
967
|
-
stream=response,
|
|
968
1177
|
)
|
|
969
1178
|
|
|
970
1179
|
if echo == "all":
|
|
@@ -1017,7 +1226,12 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1017
1226
|
for x in turn.contents:
|
|
1018
1227
|
if isinstance(x, ContentToolRequest):
|
|
1019
1228
|
tool_def = self._tools.get(x.name, None)
|
|
1020
|
-
func =
|
|
1229
|
+
func = None
|
|
1230
|
+
if tool_def:
|
|
1231
|
+
if tool_def._is_async:
|
|
1232
|
+
func = tool_def.func
|
|
1233
|
+
else:
|
|
1234
|
+
func = wrap_async(tool_def.func)
|
|
1021
1235
|
results.append(await self._invoke_tool_async(func, x.arguments, x.id))
|
|
1022
1236
|
|
|
1023
1237
|
if not results:
|
|
@@ -1032,7 +1246,9 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1032
1246
|
id_: str,
|
|
1033
1247
|
) -> ContentToolResult:
|
|
1034
1248
|
if func is None:
|
|
1035
|
-
return ContentToolResult(id_, None, "Unknown tool")
|
|
1249
|
+
return ContentToolResult(id_, value=None, error="Unknown tool")
|
|
1250
|
+
|
|
1251
|
+
name = func.__name__
|
|
1036
1252
|
|
|
1037
1253
|
try:
|
|
1038
1254
|
if isinstance(arguments, dict):
|
|
@@ -1040,10 +1256,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1040
1256
|
else:
|
|
1041
1257
|
result = func(arguments)
|
|
1042
1258
|
|
|
1043
|
-
return ContentToolResult(id_, result, None)
|
|
1259
|
+
return ContentToolResult(id_, value=result, error=None, name=name)
|
|
1044
1260
|
except Exception as e:
|
|
1045
|
-
log_tool_error(
|
|
1046
|
-
return ContentToolResult(id_, None, str(e))
|
|
1261
|
+
log_tool_error(name, str(arguments), e)
|
|
1262
|
+
return ContentToolResult(id_, value=None, error=str(e), name=name)
|
|
1047
1263
|
|
|
1048
1264
|
@staticmethod
|
|
1049
1265
|
async def _invoke_tool_async(
|
|
@@ -1052,7 +1268,9 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1052
1268
|
id_: str,
|
|
1053
1269
|
) -> ContentToolResult:
|
|
1054
1270
|
if func is None:
|
|
1055
|
-
return ContentToolResult(id_, None, "Unknown tool")
|
|
1271
|
+
return ContentToolResult(id_, value=None, error="Unknown tool")
|
|
1272
|
+
|
|
1273
|
+
name = func.__name__
|
|
1056
1274
|
|
|
1057
1275
|
try:
|
|
1058
1276
|
if isinstance(arguments, dict):
|
|
@@ -1060,10 +1278,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1060
1278
|
else:
|
|
1061
1279
|
result = await func(arguments)
|
|
1062
1280
|
|
|
1063
|
-
return ContentToolResult(id_, result, None)
|
|
1281
|
+
return ContentToolResult(id_, value=result, error=None, name=name)
|
|
1064
1282
|
except Exception as e:
|
|
1065
1283
|
log_tool_error(func.__name__, str(arguments), e)
|
|
1066
|
-
return ContentToolResult(id_, None, str(e))
|
|
1284
|
+
return ContentToolResult(id_, value=None, error=str(e), name=name)
|
|
1067
1285
|
|
|
1068
1286
|
def _markdown_display(
|
|
1069
1287
|
self, echo: Literal["text", "all", "none"]
|
|
@@ -1180,7 +1398,7 @@ class ChatResponse:
|
|
|
1180
1398
|
|
|
1181
1399
|
@property
|
|
1182
1400
|
def consumed(self) -> bool:
|
|
1183
|
-
return self._generator
|
|
1401
|
+
return inspect.getgeneratorstate(self._generator) == inspect.GEN_CLOSED
|
|
1184
1402
|
|
|
1185
1403
|
def __str__(self) -> str:
|
|
1186
1404
|
return self.get_content()
|
|
@@ -1230,7 +1448,11 @@ class ChatResponseAsync:
|
|
|
1230
1448
|
|
|
1231
1449
|
@property
|
|
1232
1450
|
def consumed(self) -> bool:
|
|
1233
|
-
|
|
1451
|
+
if sys.version_info < (3, 12):
|
|
1452
|
+
raise NotImplementedError(
|
|
1453
|
+
"Checking for consumed state is only supported in Python 3.12+"
|
|
1454
|
+
)
|
|
1455
|
+
return inspect.getasyncgenstate(self._generator) == inspect.AGEN_CLOSED
|
|
1234
1456
|
|
|
1235
1457
|
|
|
1236
1458
|
# ----------------------------------------------------------------------------
|
chatlas/_content.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
+
from pprint import pformat
|
|
5
6
|
from typing import Any, Literal, Optional
|
|
6
7
|
|
|
7
8
|
ImageContentTypes = Literal[
|
|
@@ -154,7 +155,7 @@ class ContentToolRequest(Content):
|
|
|
154
155
|
args_str = self._arguments_str()
|
|
155
156
|
func_call = f"{self.name}({args_str})"
|
|
156
157
|
comment = f"# tool request ({self.id})"
|
|
157
|
-
return f"
|
|
158
|
+
return f"```python\n{comment}\n{func_call}\n```\n"
|
|
158
159
|
|
|
159
160
|
def _repr_markdown_(self):
|
|
160
161
|
return self.__str__()
|
|
@@ -187,18 +188,31 @@ class ContentToolResult(Content):
|
|
|
187
188
|
The unique identifier of the tool request.
|
|
188
189
|
value
|
|
189
190
|
The value returned by the tool/function.
|
|
191
|
+
name
|
|
192
|
+
The name of the tool/function that was called.
|
|
190
193
|
error
|
|
191
194
|
An error message if the tool/function call failed.
|
|
192
195
|
"""
|
|
193
196
|
|
|
194
197
|
id: str
|
|
195
198
|
value: Any = None
|
|
199
|
+
name: Optional[str] = None
|
|
196
200
|
error: Optional[str] = None
|
|
197
201
|
|
|
202
|
+
def _get_value_and_language(self) -> tuple[str, str]:
|
|
203
|
+
if self.error:
|
|
204
|
+
return f"Tool calling failed with error: '{self.error}'", ""
|
|
205
|
+
try:
|
|
206
|
+
json_val = json.loads(self.value)
|
|
207
|
+
return pformat(json_val, indent=2, sort_dicts=False), "python"
|
|
208
|
+
except: # noqa: E722
|
|
209
|
+
return str(self.value), ""
|
|
210
|
+
|
|
198
211
|
def __str__(self):
|
|
199
212
|
comment = f"# tool result ({self.id})"
|
|
200
|
-
|
|
201
|
-
|
|
213
|
+
value, language = self._get_value_and_language()
|
|
214
|
+
|
|
215
|
+
return f"""```{language}\n{comment}\n{value}\n```"""
|
|
202
216
|
|
|
203
217
|
def _repr_markdown_(self):
|
|
204
218
|
return self.__str__()
|
|
@@ -211,9 +225,8 @@ class ContentToolResult(Content):
|
|
|
211
225
|
return res + ">"
|
|
212
226
|
|
|
213
227
|
def get_final_value(self) -> str:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
return str(self.value)
|
|
228
|
+
value, _language = self._get_value_and_language()
|
|
229
|
+
return value
|
|
217
230
|
|
|
218
231
|
|
|
219
232
|
@dataclass
|
|
@@ -236,7 +249,7 @@ class ContentJson(Content):
|
|
|
236
249
|
return json.dumps(self.value, indent=2)
|
|
237
250
|
|
|
238
251
|
def _repr_markdown_(self):
|
|
239
|
-
return f"""
|
|
252
|
+
return f"""```json\n{self.__str__()}\n```"""
|
|
240
253
|
|
|
241
254
|
def __repr__(self, indent: int = 0):
|
|
242
255
|
return " " * indent + f"<ContentJson value={self.value}>"
|