chatlas 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +5 -0
- chatlas/_anthropic.py +12 -6
- chatlas/_auto.py +7 -3
- chatlas/_chat.py +339 -120
- chatlas/_content.py +230 -32
- chatlas/_databricks.py +139 -0
- chatlas/_display.py +13 -7
- chatlas/_github.py +0 -6
- chatlas/_google.py +11 -6
- chatlas/_groq.py +0 -6
- chatlas/_ollama.py +3 -8
- chatlas/_openai.py +13 -38
- chatlas/_perplexity.py +0 -6
- chatlas/_snowflake.py +46 -23
- chatlas/_utils.py +36 -1
- chatlas/_version.py +2 -2
- chatlas/types/anthropic/_submit.py +2 -0
- chatlas/types/openai/_submit.py +11 -1
- {chatlas-0.6.1.dist-info → chatlas-0.7.1.dist-info}/METADATA +14 -10
- {chatlas-0.6.1.dist-info → chatlas-0.7.1.dist-info}/RECORD +22 -20
- chatlas-0.7.1.dist-info/licenses/LICENSE +21 -0
- {chatlas-0.6.1.dist-info → chatlas-0.7.1.dist-info}/WHEEL +0 -0
chatlas/_chat.py
CHANGED
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
+
import traceback
|
|
7
|
+
import warnings
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
from threading import Thread
|
|
8
10
|
from typing import (
|
|
@@ -31,7 +33,7 @@ from ._content import (
|
|
|
31
33
|
ContentToolResult,
|
|
32
34
|
)
|
|
33
35
|
from ._display import (
|
|
34
|
-
|
|
36
|
+
EchoDisplayOptions,
|
|
35
37
|
IPyMarkdownDisplay,
|
|
36
38
|
LiveMarkdownDisplay,
|
|
37
39
|
MarkdownDisplay,
|
|
@@ -57,6 +59,8 @@ method of a [](`~chatlas.Chat`) instance.
|
|
|
57
59
|
|
|
58
60
|
CompletionT = TypeVar("CompletionT")
|
|
59
61
|
|
|
62
|
+
EchoOptions = Literal["output", "all", "none", "text"]
|
|
63
|
+
|
|
60
64
|
|
|
61
65
|
class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
62
66
|
"""
|
|
@@ -91,7 +95,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
91
95
|
self.provider = provider
|
|
92
96
|
self._turns: list[Turn] = list(turns or [])
|
|
93
97
|
self._tools: dict[str, Tool] = {}
|
|
94
|
-
self.
|
|
98
|
+
self._current_display: Optional[MarkdownDisplay] = None
|
|
99
|
+
self._echo_options: EchoDisplayOptions = {
|
|
95
100
|
"rich_markdown": {},
|
|
96
101
|
"rich_console": {},
|
|
97
102
|
"css_styles": {},
|
|
@@ -390,7 +395,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
390
395
|
port: int = 0,
|
|
391
396
|
launch_browser: bool = True,
|
|
392
397
|
bg_thread: Optional[bool] = None,
|
|
393
|
-
echo: Optional[
|
|
398
|
+
echo: Optional[EchoOptions] = None,
|
|
399
|
+
content: Literal["text", "all"] = "all",
|
|
394
400
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
395
401
|
):
|
|
396
402
|
"""
|
|
@@ -411,6 +417,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
411
417
|
Whether to echo text content, all content (i.e., tool calls), or no
|
|
412
418
|
content. Defaults to `"none"` when `stream=True` and `"text"` when
|
|
413
419
|
`stream=False`.
|
|
420
|
+
content
|
|
421
|
+
Whether to display text content or all content (i.e., tool calls).
|
|
414
422
|
kwargs
|
|
415
423
|
Additional keyword arguments to pass to the method used for requesting
|
|
416
424
|
the response.
|
|
@@ -439,16 +447,14 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
439
447
|
)
|
|
440
448
|
|
|
441
449
|
@chat.on_user_submit
|
|
442
|
-
async def _():
|
|
443
|
-
user_input = chat.user_input()
|
|
444
|
-
if user_input is None:
|
|
445
|
-
return
|
|
450
|
+
async def _(user_input: str):
|
|
446
451
|
if stream:
|
|
447
452
|
await chat.append_message_stream(
|
|
448
453
|
await self.stream_async(
|
|
449
454
|
user_input,
|
|
450
455
|
kwargs=kwargs,
|
|
451
456
|
echo=echo or "none",
|
|
457
|
+
content=content,
|
|
452
458
|
)
|
|
453
459
|
)
|
|
454
460
|
else:
|
|
@@ -457,6 +463,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
457
463
|
self.chat(
|
|
458
464
|
user_input,
|
|
459
465
|
kwargs=kwargs,
|
|
466
|
+
stream=False,
|
|
460
467
|
echo=echo or "text",
|
|
461
468
|
)
|
|
462
469
|
)
|
|
@@ -485,7 +492,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
485
492
|
def console(
|
|
486
493
|
self,
|
|
487
494
|
*,
|
|
488
|
-
echo:
|
|
495
|
+
echo: EchoOptions = "output",
|
|
489
496
|
stream: bool = True,
|
|
490
497
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
491
498
|
):
|
|
@@ -523,7 +530,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
523
530
|
def chat(
|
|
524
531
|
self,
|
|
525
532
|
*args: Content | str,
|
|
526
|
-
echo:
|
|
533
|
+
echo: EchoOptions = "output",
|
|
527
534
|
stream: bool = True,
|
|
528
535
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
529
536
|
) -> ChatResponse:
|
|
@@ -558,7 +565,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
558
565
|
self._chat_impl(
|
|
559
566
|
turn,
|
|
560
567
|
echo=echo,
|
|
561
|
-
|
|
568
|
+
content="text",
|
|
562
569
|
stream=stream,
|
|
563
570
|
kwargs=kwargs,
|
|
564
571
|
)
|
|
@@ -573,7 +580,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
573
580
|
async def chat_async(
|
|
574
581
|
self,
|
|
575
582
|
*args: Content | str,
|
|
576
|
-
echo:
|
|
583
|
+
echo: EchoOptions = "output",
|
|
577
584
|
stream: bool = True,
|
|
578
585
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
579
586
|
) -> ChatResponseAsync:
|
|
@@ -608,7 +615,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
608
615
|
self._chat_impl_async(
|
|
609
616
|
turn,
|
|
610
617
|
echo=echo,
|
|
611
|
-
|
|
618
|
+
content="text",
|
|
612
619
|
stream=stream,
|
|
613
620
|
kwargs=kwargs,
|
|
614
621
|
),
|
|
@@ -620,12 +627,44 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
620
627
|
|
|
621
628
|
return response
|
|
622
629
|
|
|
630
|
+
@overload
|
|
631
|
+
def stream(
|
|
632
|
+
self,
|
|
633
|
+
*args: Content | str,
|
|
634
|
+
) -> Generator[str, None, None]: ...
|
|
635
|
+
|
|
636
|
+
@overload
|
|
637
|
+
def stream(
|
|
638
|
+
self,
|
|
639
|
+
*args: Content | str,
|
|
640
|
+
echo: EchoOptions,
|
|
641
|
+
) -> Generator[str, None, None]: ...
|
|
642
|
+
|
|
643
|
+
@overload
|
|
644
|
+
def stream(
|
|
645
|
+
self,
|
|
646
|
+
*args: Content | str,
|
|
647
|
+
echo: EchoOptions,
|
|
648
|
+
content: Literal["text"],
|
|
649
|
+
kwargs: Optional[SubmitInputArgsT],
|
|
650
|
+
) -> Generator[str, None, None]: ...
|
|
651
|
+
|
|
652
|
+
@overload
|
|
623
653
|
def stream(
|
|
624
654
|
self,
|
|
625
655
|
*args: Content | str,
|
|
626
|
-
echo:
|
|
656
|
+
echo: EchoOptions,
|
|
657
|
+
content: Literal["all"],
|
|
658
|
+
kwargs: Optional[SubmitInputArgsT],
|
|
659
|
+
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
|
|
660
|
+
|
|
661
|
+
def stream(
|
|
662
|
+
self,
|
|
663
|
+
*args: Content | str,
|
|
664
|
+
echo: EchoOptions = "none",
|
|
665
|
+
content: Literal["text", "all"] = "text",
|
|
627
666
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
628
|
-
) ->
|
|
667
|
+
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
|
|
629
668
|
"""
|
|
630
669
|
Generate a response from the chat in a streaming fashion.
|
|
631
670
|
|
|
@@ -636,6 +675,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
636
675
|
echo
|
|
637
676
|
Whether to echo text content, all content (i.e., tool calls), or no
|
|
638
677
|
content.
|
|
678
|
+
content
|
|
679
|
+
Whether to yield just text content, or all content (i.e., tool calls).
|
|
639
680
|
kwargs
|
|
640
681
|
Additional keyword arguments to pass to the method used for requesting
|
|
641
682
|
the response.
|
|
@@ -653,24 +694,58 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
653
694
|
generator = self._chat_impl(
|
|
654
695
|
turn,
|
|
655
696
|
stream=True,
|
|
656
|
-
display=display,
|
|
657
697
|
echo=echo,
|
|
698
|
+
content=content,
|
|
658
699
|
kwargs=kwargs,
|
|
659
700
|
)
|
|
660
701
|
|
|
661
|
-
def wrapper() -> Generator[
|
|
702
|
+
def wrapper() -> Generator[
|
|
703
|
+
str | ContentToolRequest | ContentToolResult, None, None
|
|
704
|
+
]:
|
|
662
705
|
with display:
|
|
663
706
|
for chunk in generator:
|
|
664
707
|
yield chunk
|
|
665
708
|
|
|
666
|
-
return
|
|
709
|
+
return wrapper()
|
|
710
|
+
|
|
711
|
+
@overload
|
|
712
|
+
async def stream_async(
|
|
713
|
+
self,
|
|
714
|
+
*args: Content | str,
|
|
715
|
+
) -> AsyncGenerator[str, None]: ...
|
|
716
|
+
|
|
717
|
+
@overload
|
|
718
|
+
async def stream_async(
|
|
719
|
+
self,
|
|
720
|
+
*args: Content | str,
|
|
721
|
+
echo: EchoOptions,
|
|
722
|
+
) -> AsyncGenerator[str, None]: ...
|
|
667
723
|
|
|
724
|
+
@overload
|
|
668
725
|
async def stream_async(
|
|
669
726
|
self,
|
|
670
727
|
*args: Content | str,
|
|
671
|
-
echo:
|
|
728
|
+
echo: EchoOptions,
|
|
729
|
+
content: Literal["text"],
|
|
730
|
+
kwargs: Optional[SubmitInputArgsT],
|
|
731
|
+
) -> AsyncGenerator[str, None]: ...
|
|
732
|
+
|
|
733
|
+
@overload
|
|
734
|
+
async def stream_async(
|
|
735
|
+
self,
|
|
736
|
+
*args: Content | str,
|
|
737
|
+
echo: EchoOptions,
|
|
738
|
+
content: Literal["all"],
|
|
739
|
+
kwargs: Optional[SubmitInputArgsT],
|
|
740
|
+
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
|
|
741
|
+
|
|
742
|
+
async def stream_async(
|
|
743
|
+
self,
|
|
744
|
+
*args: Content | str,
|
|
745
|
+
echo: EchoOptions = "none",
|
|
746
|
+
content: Literal["text", "all"] = "text",
|
|
672
747
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
673
|
-
) ->
|
|
748
|
+
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
|
|
674
749
|
"""
|
|
675
750
|
Generate a response from the chat in a streaming fashion asynchronously.
|
|
676
751
|
|
|
@@ -681,6 +756,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
681
756
|
echo
|
|
682
757
|
Whether to echo text content, all content (i.e., tool calls), or no
|
|
683
758
|
content.
|
|
759
|
+
content
|
|
760
|
+
Whether to yield just text content, or all content (i.e., tool calls).
|
|
684
761
|
kwargs
|
|
685
762
|
Additional keyword arguments to pass to the method used for requesting
|
|
686
763
|
the response.
|
|
@@ -695,24 +772,26 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
695
772
|
|
|
696
773
|
display = self._markdown_display(echo=echo)
|
|
697
774
|
|
|
698
|
-
async def wrapper() -> AsyncGenerator[
|
|
775
|
+
async def wrapper() -> AsyncGenerator[
|
|
776
|
+
str | ContentToolRequest | ContentToolResult, None
|
|
777
|
+
]:
|
|
699
778
|
with display:
|
|
700
779
|
async for chunk in self._chat_impl_async(
|
|
701
780
|
turn,
|
|
702
781
|
stream=True,
|
|
703
|
-
display=display,
|
|
704
782
|
echo=echo,
|
|
783
|
+
content=content,
|
|
705
784
|
kwargs=kwargs,
|
|
706
785
|
):
|
|
707
786
|
yield chunk
|
|
708
787
|
|
|
709
|
-
return
|
|
788
|
+
return wrapper()
|
|
710
789
|
|
|
711
790
|
def extract_data(
|
|
712
791
|
self,
|
|
713
792
|
*args: Content | str,
|
|
714
793
|
data_model: type[BaseModel],
|
|
715
|
-
echo:
|
|
794
|
+
echo: EchoOptions = "none",
|
|
716
795
|
stream: bool = False,
|
|
717
796
|
) -> dict[str, Any]:
|
|
718
797
|
"""
|
|
@@ -742,7 +821,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
742
821
|
user_turn(*args),
|
|
743
822
|
data_model=data_model,
|
|
744
823
|
echo=echo,
|
|
745
|
-
display=display,
|
|
746
824
|
stream=stream,
|
|
747
825
|
)
|
|
748
826
|
)
|
|
@@ -771,7 +849,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
771
849
|
self,
|
|
772
850
|
*args: Content | str,
|
|
773
851
|
data_model: type[BaseModel],
|
|
774
|
-
echo:
|
|
852
|
+
echo: EchoOptions = "none",
|
|
775
853
|
stream: bool = False,
|
|
776
854
|
) -> dict[str, Any]:
|
|
777
855
|
"""
|
|
@@ -802,7 +880,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
802
880
|
user_turn(*args),
|
|
803
881
|
data_model=data_model,
|
|
804
882
|
echo=echo,
|
|
805
|
-
display=display,
|
|
806
883
|
stream=stream,
|
|
807
884
|
)
|
|
808
885
|
)
|
|
@@ -910,13 +987,65 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
910
987
|
tool = Tool(func, model=model)
|
|
911
988
|
self._tools[tool.name] = tool
|
|
912
989
|
|
|
990
|
+
@property
|
|
991
|
+
def current_display(self) -> Optional[MarkdownDisplay]:
|
|
992
|
+
"""
|
|
993
|
+
Get the currently active markdown display, if any.
|
|
994
|
+
|
|
995
|
+
The display represents the place where `.chat(echo)` content is
|
|
996
|
+
being displayed. In a notebook/Quarto, this is a wrapper around
|
|
997
|
+
`IPython.display`. Otherwise, it is a wrapper around a
|
|
998
|
+
`rich.live.Live()` console.
|
|
999
|
+
|
|
1000
|
+
This is primarily useful if you want to add custom content to the
|
|
1001
|
+
display while the chat is running, but currently blocked by something
|
|
1002
|
+
like a tool call.
|
|
1003
|
+
|
|
1004
|
+
Example
|
|
1005
|
+
-------
|
|
1006
|
+
```python
|
|
1007
|
+
import requests
|
|
1008
|
+
from chatlas import ChatOpenAI
|
|
1009
|
+
|
|
1010
|
+
chat = ChatOpenAI()
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def get_current_weather(latitude: float, longitude: float):
|
|
1014
|
+
"Get the current weather given a latitude and longitude."
|
|
1015
|
+
|
|
1016
|
+
lat_lng = f"latitude={latitude}&longitude={longitude}"
|
|
1017
|
+
url = f"https://api.open-meteo.com/v1/forecast?{lat_lng}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m"
|
|
1018
|
+
response = requests.get(url)
|
|
1019
|
+
json = response.json()
|
|
1020
|
+
if chat.current_display:
|
|
1021
|
+
chat.current_display.echo("My custom tool display!!!")
|
|
1022
|
+
return json["current"]
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
chat.register_tool(get_current_weather)
|
|
1026
|
+
|
|
1027
|
+
chat.chat("What's the current temperature in Duluth, MN?", echo="text")
|
|
1028
|
+
```
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
Returns
|
|
1032
|
+
-------
|
|
1033
|
+
Optional[MarkdownDisplay]
|
|
1034
|
+
The currently active markdown display, if any.
|
|
1035
|
+
"""
|
|
1036
|
+
return self._current_display
|
|
1037
|
+
|
|
1038
|
+
def _echo_content(self, x: str):
|
|
1039
|
+
if self._current_display:
|
|
1040
|
+
self._current_display.echo(x)
|
|
1041
|
+
|
|
913
1042
|
def export(
|
|
914
1043
|
self,
|
|
915
1044
|
filename: str | Path,
|
|
916
1045
|
*,
|
|
917
1046
|
turns: Optional[Sequence[Turn]] = None,
|
|
918
1047
|
title: Optional[str] = None,
|
|
919
|
-
|
|
1048
|
+
content: Literal["text", "all"] = "text",
|
|
920
1049
|
include_system_prompt: bool = True,
|
|
921
1050
|
overwrite: bool = False,
|
|
922
1051
|
):
|
|
@@ -935,7 +1064,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
935
1064
|
A title to place at the top of the exported file.
|
|
936
1065
|
overwrite
|
|
937
1066
|
Whether to overwrite the file if it already exists.
|
|
938
|
-
|
|
1067
|
+
content
|
|
939
1068
|
Whether to include text content, all content (i.e., tool calls), or no
|
|
940
1069
|
content.
|
|
941
1070
|
include_system_prompt
|
|
@@ -971,9 +1100,9 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
971
1100
|
for turn in turns:
|
|
972
1101
|
turn_content = "\n\n".join(
|
|
973
1102
|
[
|
|
974
|
-
str(
|
|
975
|
-
for
|
|
976
|
-
if
|
|
1103
|
+
str(x).strip()
|
|
1104
|
+
for x in turn.contents
|
|
1105
|
+
if content == "all" or isinstance(x, ContentText)
|
|
977
1106
|
]
|
|
978
1107
|
)
|
|
979
1108
|
if is_html:
|
|
@@ -1033,51 +1162,130 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1033
1162
|
</html>
|
|
1034
1163
|
"""
|
|
1035
1164
|
|
|
1165
|
+
@overload
|
|
1036
1166
|
def _chat_impl(
|
|
1037
1167
|
self,
|
|
1038
1168
|
user_turn: Turn,
|
|
1039
|
-
echo:
|
|
1040
|
-
|
|
1169
|
+
echo: EchoOptions,
|
|
1170
|
+
content: Literal["text"],
|
|
1041
1171
|
stream: bool,
|
|
1042
1172
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1043
|
-
) -> Generator[str, None, None]:
|
|
1173
|
+
) -> Generator[str, None, None]: ...
|
|
1174
|
+
|
|
1175
|
+
@overload
|
|
1176
|
+
def _chat_impl(
|
|
1177
|
+
self,
|
|
1178
|
+
user_turn: Turn,
|
|
1179
|
+
echo: EchoOptions,
|
|
1180
|
+
content: Literal["all"],
|
|
1181
|
+
stream: bool,
|
|
1182
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1183
|
+
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
|
|
1184
|
+
|
|
1185
|
+
def _chat_impl(
|
|
1186
|
+
self,
|
|
1187
|
+
user_turn: Turn,
|
|
1188
|
+
echo: EchoOptions,
|
|
1189
|
+
content: Literal["text", "all"],
|
|
1190
|
+
stream: bool,
|
|
1191
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1192
|
+
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
|
|
1044
1193
|
user_turn_result: Turn | None = user_turn
|
|
1045
1194
|
while user_turn_result is not None:
|
|
1046
1195
|
for chunk in self._submit_turns(
|
|
1047
1196
|
user_turn_result,
|
|
1048
1197
|
echo=echo,
|
|
1049
|
-
display=display,
|
|
1050
1198
|
stream=stream,
|
|
1051
1199
|
kwargs=kwargs,
|
|
1052
1200
|
):
|
|
1053
1201
|
yield chunk
|
|
1054
|
-
|
|
1202
|
+
|
|
1203
|
+
turn = self.get_last_turn(role="assistant")
|
|
1204
|
+
assert turn is not None
|
|
1205
|
+
user_turn_result = None
|
|
1206
|
+
|
|
1207
|
+
results: list[ContentToolResult] = []
|
|
1208
|
+
for x in turn.contents:
|
|
1209
|
+
if isinstance(x, ContentToolRequest):
|
|
1210
|
+
if echo == "output":
|
|
1211
|
+
self._echo_content(f"\n\n{x}\n\n")
|
|
1212
|
+
if content == "all":
|
|
1213
|
+
yield x
|
|
1214
|
+
res = self._invoke_tool(x)
|
|
1215
|
+
if echo == "output":
|
|
1216
|
+
self._echo_content(f"\n\n{res}\n\n")
|
|
1217
|
+
if content == "all":
|
|
1218
|
+
yield res
|
|
1219
|
+
results.append(res)
|
|
1220
|
+
|
|
1221
|
+
if results:
|
|
1222
|
+
user_turn_result = Turn("user", results)
|
|
1223
|
+
|
|
1224
|
+
@overload
|
|
1225
|
+
def _chat_impl_async(
|
|
1226
|
+
self,
|
|
1227
|
+
user_turn: Turn,
|
|
1228
|
+
echo: EchoOptions,
|
|
1229
|
+
content: Literal["text"],
|
|
1230
|
+
stream: bool,
|
|
1231
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1232
|
+
) -> AsyncGenerator[str, None]: ...
|
|
1233
|
+
|
|
1234
|
+
@overload
|
|
1235
|
+
def _chat_impl_async(
|
|
1236
|
+
self,
|
|
1237
|
+
user_turn: Turn,
|
|
1238
|
+
echo: EchoOptions,
|
|
1239
|
+
content: Literal["all"],
|
|
1240
|
+
stream: bool,
|
|
1241
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1242
|
+
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
|
|
1055
1243
|
|
|
1056
1244
|
async def _chat_impl_async(
|
|
1057
1245
|
self,
|
|
1058
1246
|
user_turn: Turn,
|
|
1059
|
-
echo:
|
|
1060
|
-
|
|
1247
|
+
echo: EchoOptions,
|
|
1248
|
+
content: Literal["text", "all"],
|
|
1061
1249
|
stream: bool,
|
|
1062
1250
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1063
|
-
) -> AsyncGenerator[str, None]:
|
|
1251
|
+
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
|
|
1064
1252
|
user_turn_result: Turn | None = user_turn
|
|
1065
1253
|
while user_turn_result is not None:
|
|
1066
1254
|
async for chunk in self._submit_turns_async(
|
|
1067
1255
|
user_turn_result,
|
|
1068
1256
|
echo=echo,
|
|
1069
|
-
display=display,
|
|
1070
1257
|
stream=stream,
|
|
1071
1258
|
kwargs=kwargs,
|
|
1072
1259
|
):
|
|
1073
1260
|
yield chunk
|
|
1074
|
-
|
|
1261
|
+
|
|
1262
|
+
turn = self.get_last_turn(role="assistant")
|
|
1263
|
+
assert turn is not None
|
|
1264
|
+
user_turn_result = None
|
|
1265
|
+
|
|
1266
|
+
results: list[ContentToolResult] = []
|
|
1267
|
+
for x in turn.contents:
|
|
1268
|
+
if isinstance(x, ContentToolRequest):
|
|
1269
|
+
if echo == "output":
|
|
1270
|
+
self._echo_content(f"\n\n{x}\n\n")
|
|
1271
|
+
if content == "all":
|
|
1272
|
+
yield x
|
|
1273
|
+
res = await self._invoke_tool_async(x)
|
|
1274
|
+
if echo == "output":
|
|
1275
|
+
self._echo_content(f"\n\n{res}\n\n")
|
|
1276
|
+
if content == "all":
|
|
1277
|
+
yield res
|
|
1278
|
+
else:
|
|
1279
|
+
yield "\n\n"
|
|
1280
|
+
results.append(res)
|
|
1281
|
+
|
|
1282
|
+
if results:
|
|
1283
|
+
user_turn_result = Turn("user", results)
|
|
1075
1284
|
|
|
1076
1285
|
def _submit_turns(
|
|
1077
1286
|
self,
|
|
1078
1287
|
user_turn: Turn,
|
|
1079
|
-
echo:
|
|
1080
|
-
display: MarkdownDisplay,
|
|
1288
|
+
echo: EchoOptions,
|
|
1081
1289
|
stream: bool,
|
|
1082
1290
|
data_model: type[BaseModel] | None = None,
|
|
1083
1291
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
@@ -1086,7 +1294,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1086
1294
|
raise ValueError("Cannot use async tools in a synchronous chat")
|
|
1087
1295
|
|
|
1088
1296
|
def emit(text: str | Content):
|
|
1089
|
-
|
|
1297
|
+
self._echo_content(str(text))
|
|
1090
1298
|
|
|
1091
1299
|
emit("<br>\n\n")
|
|
1092
1300
|
|
|
@@ -1142,14 +1350,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1142
1350
|
async def _submit_turns_async(
|
|
1143
1351
|
self,
|
|
1144
1352
|
user_turn: Turn,
|
|
1145
|
-
echo:
|
|
1146
|
-
display: MarkdownDisplay,
|
|
1353
|
+
echo: EchoOptions,
|
|
1147
1354
|
stream: bool,
|
|
1148
1355
|
data_model: type[BaseModel] | None = None,
|
|
1149
1356
|
kwargs: Optional[SubmitInputArgsT] = None,
|
|
1150
1357
|
) -> AsyncGenerator[str, None]:
|
|
1151
1358
|
def emit(text: str | Content):
|
|
1152
|
-
|
|
1359
|
+
self._echo_content(str(text))
|
|
1153
1360
|
|
|
1154
1361
|
emit("<br>\n\n")
|
|
1155
1362
|
|
|
@@ -1202,92 +1409,74 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1202
1409
|
|
|
1203
1410
|
self._turns.extend([user_turn, turn])
|
|
1204
1411
|
|
|
1205
|
-
def
|
|
1206
|
-
|
|
1207
|
-
if
|
|
1208
|
-
return None
|
|
1209
|
-
|
|
1210
|
-
results: list[ContentToolResult] = []
|
|
1211
|
-
for x in turn.contents:
|
|
1212
|
-
if isinstance(x, ContentToolRequest):
|
|
1213
|
-
tool_def = self._tools.get(x.name, None)
|
|
1214
|
-
func = tool_def.func if tool_def is not None else None
|
|
1215
|
-
results.append(self._invoke_tool(func, x.arguments, x.id))
|
|
1216
|
-
|
|
1217
|
-
if not results:
|
|
1218
|
-
return None
|
|
1219
|
-
|
|
1220
|
-
return Turn("user", results)
|
|
1221
|
-
|
|
1222
|
-
async def _invoke_tools_async(self) -> Turn | None:
|
|
1223
|
-
turn = self.get_last_turn()
|
|
1224
|
-
if turn is None:
|
|
1225
|
-
return None
|
|
1226
|
-
|
|
1227
|
-
results: list[ContentToolResult] = []
|
|
1228
|
-
for x in turn.contents:
|
|
1229
|
-
if isinstance(x, ContentToolRequest):
|
|
1230
|
-
tool_def = self._tools.get(x.name, None)
|
|
1231
|
-
func = None
|
|
1232
|
-
if tool_def:
|
|
1233
|
-
if tool_def._is_async:
|
|
1234
|
-
func = tool_def.func
|
|
1235
|
-
else:
|
|
1236
|
-
func = wrap_async(tool_def.func)
|
|
1237
|
-
results.append(await self._invoke_tool_async(func, x.arguments, x.id))
|
|
1238
|
-
|
|
1239
|
-
if not results:
|
|
1240
|
-
return None
|
|
1412
|
+
def _invoke_tool(self, x: ContentToolRequest) -> ContentToolResult:
|
|
1413
|
+
tool_def = self._tools.get(x.name, None)
|
|
1414
|
+
func = tool_def.func if tool_def is not None else None
|
|
1241
1415
|
|
|
1242
|
-
return Turn("user", results)
|
|
1243
|
-
|
|
1244
|
-
@staticmethod
|
|
1245
|
-
def _invoke_tool(
|
|
1246
|
-
func: Callable[..., Any] | None,
|
|
1247
|
-
arguments: object,
|
|
1248
|
-
id_: str,
|
|
1249
|
-
) -> ContentToolResult:
|
|
1250
1416
|
if func is None:
|
|
1251
|
-
|
|
1417
|
+
e = RuntimeError(f"Unknown tool: {x.name}")
|
|
1418
|
+
return ContentToolResult(value=None, error=e, request=x)
|
|
1252
1419
|
|
|
1253
|
-
|
|
1420
|
+
args = x.arguments
|
|
1254
1421
|
|
|
1255
1422
|
try:
|
|
1256
|
-
if isinstance(
|
|
1257
|
-
result = func(**
|
|
1423
|
+
if isinstance(args, dict):
|
|
1424
|
+
result = func(**args)
|
|
1258
1425
|
else:
|
|
1259
|
-
result = func(
|
|
1426
|
+
result = func(args)
|
|
1260
1427
|
|
|
1261
|
-
|
|
1428
|
+
if not isinstance(result, ContentToolResult):
|
|
1429
|
+
result = ContentToolResult(value=result)
|
|
1430
|
+
|
|
1431
|
+
result.request = x
|
|
1432
|
+
return result
|
|
1262
1433
|
except Exception as e:
|
|
1263
|
-
|
|
1264
|
-
|
|
1434
|
+
warnings.warn(
|
|
1435
|
+
f"Calling tool '{x.name}' led to an error.",
|
|
1436
|
+
ToolFailureWarning,
|
|
1437
|
+
stacklevel=2,
|
|
1438
|
+
)
|
|
1439
|
+
traceback.print_exc()
|
|
1440
|
+
log_tool_error(x.name, str(args), e)
|
|
1441
|
+
return ContentToolResult(value=None, error=e, request=x)
|
|
1442
|
+
|
|
1443
|
+
async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
|
|
1444
|
+
tool_def = self._tools.get(x.name, None)
|
|
1445
|
+
func = None
|
|
1446
|
+
if tool_def:
|
|
1447
|
+
if tool_def._is_async:
|
|
1448
|
+
func = tool_def.func
|
|
1449
|
+
else:
|
|
1450
|
+
func = wrap_async(tool_def.func)
|
|
1265
1451
|
|
|
1266
|
-
@staticmethod
|
|
1267
|
-
async def _invoke_tool_async(
|
|
1268
|
-
func: Callable[..., Awaitable[Any]] | None,
|
|
1269
|
-
arguments: object,
|
|
1270
|
-
id_: str,
|
|
1271
|
-
) -> ContentToolResult:
|
|
1272
1452
|
if func is None:
|
|
1273
|
-
|
|
1453
|
+
e = RuntimeError(f"Unknown tool: {x.name}")
|
|
1454
|
+
return ContentToolResult(value=None, error=e, request=x)
|
|
1274
1455
|
|
|
1275
|
-
|
|
1456
|
+
args = x.arguments
|
|
1276
1457
|
|
|
1277
1458
|
try:
|
|
1278
|
-
if isinstance(
|
|
1279
|
-
result = await func(**
|
|
1459
|
+
if isinstance(args, dict):
|
|
1460
|
+
result = await func(**args)
|
|
1280
1461
|
else:
|
|
1281
|
-
result = await func(
|
|
1462
|
+
result = await func(args)
|
|
1282
1463
|
|
|
1283
|
-
|
|
1464
|
+
if not isinstance(result, ContentToolResult):
|
|
1465
|
+
result = ContentToolResult(value=result)
|
|
1466
|
+
|
|
1467
|
+
result.request = x
|
|
1468
|
+
return result
|
|
1284
1469
|
except Exception as e:
|
|
1285
|
-
|
|
1286
|
-
|
|
1470
|
+
warnings.warn(
|
|
1471
|
+
f"Calling tool '{x.name}' led to an error.",
|
|
1472
|
+
ToolFailureWarning,
|
|
1473
|
+
stacklevel=2,
|
|
1474
|
+
)
|
|
1475
|
+
traceback.print_exc()
|
|
1476
|
+
log_tool_error(x.name, str(args), e)
|
|
1477
|
+
return ContentToolResult(value=None, error=e, request=x)
|
|
1287
1478
|
|
|
1288
|
-
def _markdown_display(
|
|
1289
|
-
self, echo: Literal["text", "all", "none"]
|
|
1290
|
-
) -> MarkdownDisplay:
|
|
1479
|
+
def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
|
|
1291
1480
|
"""
|
|
1292
1481
|
Get a markdown display object based on the echo option.
|
|
1293
1482
|
|
|
@@ -1296,7 +1485,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1296
1485
|
screen sizes.
|
|
1297
1486
|
"""
|
|
1298
1487
|
if echo == "none":
|
|
1299
|
-
return MockMarkdownDisplay()
|
|
1488
|
+
return ChatMarkdownDisplay(MockMarkdownDisplay(), self)
|
|
1300
1489
|
|
|
1301
1490
|
# rich does a lot to detect a notebook environment, but it doesn't
|
|
1302
1491
|
# detect Quarto (at least not yet).
|
|
@@ -1305,10 +1494,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1305
1494
|
is_web = Console().is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
|
|
1306
1495
|
|
|
1307
1496
|
opts = self._echo_options
|
|
1497
|
+
|
|
1308
1498
|
if is_web:
|
|
1309
|
-
|
|
1499
|
+
display = IPyMarkdownDisplay(opts)
|
|
1310
1500
|
else:
|
|
1311
|
-
|
|
1501
|
+
display = LiveMarkdownDisplay(opts)
|
|
1502
|
+
|
|
1503
|
+
return ChatMarkdownDisplay(display, self)
|
|
1312
1504
|
|
|
1313
1505
|
def set_echo_options(
|
|
1314
1506
|
self,
|
|
@@ -1331,7 +1523,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1331
1523
|
A dictionary of CSS styles to apply to `IPython.display.Markdown()`.
|
|
1332
1524
|
This is only relevant when outputing to the browser.
|
|
1333
1525
|
"""
|
|
1334
|
-
self._echo_options:
|
|
1526
|
+
self._echo_options: EchoDisplayOptions = {
|
|
1335
1527
|
"rich_markdown": rich_markdown or {},
|
|
1336
1528
|
"rich_console": rich_console or {},
|
|
1337
1529
|
"css_styles": css_styles or {},
|
|
@@ -1501,3 +1693,30 @@ def emit_other_contents(
|
|
|
1501
1693
|
to_emit.reverse()
|
|
1502
1694
|
|
|
1503
1695
|
emit("\n\n".join(to_emit))
|
|
1696
|
+
|
|
1697
|
+
|
|
1698
|
+
# Helper/wrapper class to let Chat know about the currently active display
|
|
1699
|
+
class ChatMarkdownDisplay:
|
|
1700
|
+
def __init__(self, display: MarkdownDisplay, chat: Chat):
|
|
1701
|
+
self._display = display
|
|
1702
|
+
self._chat = chat
|
|
1703
|
+
|
|
1704
|
+
def __enter__(self):
|
|
1705
|
+
self._chat._current_display = self._display
|
|
1706
|
+
return self._display.__enter__()
|
|
1707
|
+
|
|
1708
|
+
def __exit__(self, *args, **kwargs):
|
|
1709
|
+
result = self._display.__exit__(*args, **kwargs)
|
|
1710
|
+
self._chat._current_display = None
|
|
1711
|
+
return result
|
|
1712
|
+
|
|
1713
|
+
def append(self, content):
|
|
1714
|
+
return self._display.echo(content)
|
|
1715
|
+
|
|
1716
|
+
|
|
1717
|
+
class ToolFailureWarning(RuntimeWarning):
|
|
1718
|
+
pass
|
|
1719
|
+
|
|
1720
|
+
|
|
1721
|
+
# By default warnings are shown once; we want to always show them.
|
|
1722
|
+
warnings.simplefilter("always", ToolFailureWarning)
|