chatlas 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +38 -0
- chatlas/_anthropic.py +643 -0
- chatlas/_chat.py +1279 -0
- chatlas/_content.py +242 -0
- chatlas/_content_image.py +272 -0
- chatlas/_display.py +139 -0
- chatlas/_github.py +147 -0
- chatlas/_google.py +456 -0
- chatlas/_groq.py +143 -0
- chatlas/_interpolate.py +133 -0
- chatlas/_logging.py +61 -0
- chatlas/_merge.py +103 -0
- chatlas/_ollama.py +125 -0
- chatlas/_openai.py +654 -0
- chatlas/_perplexity.py +148 -0
- chatlas/_provider.py +143 -0
- chatlas/_tokens.py +87 -0
- chatlas/_tokens_old.py +148 -0
- chatlas/_tools.py +134 -0
- chatlas/_turn.py +147 -0
- chatlas/_typing_extensions.py +26 -0
- chatlas/_utils.py +106 -0
- chatlas/types/__init__.py +32 -0
- chatlas/types/anthropic/__init__.py +14 -0
- chatlas/types/anthropic/_client.py +29 -0
- chatlas/types/anthropic/_client_bedrock.py +23 -0
- chatlas/types/anthropic/_submit.py +57 -0
- chatlas/types/google/__init__.py +12 -0
- chatlas/types/google/_client.py +101 -0
- chatlas/types/google/_submit.py +113 -0
- chatlas/types/openai/__init__.py +14 -0
- chatlas/types/openai/_client.py +22 -0
- chatlas/types/openai/_client_azure.py +25 -0
- chatlas/types/openai/_submit.py +135 -0
- chatlas-0.2.0.dist-info/METADATA +319 -0
- chatlas-0.2.0.dist-info/RECORD +37 -0
- chatlas-0.2.0.dist-info/WHEEL +4 -0
chatlas/_chat.py
ADDED
|
@@ -0,0 +1,1279 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from threading import Thread
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
AsyncGenerator,
|
|
9
|
+
AsyncIterator,
|
|
10
|
+
Awaitable,
|
|
11
|
+
Callable,
|
|
12
|
+
Generator,
|
|
13
|
+
Generic,
|
|
14
|
+
Iterator,
|
|
15
|
+
Literal,
|
|
16
|
+
Optional,
|
|
17
|
+
Sequence,
|
|
18
|
+
TypeVar,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from ._content import (
|
|
24
|
+
Content,
|
|
25
|
+
ContentJson,
|
|
26
|
+
ContentText,
|
|
27
|
+
ContentToolRequest,
|
|
28
|
+
ContentToolResult,
|
|
29
|
+
)
|
|
30
|
+
from ._display import (
|
|
31
|
+
EchoOptions,
|
|
32
|
+
IPyMarkdownDisplay,
|
|
33
|
+
LiveMarkdownDisplay,
|
|
34
|
+
MarkdownDisplay,
|
|
35
|
+
MockMarkdownDisplay,
|
|
36
|
+
)
|
|
37
|
+
from ._logging import log_tool_error
|
|
38
|
+
from ._provider import Provider
|
|
39
|
+
from ._tools import Tool
|
|
40
|
+
from ._turn import Turn, user_turn
|
|
41
|
+
from ._typing_extensions import TypedDict
|
|
42
|
+
from ._utils import html_escape
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class AnyTypeDict(TypedDict, total=False):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
|
|
50
|
+
"""
|
|
51
|
+
A TypedDict representing the arguments that can be passed to the `.chat()`
|
|
52
|
+
method of a [](`~chatlas.Chat`) instance.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
CompletionT = TypeVar("CompletionT")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
59
|
+
"""
|
|
60
|
+
A chat object that can be used to interact with a language model.
|
|
61
|
+
|
|
62
|
+
A `Chat` is an sequence of sequence of user and assistant
|
|
63
|
+
[](`~chatlas.Turn`)s sent to a specific [](`~chatlas.Provider`). A `Chat`
|
|
64
|
+
takes care of managing the state associated with the chat; i.e. it records
|
|
65
|
+
the messages that you send to the server, and the messages that you receive
|
|
66
|
+
back. If you register a tool (i.e. an function that the assistant can call
|
|
67
|
+
on your behalf), it also takes care of the tool loop.
|
|
68
|
+
|
|
69
|
+
You should generally not create this object yourself, but instead call
|
|
70
|
+
[](`~chatlas.ChatOpenAI`) or friends instead.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
provider: Provider,
|
|
76
|
+
turns: Optional[Sequence[Turn]] = None,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Create a new chat object.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
provider
|
|
84
|
+
A [](`~chatlas.Provider`) object.
|
|
85
|
+
turns
|
|
86
|
+
A list of [](`~chatlas.Turn`) objects to initialize the chat with.
|
|
87
|
+
"""
|
|
88
|
+
self.provider = provider
|
|
89
|
+
self._turns: list[Turn] = list(turns or [])
|
|
90
|
+
self._tools: dict[str, Tool] = {}
|
|
91
|
+
self._echo_options: EchoOptions = {
|
|
92
|
+
"rich_markdown": {},
|
|
93
|
+
"rich_console": {},
|
|
94
|
+
"css_styles": {},
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def get_turns(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
include_system_prompt: bool = False,
|
|
101
|
+
) -> list[Turn[CompletionT]]:
|
|
102
|
+
"""
|
|
103
|
+
Get all the turns (i.e., message contents) in the chat.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
include_system_prompt
|
|
108
|
+
Whether to include the system prompt in the turns.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
if not self._turns:
|
|
112
|
+
return self._turns
|
|
113
|
+
|
|
114
|
+
if not include_system_prompt and self._turns[0].role == "system":
|
|
115
|
+
return self._turns[1:]
|
|
116
|
+
return self._turns
|
|
117
|
+
|
|
118
|
+
def get_last_turn(
|
|
119
|
+
self,
|
|
120
|
+
*,
|
|
121
|
+
role: Literal["assistant", "user", "system"] = "assistant",
|
|
122
|
+
) -> Turn[CompletionT] | None:
|
|
123
|
+
"""
|
|
124
|
+
Get the last turn in the chat with a specific role.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
role
|
|
129
|
+
The role of the turn to return.
|
|
130
|
+
"""
|
|
131
|
+
for turn in reversed(self._turns):
|
|
132
|
+
if turn.role == role:
|
|
133
|
+
return turn
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
def set_turns(self, turns: Sequence[Turn]):
|
|
137
|
+
"""
|
|
138
|
+
Set the turns of the chat.
|
|
139
|
+
|
|
140
|
+
This method is primarily useful for clearing or setting the turns of the
|
|
141
|
+
chat (i.e., limiting the context window).
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
turns
|
|
146
|
+
The turns to set. Turns with the role "system" are not allowed.
|
|
147
|
+
"""
|
|
148
|
+
if any(x.role == "system" for x in turns):
|
|
149
|
+
idx = next(i for i, x in enumerate(turns) if x.role == "system")
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"Turn {idx} has a role 'system', which is not allowed. "
|
|
152
|
+
"The system prompt must be set separately using the `.system_prompt` property. "
|
|
153
|
+
"Consider removing this turn and setting the `.system_prompt` separately "
|
|
154
|
+
"if you want to change the system prompt."
|
|
155
|
+
)
|
|
156
|
+
self._turns = list(turns)
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def system_prompt(self) -> str | None:
|
|
160
|
+
"""
|
|
161
|
+
A property to get (or set) the system prompt for the chat.
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
str | None
|
|
166
|
+
The system prompt (if any).
|
|
167
|
+
"""
|
|
168
|
+
if self._turns and self._turns[0].role == "system":
|
|
169
|
+
return self._turns[0].text
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
@system_prompt.setter
|
|
173
|
+
def system_prompt(self, value: str | None):
|
|
174
|
+
if self._turns and self._turns[0].role == "system":
|
|
175
|
+
self._turns.pop(0)
|
|
176
|
+
if value is not None:
|
|
177
|
+
self._turns.insert(0, Turn("system", value))
|
|
178
|
+
|
|
179
|
+
def tokens(self) -> list[tuple[int, int] | None]:
|
|
180
|
+
"""
|
|
181
|
+
Get the tokens for each turn in the chat.
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
list[tuple[int, int] | None]
|
|
186
|
+
A list of tuples, where each tuple contains the start and end token
|
|
187
|
+
indices for a turn.
|
|
188
|
+
"""
|
|
189
|
+
return [turn.tokens for turn in self._turns]
|
|
190
|
+
|
|
191
|
+
def app(
|
|
192
|
+
self,
|
|
193
|
+
*,
|
|
194
|
+
stream: bool = True,
|
|
195
|
+
port: int = 0,
|
|
196
|
+
launch_browser: bool = True,
|
|
197
|
+
bg_thread: Optional[bool] = None,
|
|
198
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
199
|
+
):
|
|
200
|
+
"""
|
|
201
|
+
Enter a web-based chat app to interact with the LLM.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
stream
|
|
206
|
+
Whether to stream the response (i.e., have the response appear in chunks).
|
|
207
|
+
port
|
|
208
|
+
The port to run the app on (the default is 0, which will choose a random port).
|
|
209
|
+
launch_browser
|
|
210
|
+
Whether to launch a browser window.
|
|
211
|
+
bg_thread
|
|
212
|
+
Whether to run the app in a background thread. If `None`, the app will
|
|
213
|
+
run in a background thread if the current environment is a notebook.
|
|
214
|
+
kwargs
|
|
215
|
+
Additional keyword arguments to pass to the method used for requesting
|
|
216
|
+
the response.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
from shiny import App, run_app, ui
|
|
221
|
+
except ImportError:
|
|
222
|
+
raise ImportError(
|
|
223
|
+
"The `shiny` package is required for the `browser` method. "
|
|
224
|
+
"Install it with `pip install shiny`."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
app_ui = ui.page_fillable(
|
|
228
|
+
ui.chat_ui("chat"),
|
|
229
|
+
fillable_mobile=True,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def server(input): # noqa: A002
|
|
233
|
+
chat = ui.Chat(
|
|
234
|
+
"chat",
|
|
235
|
+
messages=[
|
|
236
|
+
{"role": turn.role, "content": turn.text}
|
|
237
|
+
for turn in self.get_turns()
|
|
238
|
+
],
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
@chat.on_user_submit
|
|
242
|
+
async def _():
|
|
243
|
+
user_input = chat.user_input()
|
|
244
|
+
if user_input is None:
|
|
245
|
+
return
|
|
246
|
+
if stream:
|
|
247
|
+
await chat.append_message_stream(
|
|
248
|
+
self.stream(user_input, kwargs=kwargs)
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
await chat.append_message(str(self.chat(user_input, kwargs=kwargs)))
|
|
252
|
+
|
|
253
|
+
app = App(app_ui, server)
|
|
254
|
+
|
|
255
|
+
def _run_app():
|
|
256
|
+
run_app(app, launch_browser=launch_browser, port=port)
|
|
257
|
+
|
|
258
|
+
# Use bg_thread by default in Jupyter and Positron
|
|
259
|
+
if bg_thread is None:
|
|
260
|
+
from rich.console import Console
|
|
261
|
+
|
|
262
|
+
console = Console()
|
|
263
|
+
bg_thread = console.is_jupyter or (os.getenv("POSITRON") == "1")
|
|
264
|
+
|
|
265
|
+
if bg_thread:
|
|
266
|
+
thread = Thread(target=_run_app, daemon=True)
|
|
267
|
+
thread.start()
|
|
268
|
+
else:
|
|
269
|
+
_run_app()
|
|
270
|
+
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
def console(
|
|
274
|
+
self,
|
|
275
|
+
*,
|
|
276
|
+
echo: Literal["text", "all", "none"] = "text",
|
|
277
|
+
stream: bool = True,
|
|
278
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
279
|
+
):
|
|
280
|
+
"""
|
|
281
|
+
Enter a chat console to interact with the LLM.
|
|
282
|
+
|
|
283
|
+
To quit, input 'exit' or press Ctrl+C.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
echo
|
|
288
|
+
Whether to echo text content, all content (i.e., tool calls), or no
|
|
289
|
+
content.
|
|
290
|
+
stream
|
|
291
|
+
Whether to stream the response (i.e., have the response appear in chunks).
|
|
292
|
+
kwargs
|
|
293
|
+
Additional keyword arguments to pass to the method used for requesting
|
|
294
|
+
the response
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
None
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
print("\nEntering chat console. To quit, input 'exit' or press Ctrl+C.\n")
|
|
302
|
+
|
|
303
|
+
while True:
|
|
304
|
+
user_input = input("?> ")
|
|
305
|
+
if user_input.strip().lower() in ("exit", "exit()"):
|
|
306
|
+
break
|
|
307
|
+
print("")
|
|
308
|
+
self.chat(user_input, echo=echo, stream=stream, kwargs=kwargs)
|
|
309
|
+
print("")
|
|
310
|
+
|
|
311
|
+
def chat(
|
|
312
|
+
self,
|
|
313
|
+
*args: Content | str,
|
|
314
|
+
echo: Literal["text", "all", "none"] = "text",
|
|
315
|
+
stream: bool = True,
|
|
316
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
317
|
+
) -> ChatResponse:
|
|
318
|
+
"""
|
|
319
|
+
Generate a response from the chat.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
args
|
|
324
|
+
The user input(s) to generate a response from.
|
|
325
|
+
echo
|
|
326
|
+
Whether to echo text content, all content (i.e., tool calls), or no
|
|
327
|
+
content.
|
|
328
|
+
stream
|
|
329
|
+
Whether to stream the response (i.e., have the response appear in
|
|
330
|
+
chunks).
|
|
331
|
+
kwargs
|
|
332
|
+
Additional keyword arguments to pass to the method used for
|
|
333
|
+
requesting the response.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
ChatResponse
|
|
338
|
+
A (consumed) response from the chat. Apply `str()` to this object to
|
|
339
|
+
get the text content of the response.
|
|
340
|
+
"""
|
|
341
|
+
turn = user_turn(*args)
|
|
342
|
+
|
|
343
|
+
display = self._markdown_display(echo=echo)
|
|
344
|
+
|
|
345
|
+
response = ChatResponse(
|
|
346
|
+
self._chat_impl(
|
|
347
|
+
turn,
|
|
348
|
+
echo=echo,
|
|
349
|
+
display=display,
|
|
350
|
+
stream=stream,
|
|
351
|
+
kwargs=kwargs,
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
with display:
|
|
356
|
+
for _ in response:
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
return response
|
|
360
|
+
|
|
361
|
+
async def chat_async(
|
|
362
|
+
self,
|
|
363
|
+
*args: Content | str,
|
|
364
|
+
echo: Literal["text", "all", "none"] = "text",
|
|
365
|
+
stream: bool = True,
|
|
366
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
367
|
+
) -> ChatResponseAsync:
|
|
368
|
+
"""
|
|
369
|
+
Generate a response from the chat asynchronously.
|
|
370
|
+
|
|
371
|
+
Parameters
|
|
372
|
+
----------
|
|
373
|
+
args
|
|
374
|
+
The user input(s) to generate a response from.
|
|
375
|
+
echo
|
|
376
|
+
Whether to echo text content, all content (i.e., tool calls, images,
|
|
377
|
+
etc), or no content.
|
|
378
|
+
stream
|
|
379
|
+
Whether to stream the response (i.e., have the response appear in
|
|
380
|
+
chunks).
|
|
381
|
+
kwargs
|
|
382
|
+
Additional keyword arguments to pass to the method used for
|
|
383
|
+
requesting the response.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
ChatResponseAsync
|
|
388
|
+
A (consumed) response from the chat. Apply `str()` to this object to
|
|
389
|
+
get the text content of the response.
|
|
390
|
+
"""
|
|
391
|
+
turn = user_turn(*args)
|
|
392
|
+
|
|
393
|
+
display = self._markdown_display(echo=echo)
|
|
394
|
+
|
|
395
|
+
response = ChatResponseAsync(
|
|
396
|
+
self._chat_impl_async(
|
|
397
|
+
turn,
|
|
398
|
+
echo=echo,
|
|
399
|
+
display=display,
|
|
400
|
+
stream=stream,
|
|
401
|
+
kwargs=kwargs,
|
|
402
|
+
),
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
with display:
|
|
406
|
+
async for _ in response:
|
|
407
|
+
pass
|
|
408
|
+
|
|
409
|
+
return response
|
|
410
|
+
|
|
411
|
+
def stream(
|
|
412
|
+
self,
|
|
413
|
+
*args: Content | str,
|
|
414
|
+
echo: Literal["text", "all", "none"] = "none",
|
|
415
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
416
|
+
) -> ChatResponse:
|
|
417
|
+
"""
|
|
418
|
+
Generate a response from the chat in a streaming fashion.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
args
|
|
423
|
+
The user input(s) to generate a response from.
|
|
424
|
+
echo
|
|
425
|
+
Whether to echo text content, all content (i.e., tool calls), or no
|
|
426
|
+
content.
|
|
427
|
+
kwargs
|
|
428
|
+
Additional keyword arguments to pass to the method used for requesting
|
|
429
|
+
the response.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
ChatResponse
|
|
434
|
+
An (unconsumed) response from the chat. Iterate over this object to
|
|
435
|
+
consume the response.
|
|
436
|
+
"""
|
|
437
|
+
turn = user_turn(*args)
|
|
438
|
+
|
|
439
|
+
display = self._markdown_display(echo=echo)
|
|
440
|
+
|
|
441
|
+
generator = self._chat_impl(
|
|
442
|
+
turn,
|
|
443
|
+
stream=True,
|
|
444
|
+
display=display,
|
|
445
|
+
echo=echo,
|
|
446
|
+
kwargs=kwargs,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
def wrapper() -> Generator[str, None, None]:
|
|
450
|
+
with display:
|
|
451
|
+
for chunk in generator:
|
|
452
|
+
yield chunk
|
|
453
|
+
|
|
454
|
+
return ChatResponse(wrapper())
|
|
455
|
+
|
|
456
|
+
async def stream_async(
|
|
457
|
+
self,
|
|
458
|
+
*args: Content | str,
|
|
459
|
+
echo: Literal["text", "all", "none"] = "none",
|
|
460
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
461
|
+
) -> ChatResponseAsync:
|
|
462
|
+
"""
|
|
463
|
+
Generate a response from the chat in a streaming fashion asynchronously.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
args
|
|
468
|
+
The user input(s) to generate a response from.
|
|
469
|
+
echo
|
|
470
|
+
Whether to echo text content, all content (i.e., tool calls), or no
|
|
471
|
+
content.
|
|
472
|
+
kwargs
|
|
473
|
+
Additional keyword arguments to pass to the method used for requesting
|
|
474
|
+
the response.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
ChatResponseAsync
|
|
479
|
+
An (unconsumed) response from the chat. Iterate over this object to
|
|
480
|
+
consume the response.
|
|
481
|
+
"""
|
|
482
|
+
turn = user_turn(*args)
|
|
483
|
+
|
|
484
|
+
display = self._markdown_display(echo=echo)
|
|
485
|
+
|
|
486
|
+
async def wrapper() -> AsyncGenerator[str, None]:
|
|
487
|
+
with display:
|
|
488
|
+
async for chunk in self._chat_impl_async(
|
|
489
|
+
turn,
|
|
490
|
+
stream=True,
|
|
491
|
+
display=display,
|
|
492
|
+
echo=echo,
|
|
493
|
+
kwargs=kwargs,
|
|
494
|
+
):
|
|
495
|
+
yield chunk
|
|
496
|
+
|
|
497
|
+
return ChatResponseAsync(wrapper())
|
|
498
|
+
|
|
499
|
+
def extract_data(
|
|
500
|
+
self,
|
|
501
|
+
*args: Content | str,
|
|
502
|
+
data_model: type[BaseModel],
|
|
503
|
+
echo: Literal["text", "all", "none"] = "none",
|
|
504
|
+
stream: bool = False,
|
|
505
|
+
) -> dict[str, Any]:
|
|
506
|
+
"""
|
|
507
|
+
Extract structured data from the given input.
|
|
508
|
+
|
|
509
|
+
Parameters
|
|
510
|
+
----------
|
|
511
|
+
args
|
|
512
|
+
The input to extract data from.
|
|
513
|
+
data_model
|
|
514
|
+
A Pydantic model describing the structure of the data to extract.
|
|
515
|
+
echo
|
|
516
|
+
Whether to echo text content, all content (i.e., tool calls), or no content.
|
|
517
|
+
stream
|
|
518
|
+
Whether to stream the response (i.e., have the response appear in chunks).
|
|
519
|
+
|
|
520
|
+
Returns
|
|
521
|
+
-------
|
|
522
|
+
dict[str, Any]
|
|
523
|
+
The extracted data.
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
display = self._markdown_display(echo=echo)
|
|
527
|
+
|
|
528
|
+
response = ChatResponse(
|
|
529
|
+
self._submit_turns(
|
|
530
|
+
user_turn(*args),
|
|
531
|
+
data_model=data_model,
|
|
532
|
+
echo=echo,
|
|
533
|
+
display=display,
|
|
534
|
+
stream=stream,
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
with display:
|
|
539
|
+
for _ in response:
|
|
540
|
+
pass
|
|
541
|
+
|
|
542
|
+
turn = self.get_last_turn()
|
|
543
|
+
assert turn is not None
|
|
544
|
+
|
|
545
|
+
res: list[ContentJson] = []
|
|
546
|
+
for x in turn.contents:
|
|
547
|
+
if isinstance(x, ContentJson):
|
|
548
|
+
res.append(x)
|
|
549
|
+
|
|
550
|
+
if len(res) != 1:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
f"Data extraction failed: {len(res)} data results received."
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
json = res[0]
|
|
556
|
+
return json.value
|
|
557
|
+
|
|
558
|
+
async def extract_data_async(
|
|
559
|
+
self,
|
|
560
|
+
*args: Content | str,
|
|
561
|
+
data_model: type[BaseModel],
|
|
562
|
+
echo: Literal["text", "all", "none"] = "none",
|
|
563
|
+
stream: bool = False,
|
|
564
|
+
) -> dict[str, Any]:
|
|
565
|
+
"""
|
|
566
|
+
Extract structured data from the given input asynchronously.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
args
|
|
571
|
+
The input to extract data from.
|
|
572
|
+
data_model
|
|
573
|
+
A Pydantic model describing the structure of the data to extract.
|
|
574
|
+
echo
|
|
575
|
+
Whether to echo text content, all content (i.e., tool calls), or no content
|
|
576
|
+
stream
|
|
577
|
+
Whether to stream the response (i.e., have the response appear in chunks).
|
|
578
|
+
Defaults to `True` if `echo` is not "none".
|
|
579
|
+
|
|
580
|
+
Returns
|
|
581
|
+
-------
|
|
582
|
+
dict[str, Any]
|
|
583
|
+
The extracted data.
|
|
584
|
+
"""
|
|
585
|
+
|
|
586
|
+
display = self._markdown_display(echo=echo)
|
|
587
|
+
|
|
588
|
+
response = ChatResponseAsync(
|
|
589
|
+
self._submit_turns_async(
|
|
590
|
+
user_turn(*args),
|
|
591
|
+
data_model=data_model,
|
|
592
|
+
echo=echo,
|
|
593
|
+
display=display,
|
|
594
|
+
stream=stream,
|
|
595
|
+
)
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
with display:
|
|
599
|
+
async for _ in response:
|
|
600
|
+
pass
|
|
601
|
+
|
|
602
|
+
turn = self.get_last_turn()
|
|
603
|
+
assert turn is not None
|
|
604
|
+
|
|
605
|
+
res: list[ContentJson] = []
|
|
606
|
+
for x in turn.contents:
|
|
607
|
+
if isinstance(x, ContentJson):
|
|
608
|
+
res.append(x)
|
|
609
|
+
|
|
610
|
+
if len(res) != 1:
|
|
611
|
+
raise ValueError(
|
|
612
|
+
f"Data extraction failed: {len(res)} data results received."
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
json = res[0]
|
|
616
|
+
return json.value
|
|
617
|
+
|
|
618
|
+
def register_tool(
|
|
619
|
+
self,
|
|
620
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
621
|
+
*,
|
|
622
|
+
model: Optional[type[BaseModel]] = None,
|
|
623
|
+
):
|
|
624
|
+
"""
|
|
625
|
+
Register a tool (function) with the chat.
|
|
626
|
+
|
|
627
|
+
The function will always be invoked in the current Python process.
|
|
628
|
+
|
|
629
|
+
Examples
|
|
630
|
+
--------
|
|
631
|
+
|
|
632
|
+
If your tool has straightforward input parameters, you can just
|
|
633
|
+
register the function directly (type hints and a docstring explaning
|
|
634
|
+
both what the function does and what the parameters are for is strongly
|
|
635
|
+
recommended):
|
|
636
|
+
|
|
637
|
+
```python
|
|
638
|
+
from chatlas import ChatOpenAI, Tool
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def add(a: int, b: int) -> int:
|
|
642
|
+
'''
|
|
643
|
+
Add two numbers together.
|
|
644
|
+
|
|
645
|
+
Parameters
|
|
646
|
+
----------
|
|
647
|
+
a : int
|
|
648
|
+
The first number to add.
|
|
649
|
+
b : int
|
|
650
|
+
The second number to add.
|
|
651
|
+
'''
|
|
652
|
+
return a + b
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
chat = ChatOpenAI()
|
|
656
|
+
chat.register_tool(add)
|
|
657
|
+
chat.chat("What is 2 + 2?")
|
|
658
|
+
```
|
|
659
|
+
|
|
660
|
+
If your tool has more complex input parameters, you can provide a Pydantic
|
|
661
|
+
model that corresponds to the input parameters for the function, This way, you
|
|
662
|
+
can have fields that hold other model(s) (for more complex input parameters),
|
|
663
|
+
and also more directly document the input parameters:
|
|
664
|
+
|
|
665
|
+
```python
|
|
666
|
+
from chatlas import ChatOpenAI, Tool
|
|
667
|
+
from pydantic import BaseModel, Field
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
class AddParams(BaseModel):
|
|
671
|
+
'''Add two numbers together.'''
|
|
672
|
+
|
|
673
|
+
a: int = Field(description="The first number to add.")
|
|
674
|
+
|
|
675
|
+
b: int = Field(description="The second number to add.")
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def add(a: int, b: int) -> int:
|
|
679
|
+
return a + b
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
chat = ChatOpenAI()
|
|
683
|
+
chat.register_tool(add, model=AddParams)
|
|
684
|
+
chat.chat("What is 2 + 2?")
|
|
685
|
+
```
|
|
686
|
+
|
|
687
|
+
Parameters
|
|
688
|
+
----------
|
|
689
|
+
func
|
|
690
|
+
The function to be invoked when the tool is called.
|
|
691
|
+
model
|
|
692
|
+
A Pydantic model that describes the input parameters for the function.
|
|
693
|
+
If not provided, the model will be inferred from the function's type hints.
|
|
694
|
+
The primary reason why you might want to provide a model in
|
|
695
|
+
Note that the name and docstring of the model takes precedence over the
|
|
696
|
+
name and docstring of the function.
|
|
697
|
+
"""
|
|
698
|
+
tool = Tool(func, model=model)
|
|
699
|
+
self._tools[tool.name] = tool
|
|
700
|
+
|
|
701
|
+
def export(
|
|
702
|
+
self,
|
|
703
|
+
filename: str | Path,
|
|
704
|
+
*,
|
|
705
|
+
turns: Optional[Sequence[Turn]] = None,
|
|
706
|
+
title: Optional[str] = None,
|
|
707
|
+
include: Literal["text", "all"] = "text",
|
|
708
|
+
include_system_prompt: bool = True,
|
|
709
|
+
overwrite: bool = False,
|
|
710
|
+
):
|
|
711
|
+
"""
|
|
712
|
+
Export the chat history to a file.
|
|
713
|
+
|
|
714
|
+
Parameters
|
|
715
|
+
----------
|
|
716
|
+
filename
|
|
717
|
+
The filename to export the chat to. Currently this must
|
|
718
|
+
be a `.md` or `.html` file.
|
|
719
|
+
turns
|
|
720
|
+
The `.get_turns()` to export. If not provided, the chat's current turns
|
|
721
|
+
will be used.
|
|
722
|
+
title
|
|
723
|
+
A title to place at the top of the exported file.
|
|
724
|
+
overwrite
|
|
725
|
+
Whether to overwrite the file if it already exists.
|
|
726
|
+
include
|
|
727
|
+
Whether to include text content, all content (i.e., tool calls), or no
|
|
728
|
+
content.
|
|
729
|
+
include_system_prompt
|
|
730
|
+
Whether to include the system prompt in a <details> tag.
|
|
731
|
+
|
|
732
|
+
Returns
|
|
733
|
+
-------
|
|
734
|
+
Path
|
|
735
|
+
The path to the exported file.
|
|
736
|
+
"""
|
|
737
|
+
if not turns:
|
|
738
|
+
turns = self.get_turns(include_system_prompt=False)
|
|
739
|
+
if not turns:
|
|
740
|
+
raise ValueError("No turns to export.")
|
|
741
|
+
|
|
742
|
+
if isinstance(filename, str):
|
|
743
|
+
filename = Path(filename)
|
|
744
|
+
|
|
745
|
+
filename = filename.resolve()
|
|
746
|
+
if filename.exists() and not overwrite:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
f"File {filename} already exists. Set `overwrite=True` to overwrite."
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if filename.suffix not in {".md", ".html"}:
|
|
752
|
+
raise ValueError("The filename must have a `.md` or `.html` extension.")
|
|
753
|
+
|
|
754
|
+
# When exporting to HTML, we lean on shiny's chat component for rendering markdown and styling
|
|
755
|
+
is_html = filename.suffix == ".html"
|
|
756
|
+
|
|
757
|
+
# Get contents from each turn
|
|
758
|
+
contents = ""
|
|
759
|
+
for turn in turns:
|
|
760
|
+
turn_content = "\n\n".join(
|
|
761
|
+
[
|
|
762
|
+
str(content)
|
|
763
|
+
for content in turn.contents
|
|
764
|
+
if include == "all" or isinstance(content, ContentText)
|
|
765
|
+
]
|
|
766
|
+
)
|
|
767
|
+
if is_html:
|
|
768
|
+
msg_type = "user" if turn.role == "user" else "chat"
|
|
769
|
+
content_attr = html_escape(turn_content)
|
|
770
|
+
turn_content = f"<shiny-{msg_type}-message content='{content_attr}'></shiny-{msg_type}-message>"
|
|
771
|
+
else:
|
|
772
|
+
turn_content = f"## {turn.role.capitalize()}\n\n{turn_content}"
|
|
773
|
+
contents += f"{turn_content}\n\n"
|
|
774
|
+
|
|
775
|
+
# Shiny chat message components requires container elements
|
|
776
|
+
if is_html:
|
|
777
|
+
contents = f"<shiny-chat-messages>\n{contents}\n</shiny-chat-messages>"
|
|
778
|
+
contents = f"<shiny-chat-container>{contents}</shiny-chat-container>"
|
|
779
|
+
|
|
780
|
+
# Add title to the top
|
|
781
|
+
if title:
|
|
782
|
+
if is_html:
|
|
783
|
+
contents = f"<h1>{title}</h1>\n\n{contents}"
|
|
784
|
+
else:
|
|
785
|
+
contents = f"# {title}\n\n{contents}"
|
|
786
|
+
|
|
787
|
+
# Add system prompt to the bottom
|
|
788
|
+
if include_system_prompt and self.system_prompt:
|
|
789
|
+
contents += f"\n<br><br>\n<details><summary>System prompt</summary>\n\n{self.system_prompt}\n\n</details>"
|
|
790
|
+
|
|
791
|
+
# Wrap in HTML template if exporting to HTML
|
|
792
|
+
if is_html:
|
|
793
|
+
contents = self._html_template(contents)
|
|
794
|
+
|
|
795
|
+
with open(filename, "w") as f:
|
|
796
|
+
f.write(contents)
|
|
797
|
+
|
|
798
|
+
return filename
|
|
799
|
+
|
|
800
|
+
@staticmethod
|
|
801
|
+
def _html_template(contents: str) -> str:
|
|
802
|
+
version = "1.2.1"
|
|
803
|
+
shiny_www = (
|
|
804
|
+
f"https://cdn.jsdelivr.net/gh/posit-dev/py-shiny@{version}/shiny/www/"
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
return f"""
|
|
808
|
+
<!DOCTYPE html>
|
|
809
|
+
<html>
|
|
810
|
+
<head>
|
|
811
|
+
<script src="{shiny_www}/py-shiny/chat/chat.js"></script>
|
|
812
|
+
<link rel="stylesheet" href="{shiny_www}/py-shiny/chat/chat.css">
|
|
813
|
+
<link rel="stylesheet" href="{shiny_www}/shared/bootstrap/bootstrap.min.css">
|
|
814
|
+
</head>
|
|
815
|
+
<body>
|
|
816
|
+
<div style="max-width:700px; margin:0 auto; padding-top:20px;">
|
|
817
|
+
{contents}
|
|
818
|
+
</div>
|
|
819
|
+
</body>
|
|
820
|
+
</html>
|
|
821
|
+
"""
|
|
822
|
+
|
|
823
|
+
def _chat_impl(
|
|
824
|
+
self,
|
|
825
|
+
user_turn: Turn,
|
|
826
|
+
echo: Literal["text", "all", "none"],
|
|
827
|
+
display: MarkdownDisplay,
|
|
828
|
+
stream: bool,
|
|
829
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
830
|
+
) -> Generator[str, None, None]:
|
|
831
|
+
user_turn_result: Turn | None = user_turn
|
|
832
|
+
while user_turn_result is not None:
|
|
833
|
+
for chunk in self._submit_turns(
|
|
834
|
+
user_turn_result,
|
|
835
|
+
echo=echo,
|
|
836
|
+
display=display,
|
|
837
|
+
stream=stream,
|
|
838
|
+
kwargs=kwargs,
|
|
839
|
+
):
|
|
840
|
+
yield chunk
|
|
841
|
+
user_turn_result = self._invoke_tools()
|
|
842
|
+
|
|
843
|
+
async def _chat_impl_async(
|
|
844
|
+
self,
|
|
845
|
+
user_turn: Turn,
|
|
846
|
+
echo: Literal["text", "all", "none"],
|
|
847
|
+
display: MarkdownDisplay,
|
|
848
|
+
stream: bool,
|
|
849
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
850
|
+
) -> AsyncGenerator[str, None]:
|
|
851
|
+
user_turn_result: Turn | None = user_turn
|
|
852
|
+
while user_turn_result is not None:
|
|
853
|
+
async for chunk in self._submit_turns_async(
|
|
854
|
+
user_turn_result,
|
|
855
|
+
echo=echo,
|
|
856
|
+
display=display,
|
|
857
|
+
stream=stream,
|
|
858
|
+
kwargs=kwargs,
|
|
859
|
+
):
|
|
860
|
+
yield chunk
|
|
861
|
+
user_turn_result = await self._invoke_tools_async()
|
|
862
|
+
|
|
863
|
+
def _submit_turns(
|
|
864
|
+
self,
|
|
865
|
+
user_turn: Turn,
|
|
866
|
+
echo: Literal["text", "all", "none"],
|
|
867
|
+
display: MarkdownDisplay,
|
|
868
|
+
stream: bool,
|
|
869
|
+
data_model: type[BaseModel] | None = None,
|
|
870
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
871
|
+
) -> Generator[str, None, None]:
|
|
872
|
+
if any(x._is_async for x in self._tools.values()):
|
|
873
|
+
raise ValueError("Cannot use async tools in a synchronous chat")
|
|
874
|
+
|
|
875
|
+
def emit(text: str | Content):
|
|
876
|
+
display.update(str(text))
|
|
877
|
+
|
|
878
|
+
emit("<br>\n\n")
|
|
879
|
+
|
|
880
|
+
if echo == "all":
|
|
881
|
+
emit_user_contents(user_turn, emit)
|
|
882
|
+
|
|
883
|
+
if stream:
|
|
884
|
+
response = self.provider.chat_perform(
|
|
885
|
+
stream=True,
|
|
886
|
+
turns=[*self._turns, user_turn],
|
|
887
|
+
tools=self._tools,
|
|
888
|
+
data_model=data_model,
|
|
889
|
+
kwargs=kwargs,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
result = None
|
|
893
|
+
for chunk in response:
|
|
894
|
+
text = self.provider.stream_text(chunk)
|
|
895
|
+
if text:
|
|
896
|
+
emit(text)
|
|
897
|
+
yield text
|
|
898
|
+
result = self.provider.stream_merge_chunks(result, chunk)
|
|
899
|
+
|
|
900
|
+
turn = self.provider.stream_turn(
|
|
901
|
+
result,
|
|
902
|
+
has_data_model=data_model is not None,
|
|
903
|
+
stream=response,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
if echo == "all":
|
|
907
|
+
emit_other_contents(turn, emit)
|
|
908
|
+
|
|
909
|
+
else:
|
|
910
|
+
response = self.provider.chat_perform(
|
|
911
|
+
stream=False,
|
|
912
|
+
turns=[*self._turns, user_turn],
|
|
913
|
+
tools=self._tools,
|
|
914
|
+
data_model=data_model,
|
|
915
|
+
kwargs=kwargs,
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
turn = self.provider.value_turn(
|
|
919
|
+
response, has_data_model=data_model is not None
|
|
920
|
+
)
|
|
921
|
+
if turn.text:
|
|
922
|
+
emit(turn.text)
|
|
923
|
+
yield turn.text
|
|
924
|
+
|
|
925
|
+
if echo == "all":
|
|
926
|
+
emit_other_contents(turn, emit)
|
|
927
|
+
|
|
928
|
+
self._turns.extend([user_turn, turn])
|
|
929
|
+
|
|
930
|
+
async def _submit_turns_async(
|
|
931
|
+
self,
|
|
932
|
+
user_turn: Turn,
|
|
933
|
+
echo: Literal["text", "all", "none"],
|
|
934
|
+
display: MarkdownDisplay,
|
|
935
|
+
stream: bool,
|
|
936
|
+
data_model: type[BaseModel] | None = None,
|
|
937
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
938
|
+
) -> AsyncGenerator[str, None]:
|
|
939
|
+
def emit(text: str | Content):
|
|
940
|
+
display.update(str(text))
|
|
941
|
+
|
|
942
|
+
emit("<br>\n\n")
|
|
943
|
+
|
|
944
|
+
if echo == "all":
|
|
945
|
+
emit_user_contents(user_turn, emit)
|
|
946
|
+
|
|
947
|
+
if stream:
|
|
948
|
+
response = await self.provider.chat_perform_async(
|
|
949
|
+
stream=True,
|
|
950
|
+
turns=[*self._turns, user_turn],
|
|
951
|
+
tools=self._tools,
|
|
952
|
+
data_model=data_model,
|
|
953
|
+
kwargs=kwargs,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
result = None
|
|
957
|
+
async for chunk in response:
|
|
958
|
+
text = self.provider.stream_text(chunk)
|
|
959
|
+
if text:
|
|
960
|
+
emit(text)
|
|
961
|
+
yield text
|
|
962
|
+
result = self.provider.stream_merge_chunks(result, chunk)
|
|
963
|
+
|
|
964
|
+
turn = await self.provider.stream_turn_async(
|
|
965
|
+
result,
|
|
966
|
+
has_data_model=data_model is not None,
|
|
967
|
+
stream=response,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if echo == "all":
|
|
971
|
+
emit_other_contents(turn, emit)
|
|
972
|
+
|
|
973
|
+
else:
|
|
974
|
+
response = await self.provider.chat_perform_async(
|
|
975
|
+
stream=False,
|
|
976
|
+
turns=[*self._turns, user_turn],
|
|
977
|
+
tools=self._tools,
|
|
978
|
+
data_model=data_model,
|
|
979
|
+
kwargs=kwargs,
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
turn = self.provider.value_turn(
|
|
983
|
+
response, has_data_model=data_model is not None
|
|
984
|
+
)
|
|
985
|
+
if turn.text:
|
|
986
|
+
emit(turn.text)
|
|
987
|
+
yield turn.text
|
|
988
|
+
|
|
989
|
+
if echo == "all":
|
|
990
|
+
emit_other_contents(turn, emit)
|
|
991
|
+
|
|
992
|
+
self._turns.extend([user_turn, turn])
|
|
993
|
+
|
|
994
|
+
def _invoke_tools(self) -> Turn | None:
|
|
995
|
+
turn = self.get_last_turn()
|
|
996
|
+
if turn is None:
|
|
997
|
+
return None
|
|
998
|
+
|
|
999
|
+
results: list[ContentToolResult] = []
|
|
1000
|
+
for x in turn.contents:
|
|
1001
|
+
if isinstance(x, ContentToolRequest):
|
|
1002
|
+
tool_def = self._tools.get(x.name, None)
|
|
1003
|
+
func = tool_def.func if tool_def is not None else None
|
|
1004
|
+
results.append(self._invoke_tool(func, x.arguments, x.id))
|
|
1005
|
+
|
|
1006
|
+
if not results:
|
|
1007
|
+
return None
|
|
1008
|
+
|
|
1009
|
+
return Turn("user", results)
|
|
1010
|
+
|
|
1011
|
+
async def _invoke_tools_async(self) -> Turn | None:
|
|
1012
|
+
turn = self.get_last_turn()
|
|
1013
|
+
if turn is None:
|
|
1014
|
+
return None
|
|
1015
|
+
|
|
1016
|
+
results: list[ContentToolResult] = []
|
|
1017
|
+
for x in turn.contents:
|
|
1018
|
+
if isinstance(x, ContentToolRequest):
|
|
1019
|
+
tool_def = self._tools.get(x.name, None)
|
|
1020
|
+
func = tool_def.func if tool_def is not None else None
|
|
1021
|
+
results.append(await self._invoke_tool_async(func, x.arguments, x.id))
|
|
1022
|
+
|
|
1023
|
+
if not results:
|
|
1024
|
+
return None
|
|
1025
|
+
|
|
1026
|
+
return Turn("user", results)
|
|
1027
|
+
|
|
1028
|
+
@staticmethod
|
|
1029
|
+
def _invoke_tool(
|
|
1030
|
+
func: Callable[..., Any] | None,
|
|
1031
|
+
arguments: object,
|
|
1032
|
+
id_: str,
|
|
1033
|
+
) -> ContentToolResult:
|
|
1034
|
+
if func is None:
|
|
1035
|
+
return ContentToolResult(id_, None, "Unknown tool")
|
|
1036
|
+
|
|
1037
|
+
try:
|
|
1038
|
+
if isinstance(arguments, dict):
|
|
1039
|
+
result = func(**arguments)
|
|
1040
|
+
else:
|
|
1041
|
+
result = func(arguments)
|
|
1042
|
+
|
|
1043
|
+
return ContentToolResult(id_, result, None)
|
|
1044
|
+
except Exception as e:
|
|
1045
|
+
log_tool_error(func.__name__, str(arguments), e)
|
|
1046
|
+
return ContentToolResult(id_, None, str(e))
|
|
1047
|
+
|
|
1048
|
+
@staticmethod
|
|
1049
|
+
async def _invoke_tool_async(
|
|
1050
|
+
func: Callable[..., Awaitable[Any]] | None,
|
|
1051
|
+
arguments: object,
|
|
1052
|
+
id_: str,
|
|
1053
|
+
) -> ContentToolResult:
|
|
1054
|
+
if func is None:
|
|
1055
|
+
return ContentToolResult(id_, None, "Unknown tool")
|
|
1056
|
+
|
|
1057
|
+
try:
|
|
1058
|
+
if isinstance(arguments, dict):
|
|
1059
|
+
result = await func(**arguments)
|
|
1060
|
+
else:
|
|
1061
|
+
result = await func(arguments)
|
|
1062
|
+
|
|
1063
|
+
return ContentToolResult(id_, result, None)
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
log_tool_error(func.__name__, str(arguments), e)
|
|
1066
|
+
return ContentToolResult(id_, None, str(e))
|
|
1067
|
+
|
|
1068
|
+
def _markdown_display(
|
|
1069
|
+
self, echo: Literal["text", "all", "none"]
|
|
1070
|
+
) -> MarkdownDisplay:
|
|
1071
|
+
"""
|
|
1072
|
+
Get a markdown display object based on the echo option.
|
|
1073
|
+
|
|
1074
|
+
The idea here is to use rich for consoles and IPython.display.Markdown
|
|
1075
|
+
for notebooks, since the latter is much more responsive to different
|
|
1076
|
+
screen sizes.
|
|
1077
|
+
"""
|
|
1078
|
+
if echo == "none":
|
|
1079
|
+
return MockMarkdownDisplay()
|
|
1080
|
+
|
|
1081
|
+
# rich does a lot to detect a notebook environment, but it doesn't
|
|
1082
|
+
# detect Quarto (at least not yet).
|
|
1083
|
+
from rich.console import Console
|
|
1084
|
+
|
|
1085
|
+
is_web = Console().is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
|
|
1086
|
+
|
|
1087
|
+
opts = self._echo_options
|
|
1088
|
+
if is_web:
|
|
1089
|
+
return IPyMarkdownDisplay(opts)
|
|
1090
|
+
else:
|
|
1091
|
+
return LiveMarkdownDisplay(opts)
|
|
1092
|
+
|
|
1093
|
+
def set_echo_options(
|
|
1094
|
+
self,
|
|
1095
|
+
rich_markdown: Optional[dict[str, Any]] = None,
|
|
1096
|
+
rich_console: Optional[dict[str, Any]] = None,
|
|
1097
|
+
css_styles: Optional[dict[str, str]] = None,
|
|
1098
|
+
):
|
|
1099
|
+
"""
|
|
1100
|
+
Set echo styling options for the chat.
|
|
1101
|
+
|
|
1102
|
+
Parameters
|
|
1103
|
+
----------
|
|
1104
|
+
rich_markdown
|
|
1105
|
+
A dictionary of options to pass to `rich.markdown.Markdown()`.
|
|
1106
|
+
This is only relevant when outputting to the console.
|
|
1107
|
+
rich_console
|
|
1108
|
+
A dictionary of options to pass to `rich.console.Console()`.
|
|
1109
|
+
This is only relevant when outputting to the console.
|
|
1110
|
+
css_styles
|
|
1111
|
+
A dictionary of CSS styles to apply to `IPython.display.Markdown()`.
|
|
1112
|
+
This is only relevant when outputing to the browser.
|
|
1113
|
+
"""
|
|
1114
|
+
self._echo_options: EchoOptions = {
|
|
1115
|
+
"rich_markdown": rich_markdown or {},
|
|
1116
|
+
"rich_console": rich_console or {},
|
|
1117
|
+
"css_styles": css_styles or {},
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
def __str__(self):
|
|
1121
|
+
turns = self.get_turns(include_system_prompt=False)
|
|
1122
|
+
res = ""
|
|
1123
|
+
for turn in turns:
|
|
1124
|
+
icon = "👤" if turn.role == "user" else "🤖"
|
|
1125
|
+
res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n"
|
|
1126
|
+
return res
|
|
1127
|
+
|
|
1128
|
+
def __repr__(self):
|
|
1129
|
+
turns = self.get_turns(include_system_prompt=True)
|
|
1130
|
+
tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
|
|
1131
|
+
res = f"<Chat turns={len(turns)} tokens={tokens}>"
|
|
1132
|
+
for turn in turns:
|
|
1133
|
+
res += "\n" + turn.__repr__(indent=2)
|
|
1134
|
+
return res + "\n"
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
class ChatResponse:
|
|
1138
|
+
"""
|
|
1139
|
+
Chat response object.
|
|
1140
|
+
|
|
1141
|
+
An object that, when displayed, will simulatenously consume (if not
|
|
1142
|
+
already consumed) and display the response in a streaming fashion.
|
|
1143
|
+
|
|
1144
|
+
This is useful for interactive use: if the object is displayed, it can
|
|
1145
|
+
be viewed as it is being generated. And, if the object is not displayed,
|
|
1146
|
+
it can act like an iterator that can be consumed by something else.
|
|
1147
|
+
|
|
1148
|
+
Attributes
|
|
1149
|
+
----------
|
|
1150
|
+
content
|
|
1151
|
+
The content of the chat response.
|
|
1152
|
+
|
|
1153
|
+
Properties
|
|
1154
|
+
----------
|
|
1155
|
+
consumed
|
|
1156
|
+
Whether the response has been consumed. If the response has been fully
|
|
1157
|
+
consumed, then it can no longer be iterated over, but the content can
|
|
1158
|
+
still be retrieved (via the `content` attribute).
|
|
1159
|
+
"""
|
|
1160
|
+
|
|
1161
|
+
def __init__(self, generator: Generator[str, None]):
|
|
1162
|
+
self._generator = generator
|
|
1163
|
+
self.content: str = ""
|
|
1164
|
+
|
|
1165
|
+
def __iter__(self) -> Iterator[str]:
|
|
1166
|
+
return self
|
|
1167
|
+
|
|
1168
|
+
def __next__(self) -> str:
|
|
1169
|
+
chunk = next(self._generator)
|
|
1170
|
+
self.content += chunk # Keep track of accumulated content
|
|
1171
|
+
return chunk
|
|
1172
|
+
|
|
1173
|
+
def get_content(self) -> str:
|
|
1174
|
+
"""
|
|
1175
|
+
Get the chat response content as a string.
|
|
1176
|
+
"""
|
|
1177
|
+
for _ in self:
|
|
1178
|
+
pass
|
|
1179
|
+
return self.content
|
|
1180
|
+
|
|
1181
|
+
@property
|
|
1182
|
+
def consumed(self) -> bool:
|
|
1183
|
+
return self._generator.gi_frame is None
|
|
1184
|
+
|
|
1185
|
+
def __str__(self) -> str:
|
|
1186
|
+
return self.get_content()
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
class ChatResponseAsync:
|
|
1190
|
+
"""
|
|
1191
|
+
Chat response (async) object.
|
|
1192
|
+
|
|
1193
|
+
An object that, when displayed, will simulatenously consume (if not
|
|
1194
|
+
already consumed) and display the response in a streaming fashion.
|
|
1195
|
+
|
|
1196
|
+
This is useful for interactive use: if the object is displayed, it can
|
|
1197
|
+
be viewed as it is being generated. And, if the object is not displayed,
|
|
1198
|
+
it can act like an iterator that can be consumed by something else.
|
|
1199
|
+
|
|
1200
|
+
Attributes
|
|
1201
|
+
----------
|
|
1202
|
+
content
|
|
1203
|
+
The content of the chat response.
|
|
1204
|
+
|
|
1205
|
+
Properties
|
|
1206
|
+
----------
|
|
1207
|
+
consumed
|
|
1208
|
+
Whether the response has been consumed. If the response has been fully
|
|
1209
|
+
consumed, then it can no longer be iterated over, but the content can
|
|
1210
|
+
still be retrieved (via the `content` attribute).
|
|
1211
|
+
"""
|
|
1212
|
+
|
|
1213
|
+
def __init__(self, generator: AsyncGenerator[str, None]):
|
|
1214
|
+
self._generator = generator
|
|
1215
|
+
self.content: str = ""
|
|
1216
|
+
|
|
1217
|
+
def __aiter__(self) -> AsyncIterator[str]:
|
|
1218
|
+
return self
|
|
1219
|
+
|
|
1220
|
+
async def __anext__(self) -> str:
|
|
1221
|
+
chunk = await self._generator.__anext__()
|
|
1222
|
+
self.content += chunk # Keep track of accumulated content
|
|
1223
|
+
return chunk
|
|
1224
|
+
|
|
1225
|
+
async def get_content(self) -> str:
|
|
1226
|
+
"Get the chat response content as a string."
|
|
1227
|
+
async for _ in self:
|
|
1228
|
+
pass
|
|
1229
|
+
return self.content
|
|
1230
|
+
|
|
1231
|
+
@property
|
|
1232
|
+
def consumed(self) -> bool:
|
|
1233
|
+
return self._generator.ag_frame is None
|
|
1234
|
+
|
|
1235
|
+
|
|
1236
|
+
# ----------------------------------------------------------------------------
|
|
1237
|
+
# Helpers for emitting content
|
|
1238
|
+
# ----------------------------------------------------------------------------
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
def emit_user_contents(
|
|
1242
|
+
x: Turn,
|
|
1243
|
+
emit: Callable[[Content | str], None],
|
|
1244
|
+
):
|
|
1245
|
+
if x.role != "user":
|
|
1246
|
+
raise ValueError("Expected a user turn")
|
|
1247
|
+
emit(f"## 👤 User turn:\n\n{str(x)}\n\n")
|
|
1248
|
+
emit_other_contents(x, emit)
|
|
1249
|
+
emit("\n\n## 🤖 Assistant turn:\n\n")
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
def emit_other_contents(
|
|
1253
|
+
x: Turn,
|
|
1254
|
+
emit: Callable[[Content | str], None],
|
|
1255
|
+
):
|
|
1256
|
+
# Gather other content to emit in _reverse_ order
|
|
1257
|
+
to_emit: list[str] = []
|
|
1258
|
+
|
|
1259
|
+
if x.finish_reason:
|
|
1260
|
+
to_emit.append(f"\n\n<< 🤖 finish reason: {x.finish_reason} \\>\\>\n\n")
|
|
1261
|
+
|
|
1262
|
+
has_text = False
|
|
1263
|
+
has_other = False
|
|
1264
|
+
for content in reversed(x.contents):
|
|
1265
|
+
if isinstance(content, ContentText):
|
|
1266
|
+
has_text = True
|
|
1267
|
+
else:
|
|
1268
|
+
has_other = True
|
|
1269
|
+
to_emit.append(str(content))
|
|
1270
|
+
|
|
1271
|
+
if has_text and has_other:
|
|
1272
|
+
if x.role == "user":
|
|
1273
|
+
to_emit.append("<< 👤 other content >>")
|
|
1274
|
+
else:
|
|
1275
|
+
to_emit.append("<< 🤖 other content >>")
|
|
1276
|
+
|
|
1277
|
+
to_emit.reverse()
|
|
1278
|
+
|
|
1279
|
+
emit("\n\n".join(to_emit))
|