chatlas 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_chat.py ADDED
@@ -0,0 +1,1279 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from threading import Thread
6
+ from typing import (
7
+ Any,
8
+ AsyncGenerator,
9
+ AsyncIterator,
10
+ Awaitable,
11
+ Callable,
12
+ Generator,
13
+ Generic,
14
+ Iterator,
15
+ Literal,
16
+ Optional,
17
+ Sequence,
18
+ TypeVar,
19
+ )
20
+
21
+ from pydantic import BaseModel
22
+
23
+ from ._content import (
24
+ Content,
25
+ ContentJson,
26
+ ContentText,
27
+ ContentToolRequest,
28
+ ContentToolResult,
29
+ )
30
+ from ._display import (
31
+ EchoOptions,
32
+ IPyMarkdownDisplay,
33
+ LiveMarkdownDisplay,
34
+ MarkdownDisplay,
35
+ MockMarkdownDisplay,
36
+ )
37
+ from ._logging import log_tool_error
38
+ from ._provider import Provider
39
+ from ._tools import Tool
40
+ from ._turn import Turn, user_turn
41
+ from ._typing_extensions import TypedDict
42
+ from ._utils import html_escape
43
+
44
+
45
+ class AnyTypeDict(TypedDict, total=False):
46
+ pass
47
+
48
+
49
+ SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
50
+ """
51
+ A TypedDict representing the arguments that can be passed to the `.chat()`
52
+ method of a [](`~chatlas.Chat`) instance.
53
+ """
54
+
55
+ CompletionT = TypeVar("CompletionT")
56
+
57
+
58
+ class Chat(Generic[SubmitInputArgsT, CompletionT]):
59
+ """
60
+ A chat object that can be used to interact with a language model.
61
+
62
+ A `Chat` is an sequence of sequence of user and assistant
63
+ [](`~chatlas.Turn`)s sent to a specific [](`~chatlas.Provider`). A `Chat`
64
+ takes care of managing the state associated with the chat; i.e. it records
65
+ the messages that you send to the server, and the messages that you receive
66
+ back. If you register a tool (i.e. an function that the assistant can call
67
+ on your behalf), it also takes care of the tool loop.
68
+
69
+ You should generally not create this object yourself, but instead call
70
+ [](`~chatlas.ChatOpenAI`) or friends instead.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ provider: Provider,
76
+ turns: Optional[Sequence[Turn]] = None,
77
+ ):
78
+ """
79
+ Create a new chat object.
80
+
81
+ Parameters
82
+ ----------
83
+ provider
84
+ A [](`~chatlas.Provider`) object.
85
+ turns
86
+ A list of [](`~chatlas.Turn`) objects to initialize the chat with.
87
+ """
88
+ self.provider = provider
89
+ self._turns: list[Turn] = list(turns or [])
90
+ self._tools: dict[str, Tool] = {}
91
+ self._echo_options: EchoOptions = {
92
+ "rich_markdown": {},
93
+ "rich_console": {},
94
+ "css_styles": {},
95
+ }
96
+
97
+ def get_turns(
98
+ self,
99
+ *,
100
+ include_system_prompt: bool = False,
101
+ ) -> list[Turn[CompletionT]]:
102
+ """
103
+ Get all the turns (i.e., message contents) in the chat.
104
+
105
+ Parameters
106
+ ----------
107
+ include_system_prompt
108
+ Whether to include the system prompt in the turns.
109
+ """
110
+
111
+ if not self._turns:
112
+ return self._turns
113
+
114
+ if not include_system_prompt and self._turns[0].role == "system":
115
+ return self._turns[1:]
116
+ return self._turns
117
+
118
+ def get_last_turn(
119
+ self,
120
+ *,
121
+ role: Literal["assistant", "user", "system"] = "assistant",
122
+ ) -> Turn[CompletionT] | None:
123
+ """
124
+ Get the last turn in the chat with a specific role.
125
+
126
+ Parameters
127
+ ----------
128
+ role
129
+ The role of the turn to return.
130
+ """
131
+ for turn in reversed(self._turns):
132
+ if turn.role == role:
133
+ return turn
134
+ return None
135
+
136
+ def set_turns(self, turns: Sequence[Turn]):
137
+ """
138
+ Set the turns of the chat.
139
+
140
+ This method is primarily useful for clearing or setting the turns of the
141
+ chat (i.e., limiting the context window).
142
+
143
+ Parameters
144
+ ----------
145
+ turns
146
+ The turns to set. Turns with the role "system" are not allowed.
147
+ """
148
+ if any(x.role == "system" for x in turns):
149
+ idx = next(i for i, x in enumerate(turns) if x.role == "system")
150
+ raise ValueError(
151
+ f"Turn {idx} has a role 'system', which is not allowed. "
152
+ "The system prompt must be set separately using the `.system_prompt` property. "
153
+ "Consider removing this turn and setting the `.system_prompt` separately "
154
+ "if you want to change the system prompt."
155
+ )
156
+ self._turns = list(turns)
157
+
158
+ @property
159
+ def system_prompt(self) -> str | None:
160
+ """
161
+ A property to get (or set) the system prompt for the chat.
162
+
163
+ Returns
164
+ -------
165
+ str | None
166
+ The system prompt (if any).
167
+ """
168
+ if self._turns and self._turns[0].role == "system":
169
+ return self._turns[0].text
170
+ return None
171
+
172
+ @system_prompt.setter
173
+ def system_prompt(self, value: str | None):
174
+ if self._turns and self._turns[0].role == "system":
175
+ self._turns.pop(0)
176
+ if value is not None:
177
+ self._turns.insert(0, Turn("system", value))
178
+
179
+ def tokens(self) -> list[tuple[int, int] | None]:
180
+ """
181
+ Get the tokens for each turn in the chat.
182
+
183
+ Returns
184
+ -------
185
+ list[tuple[int, int] | None]
186
+ A list of tuples, where each tuple contains the start and end token
187
+ indices for a turn.
188
+ """
189
+ return [turn.tokens for turn in self._turns]
190
+
191
+ def app(
192
+ self,
193
+ *,
194
+ stream: bool = True,
195
+ port: int = 0,
196
+ launch_browser: bool = True,
197
+ bg_thread: Optional[bool] = None,
198
+ kwargs: Optional[SubmitInputArgsT] = None,
199
+ ):
200
+ """
201
+ Enter a web-based chat app to interact with the LLM.
202
+
203
+ Parameters
204
+ ----------
205
+ stream
206
+ Whether to stream the response (i.e., have the response appear in chunks).
207
+ port
208
+ The port to run the app on (the default is 0, which will choose a random port).
209
+ launch_browser
210
+ Whether to launch a browser window.
211
+ bg_thread
212
+ Whether to run the app in a background thread. If `None`, the app will
213
+ run in a background thread if the current environment is a notebook.
214
+ kwargs
215
+ Additional keyword arguments to pass to the method used for requesting
216
+ the response.
217
+ """
218
+
219
+ try:
220
+ from shiny import App, run_app, ui
221
+ except ImportError:
222
+ raise ImportError(
223
+ "The `shiny` package is required for the `browser` method. "
224
+ "Install it with `pip install shiny`."
225
+ )
226
+
227
+ app_ui = ui.page_fillable(
228
+ ui.chat_ui("chat"),
229
+ fillable_mobile=True,
230
+ )
231
+
232
+ def server(input): # noqa: A002
233
+ chat = ui.Chat(
234
+ "chat",
235
+ messages=[
236
+ {"role": turn.role, "content": turn.text}
237
+ for turn in self.get_turns()
238
+ ],
239
+ )
240
+
241
+ @chat.on_user_submit
242
+ async def _():
243
+ user_input = chat.user_input()
244
+ if user_input is None:
245
+ return
246
+ if stream:
247
+ await chat.append_message_stream(
248
+ self.stream(user_input, kwargs=kwargs)
249
+ )
250
+ else:
251
+ await chat.append_message(str(self.chat(user_input, kwargs=kwargs)))
252
+
253
+ app = App(app_ui, server)
254
+
255
+ def _run_app():
256
+ run_app(app, launch_browser=launch_browser, port=port)
257
+
258
+ # Use bg_thread by default in Jupyter and Positron
259
+ if bg_thread is None:
260
+ from rich.console import Console
261
+
262
+ console = Console()
263
+ bg_thread = console.is_jupyter or (os.getenv("POSITRON") == "1")
264
+
265
+ if bg_thread:
266
+ thread = Thread(target=_run_app, daemon=True)
267
+ thread.start()
268
+ else:
269
+ _run_app()
270
+
271
+ return None
272
+
273
+ def console(
274
+ self,
275
+ *,
276
+ echo: Literal["text", "all", "none"] = "text",
277
+ stream: bool = True,
278
+ kwargs: Optional[SubmitInputArgsT] = None,
279
+ ):
280
+ """
281
+ Enter a chat console to interact with the LLM.
282
+
283
+ To quit, input 'exit' or press Ctrl+C.
284
+
285
+ Parameters
286
+ ----------
287
+ echo
288
+ Whether to echo text content, all content (i.e., tool calls), or no
289
+ content.
290
+ stream
291
+ Whether to stream the response (i.e., have the response appear in chunks).
292
+ kwargs
293
+ Additional keyword arguments to pass to the method used for requesting
294
+ the response
295
+
296
+ Returns
297
+ -------
298
+ None
299
+ """
300
+
301
+ print("\nEntering chat console. To quit, input 'exit' or press Ctrl+C.\n")
302
+
303
+ while True:
304
+ user_input = input("?> ")
305
+ if user_input.strip().lower() in ("exit", "exit()"):
306
+ break
307
+ print("")
308
+ self.chat(user_input, echo=echo, stream=stream, kwargs=kwargs)
309
+ print("")
310
+
311
+ def chat(
312
+ self,
313
+ *args: Content | str,
314
+ echo: Literal["text", "all", "none"] = "text",
315
+ stream: bool = True,
316
+ kwargs: Optional[SubmitInputArgsT] = None,
317
+ ) -> ChatResponse:
318
+ """
319
+ Generate a response from the chat.
320
+
321
+ Parameters
322
+ ----------
323
+ args
324
+ The user input(s) to generate a response from.
325
+ echo
326
+ Whether to echo text content, all content (i.e., tool calls), or no
327
+ content.
328
+ stream
329
+ Whether to stream the response (i.e., have the response appear in
330
+ chunks).
331
+ kwargs
332
+ Additional keyword arguments to pass to the method used for
333
+ requesting the response.
334
+
335
+ Returns
336
+ -------
337
+ ChatResponse
338
+ A (consumed) response from the chat. Apply `str()` to this object to
339
+ get the text content of the response.
340
+ """
341
+ turn = user_turn(*args)
342
+
343
+ display = self._markdown_display(echo=echo)
344
+
345
+ response = ChatResponse(
346
+ self._chat_impl(
347
+ turn,
348
+ echo=echo,
349
+ display=display,
350
+ stream=stream,
351
+ kwargs=kwargs,
352
+ )
353
+ )
354
+
355
+ with display:
356
+ for _ in response:
357
+ pass
358
+
359
+ return response
360
+
361
+ async def chat_async(
362
+ self,
363
+ *args: Content | str,
364
+ echo: Literal["text", "all", "none"] = "text",
365
+ stream: bool = True,
366
+ kwargs: Optional[SubmitInputArgsT] = None,
367
+ ) -> ChatResponseAsync:
368
+ """
369
+ Generate a response from the chat asynchronously.
370
+
371
+ Parameters
372
+ ----------
373
+ args
374
+ The user input(s) to generate a response from.
375
+ echo
376
+ Whether to echo text content, all content (i.e., tool calls, images,
377
+ etc), or no content.
378
+ stream
379
+ Whether to stream the response (i.e., have the response appear in
380
+ chunks).
381
+ kwargs
382
+ Additional keyword arguments to pass to the method used for
383
+ requesting the response.
384
+
385
+ Returns
386
+ -------
387
+ ChatResponseAsync
388
+ A (consumed) response from the chat. Apply `str()` to this object to
389
+ get the text content of the response.
390
+ """
391
+ turn = user_turn(*args)
392
+
393
+ display = self._markdown_display(echo=echo)
394
+
395
+ response = ChatResponseAsync(
396
+ self._chat_impl_async(
397
+ turn,
398
+ echo=echo,
399
+ display=display,
400
+ stream=stream,
401
+ kwargs=kwargs,
402
+ ),
403
+ )
404
+
405
+ with display:
406
+ async for _ in response:
407
+ pass
408
+
409
+ return response
410
+
411
+ def stream(
412
+ self,
413
+ *args: Content | str,
414
+ echo: Literal["text", "all", "none"] = "none",
415
+ kwargs: Optional[SubmitInputArgsT] = None,
416
+ ) -> ChatResponse:
417
+ """
418
+ Generate a response from the chat in a streaming fashion.
419
+
420
+ Parameters
421
+ ----------
422
+ args
423
+ The user input(s) to generate a response from.
424
+ echo
425
+ Whether to echo text content, all content (i.e., tool calls), or no
426
+ content.
427
+ kwargs
428
+ Additional keyword arguments to pass to the method used for requesting
429
+ the response.
430
+
431
+ Returns
432
+ -------
433
+ ChatResponse
434
+ An (unconsumed) response from the chat. Iterate over this object to
435
+ consume the response.
436
+ """
437
+ turn = user_turn(*args)
438
+
439
+ display = self._markdown_display(echo=echo)
440
+
441
+ generator = self._chat_impl(
442
+ turn,
443
+ stream=True,
444
+ display=display,
445
+ echo=echo,
446
+ kwargs=kwargs,
447
+ )
448
+
449
+ def wrapper() -> Generator[str, None, None]:
450
+ with display:
451
+ for chunk in generator:
452
+ yield chunk
453
+
454
+ return ChatResponse(wrapper())
455
+
456
+ async def stream_async(
457
+ self,
458
+ *args: Content | str,
459
+ echo: Literal["text", "all", "none"] = "none",
460
+ kwargs: Optional[SubmitInputArgsT] = None,
461
+ ) -> ChatResponseAsync:
462
+ """
463
+ Generate a response from the chat in a streaming fashion asynchronously.
464
+
465
+ Parameters
466
+ ----------
467
+ args
468
+ The user input(s) to generate a response from.
469
+ echo
470
+ Whether to echo text content, all content (i.e., tool calls), or no
471
+ content.
472
+ kwargs
473
+ Additional keyword arguments to pass to the method used for requesting
474
+ the response.
475
+
476
+ Returns
477
+ -------
478
+ ChatResponseAsync
479
+ An (unconsumed) response from the chat. Iterate over this object to
480
+ consume the response.
481
+ """
482
+ turn = user_turn(*args)
483
+
484
+ display = self._markdown_display(echo=echo)
485
+
486
+ async def wrapper() -> AsyncGenerator[str, None]:
487
+ with display:
488
+ async for chunk in self._chat_impl_async(
489
+ turn,
490
+ stream=True,
491
+ display=display,
492
+ echo=echo,
493
+ kwargs=kwargs,
494
+ ):
495
+ yield chunk
496
+
497
+ return ChatResponseAsync(wrapper())
498
+
499
+ def extract_data(
500
+ self,
501
+ *args: Content | str,
502
+ data_model: type[BaseModel],
503
+ echo: Literal["text", "all", "none"] = "none",
504
+ stream: bool = False,
505
+ ) -> dict[str, Any]:
506
+ """
507
+ Extract structured data from the given input.
508
+
509
+ Parameters
510
+ ----------
511
+ args
512
+ The input to extract data from.
513
+ data_model
514
+ A Pydantic model describing the structure of the data to extract.
515
+ echo
516
+ Whether to echo text content, all content (i.e., tool calls), or no content.
517
+ stream
518
+ Whether to stream the response (i.e., have the response appear in chunks).
519
+
520
+ Returns
521
+ -------
522
+ dict[str, Any]
523
+ The extracted data.
524
+ """
525
+
526
+ display = self._markdown_display(echo=echo)
527
+
528
+ response = ChatResponse(
529
+ self._submit_turns(
530
+ user_turn(*args),
531
+ data_model=data_model,
532
+ echo=echo,
533
+ display=display,
534
+ stream=stream,
535
+ )
536
+ )
537
+
538
+ with display:
539
+ for _ in response:
540
+ pass
541
+
542
+ turn = self.get_last_turn()
543
+ assert turn is not None
544
+
545
+ res: list[ContentJson] = []
546
+ for x in turn.contents:
547
+ if isinstance(x, ContentJson):
548
+ res.append(x)
549
+
550
+ if len(res) != 1:
551
+ raise ValueError(
552
+ f"Data extraction failed: {len(res)} data results received."
553
+ )
554
+
555
+ json = res[0]
556
+ return json.value
557
+
558
+ async def extract_data_async(
559
+ self,
560
+ *args: Content | str,
561
+ data_model: type[BaseModel],
562
+ echo: Literal["text", "all", "none"] = "none",
563
+ stream: bool = False,
564
+ ) -> dict[str, Any]:
565
+ """
566
+ Extract structured data from the given input asynchronously.
567
+
568
+ Parameters
569
+ ----------
570
+ args
571
+ The input to extract data from.
572
+ data_model
573
+ A Pydantic model describing the structure of the data to extract.
574
+ echo
575
+ Whether to echo text content, all content (i.e., tool calls), or no content
576
+ stream
577
+ Whether to stream the response (i.e., have the response appear in chunks).
578
+ Defaults to `True` if `echo` is not "none".
579
+
580
+ Returns
581
+ -------
582
+ dict[str, Any]
583
+ The extracted data.
584
+ """
585
+
586
+ display = self._markdown_display(echo=echo)
587
+
588
+ response = ChatResponseAsync(
589
+ self._submit_turns_async(
590
+ user_turn(*args),
591
+ data_model=data_model,
592
+ echo=echo,
593
+ display=display,
594
+ stream=stream,
595
+ )
596
+ )
597
+
598
+ with display:
599
+ async for _ in response:
600
+ pass
601
+
602
+ turn = self.get_last_turn()
603
+ assert turn is not None
604
+
605
+ res: list[ContentJson] = []
606
+ for x in turn.contents:
607
+ if isinstance(x, ContentJson):
608
+ res.append(x)
609
+
610
+ if len(res) != 1:
611
+ raise ValueError(
612
+ f"Data extraction failed: {len(res)} data results received."
613
+ )
614
+
615
+ json = res[0]
616
+ return json.value
617
+
618
+ def register_tool(
619
+ self,
620
+ func: Callable[..., Any] | Callable[..., Awaitable[Any]],
621
+ *,
622
+ model: Optional[type[BaseModel]] = None,
623
+ ):
624
+ """
625
+ Register a tool (function) with the chat.
626
+
627
+ The function will always be invoked in the current Python process.
628
+
629
+ Examples
630
+ --------
631
+
632
+ If your tool has straightforward input parameters, you can just
633
+ register the function directly (type hints and a docstring explaning
634
+ both what the function does and what the parameters are for is strongly
635
+ recommended):
636
+
637
+ ```python
638
+ from chatlas import ChatOpenAI, Tool
639
+
640
+
641
+ def add(a: int, b: int) -> int:
642
+ '''
643
+ Add two numbers together.
644
+
645
+ Parameters
646
+ ----------
647
+ a : int
648
+ The first number to add.
649
+ b : int
650
+ The second number to add.
651
+ '''
652
+ return a + b
653
+
654
+
655
+ chat = ChatOpenAI()
656
+ chat.register_tool(add)
657
+ chat.chat("What is 2 + 2?")
658
+ ```
659
+
660
+ If your tool has more complex input parameters, you can provide a Pydantic
661
+ model that corresponds to the input parameters for the function, This way, you
662
+ can have fields that hold other model(s) (for more complex input parameters),
663
+ and also more directly document the input parameters:
664
+
665
+ ```python
666
+ from chatlas import ChatOpenAI, Tool
667
+ from pydantic import BaseModel, Field
668
+
669
+
670
+ class AddParams(BaseModel):
671
+ '''Add two numbers together.'''
672
+
673
+ a: int = Field(description="The first number to add.")
674
+
675
+ b: int = Field(description="The second number to add.")
676
+
677
+
678
+ def add(a: int, b: int) -> int:
679
+ return a + b
680
+
681
+
682
+ chat = ChatOpenAI()
683
+ chat.register_tool(add, model=AddParams)
684
+ chat.chat("What is 2 + 2?")
685
+ ```
686
+
687
+ Parameters
688
+ ----------
689
+ func
690
+ The function to be invoked when the tool is called.
691
+ model
692
+ A Pydantic model that describes the input parameters for the function.
693
+ If not provided, the model will be inferred from the function's type hints.
694
+ The primary reason why you might want to provide a model in
695
+ Note that the name and docstring of the model takes precedence over the
696
+ name and docstring of the function.
697
+ """
698
+ tool = Tool(func, model=model)
699
+ self._tools[tool.name] = tool
700
+
701
+ def export(
702
+ self,
703
+ filename: str | Path,
704
+ *,
705
+ turns: Optional[Sequence[Turn]] = None,
706
+ title: Optional[str] = None,
707
+ include: Literal["text", "all"] = "text",
708
+ include_system_prompt: bool = True,
709
+ overwrite: bool = False,
710
+ ):
711
+ """
712
+ Export the chat history to a file.
713
+
714
+ Parameters
715
+ ----------
716
+ filename
717
+ The filename to export the chat to. Currently this must
718
+ be a `.md` or `.html` file.
719
+ turns
720
+ The `.get_turns()` to export. If not provided, the chat's current turns
721
+ will be used.
722
+ title
723
+ A title to place at the top of the exported file.
724
+ overwrite
725
+ Whether to overwrite the file if it already exists.
726
+ include
727
+ Whether to include text content, all content (i.e., tool calls), or no
728
+ content.
729
+ include_system_prompt
730
+ Whether to include the system prompt in a <details> tag.
731
+
732
+ Returns
733
+ -------
734
+ Path
735
+ The path to the exported file.
736
+ """
737
+ if not turns:
738
+ turns = self.get_turns(include_system_prompt=False)
739
+ if not turns:
740
+ raise ValueError("No turns to export.")
741
+
742
+ if isinstance(filename, str):
743
+ filename = Path(filename)
744
+
745
+ filename = filename.resolve()
746
+ if filename.exists() and not overwrite:
747
+ raise ValueError(
748
+ f"File {filename} already exists. Set `overwrite=True` to overwrite."
749
+ )
750
+
751
+ if filename.suffix not in {".md", ".html"}:
752
+ raise ValueError("The filename must have a `.md` or `.html` extension.")
753
+
754
+ # When exporting to HTML, we lean on shiny's chat component for rendering markdown and styling
755
+ is_html = filename.suffix == ".html"
756
+
757
+ # Get contents from each turn
758
+ contents = ""
759
+ for turn in turns:
760
+ turn_content = "\n\n".join(
761
+ [
762
+ str(content)
763
+ for content in turn.contents
764
+ if include == "all" or isinstance(content, ContentText)
765
+ ]
766
+ )
767
+ if is_html:
768
+ msg_type = "user" if turn.role == "user" else "chat"
769
+ content_attr = html_escape(turn_content)
770
+ turn_content = f"<shiny-{msg_type}-message content='{content_attr}'></shiny-{msg_type}-message>"
771
+ else:
772
+ turn_content = f"## {turn.role.capitalize()}\n\n{turn_content}"
773
+ contents += f"{turn_content}\n\n"
774
+
775
+ # Shiny chat message components requires container elements
776
+ if is_html:
777
+ contents = f"<shiny-chat-messages>\n{contents}\n</shiny-chat-messages>"
778
+ contents = f"<shiny-chat-container>{contents}</shiny-chat-container>"
779
+
780
+ # Add title to the top
781
+ if title:
782
+ if is_html:
783
+ contents = f"<h1>{title}</h1>\n\n{contents}"
784
+ else:
785
+ contents = f"# {title}\n\n{contents}"
786
+
787
+ # Add system prompt to the bottom
788
+ if include_system_prompt and self.system_prompt:
789
+ contents += f"\n<br><br>\n<details><summary>System prompt</summary>\n\n{self.system_prompt}\n\n</details>"
790
+
791
+ # Wrap in HTML template if exporting to HTML
792
+ if is_html:
793
+ contents = self._html_template(contents)
794
+
795
+ with open(filename, "w") as f:
796
+ f.write(contents)
797
+
798
+ return filename
799
+
800
+ @staticmethod
801
+ def _html_template(contents: str) -> str:
802
+ version = "1.2.1"
803
+ shiny_www = (
804
+ f"https://cdn.jsdelivr.net/gh/posit-dev/py-shiny@{version}/shiny/www/"
805
+ )
806
+
807
+ return f"""
808
+ <!DOCTYPE html>
809
+ <html>
810
+ <head>
811
+ <script src="{shiny_www}/py-shiny/chat/chat.js"></script>
812
+ <link rel="stylesheet" href="{shiny_www}/py-shiny/chat/chat.css">
813
+ <link rel="stylesheet" href="{shiny_www}/shared/bootstrap/bootstrap.min.css">
814
+ </head>
815
+ <body>
816
+ <div style="max-width:700px; margin:0 auto; padding-top:20px;">
817
+ {contents}
818
+ </div>
819
+ </body>
820
+ </html>
821
+ """
822
+
823
+ def _chat_impl(
824
+ self,
825
+ user_turn: Turn,
826
+ echo: Literal["text", "all", "none"],
827
+ display: MarkdownDisplay,
828
+ stream: bool,
829
+ kwargs: Optional[SubmitInputArgsT] = None,
830
+ ) -> Generator[str, None, None]:
831
+ user_turn_result: Turn | None = user_turn
832
+ while user_turn_result is not None:
833
+ for chunk in self._submit_turns(
834
+ user_turn_result,
835
+ echo=echo,
836
+ display=display,
837
+ stream=stream,
838
+ kwargs=kwargs,
839
+ ):
840
+ yield chunk
841
+ user_turn_result = self._invoke_tools()
842
+
843
+ async def _chat_impl_async(
844
+ self,
845
+ user_turn: Turn,
846
+ echo: Literal["text", "all", "none"],
847
+ display: MarkdownDisplay,
848
+ stream: bool,
849
+ kwargs: Optional[SubmitInputArgsT] = None,
850
+ ) -> AsyncGenerator[str, None]:
851
+ user_turn_result: Turn | None = user_turn
852
+ while user_turn_result is not None:
853
+ async for chunk in self._submit_turns_async(
854
+ user_turn_result,
855
+ echo=echo,
856
+ display=display,
857
+ stream=stream,
858
+ kwargs=kwargs,
859
+ ):
860
+ yield chunk
861
+ user_turn_result = await self._invoke_tools_async()
862
+
863
+ def _submit_turns(
864
+ self,
865
+ user_turn: Turn,
866
+ echo: Literal["text", "all", "none"],
867
+ display: MarkdownDisplay,
868
+ stream: bool,
869
+ data_model: type[BaseModel] | None = None,
870
+ kwargs: Optional[SubmitInputArgsT] = None,
871
+ ) -> Generator[str, None, None]:
872
+ if any(x._is_async for x in self._tools.values()):
873
+ raise ValueError("Cannot use async tools in a synchronous chat")
874
+
875
+ def emit(text: str | Content):
876
+ display.update(str(text))
877
+
878
+ emit("<br>\n\n")
879
+
880
+ if echo == "all":
881
+ emit_user_contents(user_turn, emit)
882
+
883
+ if stream:
884
+ response = self.provider.chat_perform(
885
+ stream=True,
886
+ turns=[*self._turns, user_turn],
887
+ tools=self._tools,
888
+ data_model=data_model,
889
+ kwargs=kwargs,
890
+ )
891
+
892
+ result = None
893
+ for chunk in response:
894
+ text = self.provider.stream_text(chunk)
895
+ if text:
896
+ emit(text)
897
+ yield text
898
+ result = self.provider.stream_merge_chunks(result, chunk)
899
+
900
+ turn = self.provider.stream_turn(
901
+ result,
902
+ has_data_model=data_model is not None,
903
+ stream=response,
904
+ )
905
+
906
+ if echo == "all":
907
+ emit_other_contents(turn, emit)
908
+
909
+ else:
910
+ response = self.provider.chat_perform(
911
+ stream=False,
912
+ turns=[*self._turns, user_turn],
913
+ tools=self._tools,
914
+ data_model=data_model,
915
+ kwargs=kwargs,
916
+ )
917
+
918
+ turn = self.provider.value_turn(
919
+ response, has_data_model=data_model is not None
920
+ )
921
+ if turn.text:
922
+ emit(turn.text)
923
+ yield turn.text
924
+
925
+ if echo == "all":
926
+ emit_other_contents(turn, emit)
927
+
928
+ self._turns.extend([user_turn, turn])
929
+
930
+ async def _submit_turns_async(
931
+ self,
932
+ user_turn: Turn,
933
+ echo: Literal["text", "all", "none"],
934
+ display: MarkdownDisplay,
935
+ stream: bool,
936
+ data_model: type[BaseModel] | None = None,
937
+ kwargs: Optional[SubmitInputArgsT] = None,
938
+ ) -> AsyncGenerator[str, None]:
939
+ def emit(text: str | Content):
940
+ display.update(str(text))
941
+
942
+ emit("<br>\n\n")
943
+
944
+ if echo == "all":
945
+ emit_user_contents(user_turn, emit)
946
+
947
+ if stream:
948
+ response = await self.provider.chat_perform_async(
949
+ stream=True,
950
+ turns=[*self._turns, user_turn],
951
+ tools=self._tools,
952
+ data_model=data_model,
953
+ kwargs=kwargs,
954
+ )
955
+
956
+ result = None
957
+ async for chunk in response:
958
+ text = self.provider.stream_text(chunk)
959
+ if text:
960
+ emit(text)
961
+ yield text
962
+ result = self.provider.stream_merge_chunks(result, chunk)
963
+
964
+ turn = await self.provider.stream_turn_async(
965
+ result,
966
+ has_data_model=data_model is not None,
967
+ stream=response,
968
+ )
969
+
970
+ if echo == "all":
971
+ emit_other_contents(turn, emit)
972
+
973
+ else:
974
+ response = await self.provider.chat_perform_async(
975
+ stream=False,
976
+ turns=[*self._turns, user_turn],
977
+ tools=self._tools,
978
+ data_model=data_model,
979
+ kwargs=kwargs,
980
+ )
981
+
982
+ turn = self.provider.value_turn(
983
+ response, has_data_model=data_model is not None
984
+ )
985
+ if turn.text:
986
+ emit(turn.text)
987
+ yield turn.text
988
+
989
+ if echo == "all":
990
+ emit_other_contents(turn, emit)
991
+
992
+ self._turns.extend([user_turn, turn])
993
+
994
+ def _invoke_tools(self) -> Turn | None:
995
+ turn = self.get_last_turn()
996
+ if turn is None:
997
+ return None
998
+
999
+ results: list[ContentToolResult] = []
1000
+ for x in turn.contents:
1001
+ if isinstance(x, ContentToolRequest):
1002
+ tool_def = self._tools.get(x.name, None)
1003
+ func = tool_def.func if tool_def is not None else None
1004
+ results.append(self._invoke_tool(func, x.arguments, x.id))
1005
+
1006
+ if not results:
1007
+ return None
1008
+
1009
+ return Turn("user", results)
1010
+
1011
+ async def _invoke_tools_async(self) -> Turn | None:
1012
+ turn = self.get_last_turn()
1013
+ if turn is None:
1014
+ return None
1015
+
1016
+ results: list[ContentToolResult] = []
1017
+ for x in turn.contents:
1018
+ if isinstance(x, ContentToolRequest):
1019
+ tool_def = self._tools.get(x.name, None)
1020
+ func = tool_def.func if tool_def is not None else None
1021
+ results.append(await self._invoke_tool_async(func, x.arguments, x.id))
1022
+
1023
+ if not results:
1024
+ return None
1025
+
1026
+ return Turn("user", results)
1027
+
1028
+ @staticmethod
1029
+ def _invoke_tool(
1030
+ func: Callable[..., Any] | None,
1031
+ arguments: object,
1032
+ id_: str,
1033
+ ) -> ContentToolResult:
1034
+ if func is None:
1035
+ return ContentToolResult(id_, None, "Unknown tool")
1036
+
1037
+ try:
1038
+ if isinstance(arguments, dict):
1039
+ result = func(**arguments)
1040
+ else:
1041
+ result = func(arguments)
1042
+
1043
+ return ContentToolResult(id_, result, None)
1044
+ except Exception as e:
1045
+ log_tool_error(func.__name__, str(arguments), e)
1046
+ return ContentToolResult(id_, None, str(e))
1047
+
1048
+ @staticmethod
1049
+ async def _invoke_tool_async(
1050
+ func: Callable[..., Awaitable[Any]] | None,
1051
+ arguments: object,
1052
+ id_: str,
1053
+ ) -> ContentToolResult:
1054
+ if func is None:
1055
+ return ContentToolResult(id_, None, "Unknown tool")
1056
+
1057
+ try:
1058
+ if isinstance(arguments, dict):
1059
+ result = await func(**arguments)
1060
+ else:
1061
+ result = await func(arguments)
1062
+
1063
+ return ContentToolResult(id_, result, None)
1064
+ except Exception as e:
1065
+ log_tool_error(func.__name__, str(arguments), e)
1066
+ return ContentToolResult(id_, None, str(e))
1067
+
1068
+ def _markdown_display(
1069
+ self, echo: Literal["text", "all", "none"]
1070
+ ) -> MarkdownDisplay:
1071
+ """
1072
+ Get a markdown display object based on the echo option.
1073
+
1074
+ The idea here is to use rich for consoles and IPython.display.Markdown
1075
+ for notebooks, since the latter is much more responsive to different
1076
+ screen sizes.
1077
+ """
1078
+ if echo == "none":
1079
+ return MockMarkdownDisplay()
1080
+
1081
+ # rich does a lot to detect a notebook environment, but it doesn't
1082
+ # detect Quarto (at least not yet).
1083
+ from rich.console import Console
1084
+
1085
+ is_web = Console().is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
1086
+
1087
+ opts = self._echo_options
1088
+ if is_web:
1089
+ return IPyMarkdownDisplay(opts)
1090
+ else:
1091
+ return LiveMarkdownDisplay(opts)
1092
+
1093
+ def set_echo_options(
1094
+ self,
1095
+ rich_markdown: Optional[dict[str, Any]] = None,
1096
+ rich_console: Optional[dict[str, Any]] = None,
1097
+ css_styles: Optional[dict[str, str]] = None,
1098
+ ):
1099
+ """
1100
+ Set echo styling options for the chat.
1101
+
1102
+ Parameters
1103
+ ----------
1104
+ rich_markdown
1105
+ A dictionary of options to pass to `rich.markdown.Markdown()`.
1106
+ This is only relevant when outputting to the console.
1107
+ rich_console
1108
+ A dictionary of options to pass to `rich.console.Console()`.
1109
+ This is only relevant when outputting to the console.
1110
+ css_styles
1111
+ A dictionary of CSS styles to apply to `IPython.display.Markdown()`.
1112
+ This is only relevant when outputing to the browser.
1113
+ """
1114
+ self._echo_options: EchoOptions = {
1115
+ "rich_markdown": rich_markdown or {},
1116
+ "rich_console": rich_console or {},
1117
+ "css_styles": css_styles or {},
1118
+ }
1119
+
1120
+ def __str__(self):
1121
+ turns = self.get_turns(include_system_prompt=False)
1122
+ res = ""
1123
+ for turn in turns:
1124
+ icon = "👤" if turn.role == "user" else "🤖"
1125
+ res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n"
1126
+ return res
1127
+
1128
+ def __repr__(self):
1129
+ turns = self.get_turns(include_system_prompt=True)
1130
+ tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
1131
+ res = f"<Chat turns={len(turns)} tokens={tokens}>"
1132
+ for turn in turns:
1133
+ res += "\n" + turn.__repr__(indent=2)
1134
+ return res + "\n"
1135
+
1136
+
1137
+ class ChatResponse:
1138
+ """
1139
+ Chat response object.
1140
+
1141
+ An object that, when displayed, will simulatenously consume (if not
1142
+ already consumed) and display the response in a streaming fashion.
1143
+
1144
+ This is useful for interactive use: if the object is displayed, it can
1145
+ be viewed as it is being generated. And, if the object is not displayed,
1146
+ it can act like an iterator that can be consumed by something else.
1147
+
1148
+ Attributes
1149
+ ----------
1150
+ content
1151
+ The content of the chat response.
1152
+
1153
+ Properties
1154
+ ----------
1155
+ consumed
1156
+ Whether the response has been consumed. If the response has been fully
1157
+ consumed, then it can no longer be iterated over, but the content can
1158
+ still be retrieved (via the `content` attribute).
1159
+ """
1160
+
1161
+ def __init__(self, generator: Generator[str, None]):
1162
+ self._generator = generator
1163
+ self.content: str = ""
1164
+
1165
+ def __iter__(self) -> Iterator[str]:
1166
+ return self
1167
+
1168
+ def __next__(self) -> str:
1169
+ chunk = next(self._generator)
1170
+ self.content += chunk # Keep track of accumulated content
1171
+ return chunk
1172
+
1173
+ def get_content(self) -> str:
1174
+ """
1175
+ Get the chat response content as a string.
1176
+ """
1177
+ for _ in self:
1178
+ pass
1179
+ return self.content
1180
+
1181
+ @property
1182
+ def consumed(self) -> bool:
1183
+ return self._generator.gi_frame is None
1184
+
1185
+ def __str__(self) -> str:
1186
+ return self.get_content()
1187
+
1188
+
1189
+ class ChatResponseAsync:
1190
+ """
1191
+ Chat response (async) object.
1192
+
1193
+ An object that, when displayed, will simulatenously consume (if not
1194
+ already consumed) and display the response in a streaming fashion.
1195
+
1196
+ This is useful for interactive use: if the object is displayed, it can
1197
+ be viewed as it is being generated. And, if the object is not displayed,
1198
+ it can act like an iterator that can be consumed by something else.
1199
+
1200
+ Attributes
1201
+ ----------
1202
+ content
1203
+ The content of the chat response.
1204
+
1205
+ Properties
1206
+ ----------
1207
+ consumed
1208
+ Whether the response has been consumed. If the response has been fully
1209
+ consumed, then it can no longer be iterated over, but the content can
1210
+ still be retrieved (via the `content` attribute).
1211
+ """
1212
+
1213
+ def __init__(self, generator: AsyncGenerator[str, None]):
1214
+ self._generator = generator
1215
+ self.content: str = ""
1216
+
1217
+ def __aiter__(self) -> AsyncIterator[str]:
1218
+ return self
1219
+
1220
+ async def __anext__(self) -> str:
1221
+ chunk = await self._generator.__anext__()
1222
+ self.content += chunk # Keep track of accumulated content
1223
+ return chunk
1224
+
1225
+ async def get_content(self) -> str:
1226
+ "Get the chat response content as a string."
1227
+ async for _ in self:
1228
+ pass
1229
+ return self.content
1230
+
1231
+ @property
1232
+ def consumed(self) -> bool:
1233
+ return self._generator.ag_frame is None
1234
+
1235
+
1236
+ # ----------------------------------------------------------------------------
1237
+ # Helpers for emitting content
1238
+ # ----------------------------------------------------------------------------
1239
+
1240
+
1241
+ def emit_user_contents(
1242
+ x: Turn,
1243
+ emit: Callable[[Content | str], None],
1244
+ ):
1245
+ if x.role != "user":
1246
+ raise ValueError("Expected a user turn")
1247
+ emit(f"## 👤 User turn:\n\n{str(x)}\n\n")
1248
+ emit_other_contents(x, emit)
1249
+ emit("\n\n## 🤖 Assistant turn:\n\n")
1250
+
1251
+
1252
+ def emit_other_contents(
1253
+ x: Turn,
1254
+ emit: Callable[[Content | str], None],
1255
+ ):
1256
+ # Gather other content to emit in _reverse_ order
1257
+ to_emit: list[str] = []
1258
+
1259
+ if x.finish_reason:
1260
+ to_emit.append(f"\n\n<< 🤖 finish reason: {x.finish_reason} \\>\\>\n\n")
1261
+
1262
+ has_text = False
1263
+ has_other = False
1264
+ for content in reversed(x.contents):
1265
+ if isinstance(content, ContentText):
1266
+ has_text = True
1267
+ else:
1268
+ has_other = True
1269
+ to_emit.append(str(content))
1270
+
1271
+ if has_text and has_other:
1272
+ if x.role == "user":
1273
+ to_emit.append("<< 👤 other content >>")
1274
+ else:
1275
+ to_emit.append("<< 🤖 other content >>")
1276
+
1277
+ to_emit.reverse()
1278
+
1279
+ emit("\n\n".join(to_emit))