chatlas 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_chat.py CHANGED
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import os
5
5
  import sys
6
+ import traceback
7
+ import warnings
6
8
  from pathlib import Path
7
9
  from threading import Thread
8
10
  from typing import (
@@ -31,7 +33,7 @@ from ._content import (
31
33
  ContentToolResult,
32
34
  )
33
35
  from ._display import (
34
- EchoOptions,
36
+ EchoDisplayOptions,
35
37
  IPyMarkdownDisplay,
36
38
  LiveMarkdownDisplay,
37
39
  MarkdownDisplay,
@@ -57,6 +59,8 @@ method of a [](`~chatlas.Chat`) instance.
57
59
 
58
60
  CompletionT = TypeVar("CompletionT")
59
61
 
62
+ EchoOptions = Literal["output", "all", "none", "text"]
63
+
60
64
 
61
65
  class Chat(Generic[SubmitInputArgsT, CompletionT]):
62
66
  """
@@ -91,7 +95,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
91
95
  self.provider = provider
92
96
  self._turns: list[Turn] = list(turns or [])
93
97
  self._tools: dict[str, Tool] = {}
94
- self._echo_options: EchoOptions = {
98
+ self._current_display: Optional[MarkdownDisplay] = None
99
+ self._echo_options: EchoDisplayOptions = {
95
100
  "rich_markdown": {},
96
101
  "rich_console": {},
97
102
  "css_styles": {},
@@ -390,7 +395,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
390
395
  port: int = 0,
391
396
  launch_browser: bool = True,
392
397
  bg_thread: Optional[bool] = None,
393
- echo: Optional[Literal["text", "all", "none"]] = None,
398
+ echo: Optional[EchoOptions] = None,
399
+ content: Literal["text", "all"] = "all",
394
400
  kwargs: Optional[SubmitInputArgsT] = None,
395
401
  ):
396
402
  """
@@ -411,6 +417,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
411
417
  Whether to echo text content, all content (i.e., tool calls), or no
412
418
  content. Defaults to `"none"` when `stream=True` and `"text"` when
413
419
  `stream=False`.
420
+ content
421
+ Whether to display text content or all content (i.e., tool calls).
414
422
  kwargs
415
423
  Additional keyword arguments to pass to the method used for requesting
416
424
  the response.
@@ -439,16 +447,14 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
439
447
  )
440
448
 
441
449
  @chat.on_user_submit
442
- async def _():
443
- user_input = chat.user_input()
444
- if user_input is None:
445
- return
450
+ async def _(user_input: str):
446
451
  if stream:
447
452
  await chat.append_message_stream(
448
453
  await self.stream_async(
449
454
  user_input,
450
455
  kwargs=kwargs,
451
456
  echo=echo or "none",
457
+ content=content,
452
458
  )
453
459
  )
454
460
  else:
@@ -457,6 +463,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
457
463
  self.chat(
458
464
  user_input,
459
465
  kwargs=kwargs,
466
+ stream=False,
460
467
  echo=echo or "text",
461
468
  )
462
469
  )
@@ -485,7 +492,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
485
492
  def console(
486
493
  self,
487
494
  *,
488
- echo: Literal["text", "all", "none"] = "text",
495
+ echo: EchoOptions = "output",
489
496
  stream: bool = True,
490
497
  kwargs: Optional[SubmitInputArgsT] = None,
491
498
  ):
@@ -523,7 +530,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
523
530
  def chat(
524
531
  self,
525
532
  *args: Content | str,
526
- echo: Literal["text", "all", "none"] = "text",
533
+ echo: EchoOptions = "output",
527
534
  stream: bool = True,
528
535
  kwargs: Optional[SubmitInputArgsT] = None,
529
536
  ) -> ChatResponse:
@@ -558,7 +565,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
558
565
  self._chat_impl(
559
566
  turn,
560
567
  echo=echo,
561
- display=display,
568
+ content="text",
562
569
  stream=stream,
563
570
  kwargs=kwargs,
564
571
  )
@@ -573,7 +580,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
573
580
  async def chat_async(
574
581
  self,
575
582
  *args: Content | str,
576
- echo: Literal["text", "all", "none"] = "text",
583
+ echo: EchoOptions = "output",
577
584
  stream: bool = True,
578
585
  kwargs: Optional[SubmitInputArgsT] = None,
579
586
  ) -> ChatResponseAsync:
@@ -608,7 +615,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
608
615
  self._chat_impl_async(
609
616
  turn,
610
617
  echo=echo,
611
- display=display,
618
+ content="text",
612
619
  stream=stream,
613
620
  kwargs=kwargs,
614
621
  ),
@@ -620,12 +627,44 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
620
627
 
621
628
  return response
622
629
 
630
+ @overload
631
+ def stream(
632
+ self,
633
+ *args: Content | str,
634
+ ) -> Generator[str, None, None]: ...
635
+
636
+ @overload
637
+ def stream(
638
+ self,
639
+ *args: Content | str,
640
+ echo: EchoOptions,
641
+ ) -> Generator[str, None, None]: ...
642
+
643
+ @overload
644
+ def stream(
645
+ self,
646
+ *args: Content | str,
647
+ echo: EchoOptions,
648
+ content: Literal["text"],
649
+ kwargs: Optional[SubmitInputArgsT],
650
+ ) -> Generator[str, None, None]: ...
651
+
652
+ @overload
623
653
  def stream(
624
654
  self,
625
655
  *args: Content | str,
626
- echo: Literal["text", "all", "none"] = "none",
656
+ echo: EchoOptions,
657
+ content: Literal["all"],
658
+ kwargs: Optional[SubmitInputArgsT],
659
+ ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
660
+
661
+ def stream(
662
+ self,
663
+ *args: Content | str,
664
+ echo: EchoOptions = "none",
665
+ content: Literal["text", "all"] = "text",
627
666
  kwargs: Optional[SubmitInputArgsT] = None,
628
- ) -> ChatResponse:
667
+ ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
629
668
  """
630
669
  Generate a response from the chat in a streaming fashion.
631
670
 
@@ -636,6 +675,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
636
675
  echo
637
676
  Whether to echo text content, all content (i.e., tool calls), or no
638
677
  content.
678
+ content
679
+ Whether to yield just text content, or all content (i.e., tool calls).
639
680
  kwargs
640
681
  Additional keyword arguments to pass to the method used for requesting
641
682
  the response.
@@ -653,24 +694,58 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
653
694
  generator = self._chat_impl(
654
695
  turn,
655
696
  stream=True,
656
- display=display,
657
697
  echo=echo,
698
+ content=content,
658
699
  kwargs=kwargs,
659
700
  )
660
701
 
661
- def wrapper() -> Generator[str, None, None]:
702
+ def wrapper() -> Generator[
703
+ str | ContentToolRequest | ContentToolResult, None, None
704
+ ]:
662
705
  with display:
663
706
  for chunk in generator:
664
707
  yield chunk
665
708
 
666
- return ChatResponse(wrapper())
709
+ return wrapper()
710
+
711
+ @overload
712
+ async def stream_async(
713
+ self,
714
+ *args: Content | str,
715
+ ) -> AsyncGenerator[str, None]: ...
716
+
717
+ @overload
718
+ async def stream_async(
719
+ self,
720
+ *args: Content | str,
721
+ echo: EchoOptions,
722
+ ) -> AsyncGenerator[str, None]: ...
667
723
 
724
+ @overload
668
725
  async def stream_async(
669
726
  self,
670
727
  *args: Content | str,
671
- echo: Literal["text", "all", "none"] = "none",
728
+ echo: EchoOptions,
729
+ content: Literal["text"],
730
+ kwargs: Optional[SubmitInputArgsT],
731
+ ) -> AsyncGenerator[str, None]: ...
732
+
733
+ @overload
734
+ async def stream_async(
735
+ self,
736
+ *args: Content | str,
737
+ echo: EchoOptions,
738
+ content: Literal["all"],
739
+ kwargs: Optional[SubmitInputArgsT],
740
+ ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
741
+
742
+ async def stream_async(
743
+ self,
744
+ *args: Content | str,
745
+ echo: EchoOptions = "none",
746
+ content: Literal["text", "all"] = "text",
672
747
  kwargs: Optional[SubmitInputArgsT] = None,
673
- ) -> ChatResponseAsync:
748
+ ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
674
749
  """
675
750
  Generate a response from the chat in a streaming fashion asynchronously.
676
751
 
@@ -681,6 +756,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
681
756
  echo
682
757
  Whether to echo text content, all content (i.e., tool calls), or no
683
758
  content.
759
+ content
760
+ Whether to yield just text content, or all content (i.e., tool calls).
684
761
  kwargs
685
762
  Additional keyword arguments to pass to the method used for requesting
686
763
  the response.
@@ -695,24 +772,26 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
695
772
 
696
773
  display = self._markdown_display(echo=echo)
697
774
 
698
- async def wrapper() -> AsyncGenerator[str, None]:
775
+ async def wrapper() -> AsyncGenerator[
776
+ str | ContentToolRequest | ContentToolResult, None
777
+ ]:
699
778
  with display:
700
779
  async for chunk in self._chat_impl_async(
701
780
  turn,
702
781
  stream=True,
703
- display=display,
704
782
  echo=echo,
783
+ content=content,
705
784
  kwargs=kwargs,
706
785
  ):
707
786
  yield chunk
708
787
 
709
- return ChatResponseAsync(wrapper())
788
+ return wrapper()
710
789
 
711
790
  def extract_data(
712
791
  self,
713
792
  *args: Content | str,
714
793
  data_model: type[BaseModel],
715
- echo: Literal["text", "all", "none"] = "none",
794
+ echo: EchoOptions = "none",
716
795
  stream: bool = False,
717
796
  ) -> dict[str, Any]:
718
797
  """
@@ -742,7 +821,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
742
821
  user_turn(*args),
743
822
  data_model=data_model,
744
823
  echo=echo,
745
- display=display,
746
824
  stream=stream,
747
825
  )
748
826
  )
@@ -771,7 +849,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
771
849
  self,
772
850
  *args: Content | str,
773
851
  data_model: type[BaseModel],
774
- echo: Literal["text", "all", "none"] = "none",
852
+ echo: EchoOptions = "none",
775
853
  stream: bool = False,
776
854
  ) -> dict[str, Any]:
777
855
  """
@@ -802,7 +880,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
802
880
  user_turn(*args),
803
881
  data_model=data_model,
804
882
  echo=echo,
805
- display=display,
806
883
  stream=stream,
807
884
  )
808
885
  )
@@ -910,13 +987,65 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
910
987
  tool = Tool(func, model=model)
911
988
  self._tools[tool.name] = tool
912
989
 
990
+ @property
991
+ def current_display(self) -> Optional[MarkdownDisplay]:
992
+ """
993
+ Get the currently active markdown display, if any.
994
+
995
+ The display represents the place where `.chat(echo)` content is
996
+ being displayed. In a notebook/Quarto, this is a wrapper around
997
+ `IPython.display`. Otherwise, it is a wrapper around a
998
+ `rich.live.Live()` console.
999
+
1000
+ This is primarily useful if you want to add custom content to the
1001
+ display while the chat is running, but currently blocked by something
1002
+ like a tool call.
1003
+
1004
+ Example
1005
+ -------
1006
+ ```python
1007
+ import requests
1008
+ from chatlas import ChatOpenAI
1009
+
1010
+ chat = ChatOpenAI()
1011
+
1012
+
1013
+ def get_current_weather(latitude: float, longitude: float):
1014
+ "Get the current weather given a latitude and longitude."
1015
+
1016
+ lat_lng = f"latitude={latitude}&longitude={longitude}"
1017
+ url = f"https://api.open-meteo.com/v1/forecast?{lat_lng}&current=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m"
1018
+ response = requests.get(url)
1019
+ json = response.json()
1020
+ if chat.current_display:
1021
+ chat.current_display.echo("My custom tool display!!!")
1022
+ return json["current"]
1023
+
1024
+
1025
+ chat.register_tool(get_current_weather)
1026
+
1027
+ chat.chat("What's the current temperature in Duluth, MN?", echo="text")
1028
+ ```
1029
+
1030
+
1031
+ Returns
1032
+ -------
1033
+ Optional[MarkdownDisplay]
1034
+ The currently active markdown display, if any.
1035
+ """
1036
+ return self._current_display
1037
+
1038
+ def _echo_content(self, x: str):
1039
+ if self._current_display:
1040
+ self._current_display.echo(x)
1041
+
913
1042
  def export(
914
1043
  self,
915
1044
  filename: str | Path,
916
1045
  *,
917
1046
  turns: Optional[Sequence[Turn]] = None,
918
1047
  title: Optional[str] = None,
919
- include: Literal["text", "all"] = "text",
1048
+ content: Literal["text", "all"] = "text",
920
1049
  include_system_prompt: bool = True,
921
1050
  overwrite: bool = False,
922
1051
  ):
@@ -935,7 +1064,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
935
1064
  A title to place at the top of the exported file.
936
1065
  overwrite
937
1066
  Whether to overwrite the file if it already exists.
938
- include
1067
+ content
939
1068
  Whether to include text content, all content (i.e., tool calls), or no
940
1069
  content.
941
1070
  include_system_prompt
@@ -971,9 +1100,9 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
971
1100
  for turn in turns:
972
1101
  turn_content = "\n\n".join(
973
1102
  [
974
- str(content).strip()
975
- for content in turn.contents
976
- if include == "all" or isinstance(content, ContentText)
1103
+ str(x).strip()
1104
+ for x in turn.contents
1105
+ if content == "all" or isinstance(x, ContentText)
977
1106
  ]
978
1107
  )
979
1108
  if is_html:
@@ -1033,51 +1162,130 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1033
1162
  </html>
1034
1163
  """
1035
1164
 
1165
+ @overload
1036
1166
  def _chat_impl(
1037
1167
  self,
1038
1168
  user_turn: Turn,
1039
- echo: Literal["text", "all", "none"],
1040
- display: MarkdownDisplay,
1169
+ echo: EchoOptions,
1170
+ content: Literal["text"],
1041
1171
  stream: bool,
1042
1172
  kwargs: Optional[SubmitInputArgsT] = None,
1043
- ) -> Generator[str, None, None]:
1173
+ ) -> Generator[str, None, None]: ...
1174
+
1175
+ @overload
1176
+ def _chat_impl(
1177
+ self,
1178
+ user_turn: Turn,
1179
+ echo: EchoOptions,
1180
+ content: Literal["all"],
1181
+ stream: bool,
1182
+ kwargs: Optional[SubmitInputArgsT] = None,
1183
+ ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
1184
+
1185
+ def _chat_impl(
1186
+ self,
1187
+ user_turn: Turn,
1188
+ echo: EchoOptions,
1189
+ content: Literal["text", "all"],
1190
+ stream: bool,
1191
+ kwargs: Optional[SubmitInputArgsT] = None,
1192
+ ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
1044
1193
  user_turn_result: Turn | None = user_turn
1045
1194
  while user_turn_result is not None:
1046
1195
  for chunk in self._submit_turns(
1047
1196
  user_turn_result,
1048
1197
  echo=echo,
1049
- display=display,
1050
1198
  stream=stream,
1051
1199
  kwargs=kwargs,
1052
1200
  ):
1053
1201
  yield chunk
1054
- user_turn_result = self._invoke_tools()
1202
+
1203
+ turn = self.get_last_turn(role="assistant")
1204
+ assert turn is not None
1205
+ user_turn_result = None
1206
+
1207
+ results: list[ContentToolResult] = []
1208
+ for x in turn.contents:
1209
+ if isinstance(x, ContentToolRequest):
1210
+ if echo == "output":
1211
+ self._echo_content(f"\n\n{x}\n\n")
1212
+ if content == "all":
1213
+ yield x
1214
+ res = self._invoke_tool(x)
1215
+ if echo == "output":
1216
+ self._echo_content(f"\n\n{res}\n\n")
1217
+ if content == "all":
1218
+ yield res
1219
+ results.append(res)
1220
+
1221
+ if results:
1222
+ user_turn_result = Turn("user", results)
1223
+
1224
+ @overload
1225
+ def _chat_impl_async(
1226
+ self,
1227
+ user_turn: Turn,
1228
+ echo: EchoOptions,
1229
+ content: Literal["text"],
1230
+ stream: bool,
1231
+ kwargs: Optional[SubmitInputArgsT] = None,
1232
+ ) -> AsyncGenerator[str, None]: ...
1233
+
1234
+ @overload
1235
+ def _chat_impl_async(
1236
+ self,
1237
+ user_turn: Turn,
1238
+ echo: EchoOptions,
1239
+ content: Literal["all"],
1240
+ stream: bool,
1241
+ kwargs: Optional[SubmitInputArgsT] = None,
1242
+ ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
1055
1243
 
1056
1244
  async def _chat_impl_async(
1057
1245
  self,
1058
1246
  user_turn: Turn,
1059
- echo: Literal["text", "all", "none"],
1060
- display: MarkdownDisplay,
1247
+ echo: EchoOptions,
1248
+ content: Literal["text", "all"],
1061
1249
  stream: bool,
1062
1250
  kwargs: Optional[SubmitInputArgsT] = None,
1063
- ) -> AsyncGenerator[str, None]:
1251
+ ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
1064
1252
  user_turn_result: Turn | None = user_turn
1065
1253
  while user_turn_result is not None:
1066
1254
  async for chunk in self._submit_turns_async(
1067
1255
  user_turn_result,
1068
1256
  echo=echo,
1069
- display=display,
1070
1257
  stream=stream,
1071
1258
  kwargs=kwargs,
1072
1259
  ):
1073
1260
  yield chunk
1074
- user_turn_result = await self._invoke_tools_async()
1261
+
1262
+ turn = self.get_last_turn(role="assistant")
1263
+ assert turn is not None
1264
+ user_turn_result = None
1265
+
1266
+ results: list[ContentToolResult] = []
1267
+ for x in turn.contents:
1268
+ if isinstance(x, ContentToolRequest):
1269
+ if echo == "output":
1270
+ self._echo_content(f"\n\n{x}\n\n")
1271
+ if content == "all":
1272
+ yield x
1273
+ res = await self._invoke_tool_async(x)
1274
+ if echo == "output":
1275
+ self._echo_content(f"\n\n{res}\n\n")
1276
+ if content == "all":
1277
+ yield res
1278
+ else:
1279
+ yield "\n\n"
1280
+ results.append(res)
1281
+
1282
+ if results:
1283
+ user_turn_result = Turn("user", results)
1075
1284
 
1076
1285
  def _submit_turns(
1077
1286
  self,
1078
1287
  user_turn: Turn,
1079
- echo: Literal["text", "all", "none"],
1080
- display: MarkdownDisplay,
1288
+ echo: EchoOptions,
1081
1289
  stream: bool,
1082
1290
  data_model: type[BaseModel] | None = None,
1083
1291
  kwargs: Optional[SubmitInputArgsT] = None,
@@ -1086,7 +1294,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1086
1294
  raise ValueError("Cannot use async tools in a synchronous chat")
1087
1295
 
1088
1296
  def emit(text: str | Content):
1089
- display.update(str(text))
1297
+ self._echo_content(str(text))
1090
1298
 
1091
1299
  emit("<br>\n\n")
1092
1300
 
@@ -1142,14 +1350,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1142
1350
  async def _submit_turns_async(
1143
1351
  self,
1144
1352
  user_turn: Turn,
1145
- echo: Literal["text", "all", "none"],
1146
- display: MarkdownDisplay,
1353
+ echo: EchoOptions,
1147
1354
  stream: bool,
1148
1355
  data_model: type[BaseModel] | None = None,
1149
1356
  kwargs: Optional[SubmitInputArgsT] = None,
1150
1357
  ) -> AsyncGenerator[str, None]:
1151
1358
  def emit(text: str | Content):
1152
- display.update(str(text))
1359
+ self._echo_content(str(text))
1153
1360
 
1154
1361
  emit("<br>\n\n")
1155
1362
 
@@ -1202,92 +1409,74 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1202
1409
 
1203
1410
  self._turns.extend([user_turn, turn])
1204
1411
 
1205
- def _invoke_tools(self) -> Turn | None:
1206
- turn = self.get_last_turn()
1207
- if turn is None:
1208
- return None
1209
-
1210
- results: list[ContentToolResult] = []
1211
- for x in turn.contents:
1212
- if isinstance(x, ContentToolRequest):
1213
- tool_def = self._tools.get(x.name, None)
1214
- func = tool_def.func if tool_def is not None else None
1215
- results.append(self._invoke_tool(func, x.arguments, x.id))
1216
-
1217
- if not results:
1218
- return None
1219
-
1220
- return Turn("user", results)
1221
-
1222
- async def _invoke_tools_async(self) -> Turn | None:
1223
- turn = self.get_last_turn()
1224
- if turn is None:
1225
- return None
1226
-
1227
- results: list[ContentToolResult] = []
1228
- for x in turn.contents:
1229
- if isinstance(x, ContentToolRequest):
1230
- tool_def = self._tools.get(x.name, None)
1231
- func = None
1232
- if tool_def:
1233
- if tool_def._is_async:
1234
- func = tool_def.func
1235
- else:
1236
- func = wrap_async(tool_def.func)
1237
- results.append(await self._invoke_tool_async(func, x.arguments, x.id))
1238
-
1239
- if not results:
1240
- return None
1412
+ def _invoke_tool(self, x: ContentToolRequest) -> ContentToolResult:
1413
+ tool_def = self._tools.get(x.name, None)
1414
+ func = tool_def.func if tool_def is not None else None
1241
1415
 
1242
- return Turn("user", results)
1243
-
1244
- @staticmethod
1245
- def _invoke_tool(
1246
- func: Callable[..., Any] | None,
1247
- arguments: object,
1248
- id_: str,
1249
- ) -> ContentToolResult:
1250
1416
  if func is None:
1251
- return ContentToolResult(id=id_, value=None, error="Unknown tool")
1417
+ e = RuntimeError(f"Unknown tool: {x.name}")
1418
+ return ContentToolResult(value=None, error=e, request=x)
1252
1419
 
1253
- name = func.__name__
1420
+ args = x.arguments
1254
1421
 
1255
1422
  try:
1256
- if isinstance(arguments, dict):
1257
- result = func(**arguments)
1423
+ if isinstance(args, dict):
1424
+ result = func(**args)
1258
1425
  else:
1259
- result = func(arguments)
1426
+ result = func(args)
1260
1427
 
1261
- return ContentToolResult(id=id_, value=result, error=None, name=name)
1428
+ if not isinstance(result, ContentToolResult):
1429
+ result = ContentToolResult(value=result)
1430
+
1431
+ result.request = x
1432
+ return result
1262
1433
  except Exception as e:
1263
- log_tool_error(name, str(arguments), e)
1264
- return ContentToolResult(id=id_, value=None, error=str(e), name=name)
1434
+ warnings.warn(
1435
+ f"Calling tool '{x.name}' led to an error.",
1436
+ ToolFailureWarning,
1437
+ stacklevel=2,
1438
+ )
1439
+ traceback.print_exc()
1440
+ log_tool_error(x.name, str(args), e)
1441
+ return ContentToolResult(value=None, error=e, request=x)
1442
+
1443
+ async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
1444
+ tool_def = self._tools.get(x.name, None)
1445
+ func = None
1446
+ if tool_def:
1447
+ if tool_def._is_async:
1448
+ func = tool_def.func
1449
+ else:
1450
+ func = wrap_async(tool_def.func)
1265
1451
 
1266
- @staticmethod
1267
- async def _invoke_tool_async(
1268
- func: Callable[..., Awaitable[Any]] | None,
1269
- arguments: object,
1270
- id_: str,
1271
- ) -> ContentToolResult:
1272
1452
  if func is None:
1273
- return ContentToolResult(id=id_, value=None, error="Unknown tool")
1453
+ e = RuntimeError(f"Unknown tool: {x.name}")
1454
+ return ContentToolResult(value=None, error=e, request=x)
1274
1455
 
1275
- name = func.__name__
1456
+ args = x.arguments
1276
1457
 
1277
1458
  try:
1278
- if isinstance(arguments, dict):
1279
- result = await func(**arguments)
1459
+ if isinstance(args, dict):
1460
+ result = await func(**args)
1280
1461
  else:
1281
- result = await func(arguments)
1462
+ result = await func(args)
1282
1463
 
1283
- return ContentToolResult(id=id_, value=result, error=None, name=name)
1464
+ if not isinstance(result, ContentToolResult):
1465
+ result = ContentToolResult(value=result)
1466
+
1467
+ result.request = x
1468
+ return result
1284
1469
  except Exception as e:
1285
- log_tool_error(func.__name__, str(arguments), e)
1286
- return ContentToolResult(id=id_, value=None, error=str(e), name=name)
1470
+ warnings.warn(
1471
+ f"Calling tool '{x.name}' led to an error.",
1472
+ ToolFailureWarning,
1473
+ stacklevel=2,
1474
+ )
1475
+ traceback.print_exc()
1476
+ log_tool_error(x.name, str(args), e)
1477
+ return ContentToolResult(value=None, error=e, request=x)
1287
1478
 
1288
- def _markdown_display(
1289
- self, echo: Literal["text", "all", "none"]
1290
- ) -> MarkdownDisplay:
1479
+ def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
1291
1480
  """
1292
1481
  Get a markdown display object based on the echo option.
1293
1482
 
@@ -1296,7 +1485,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1296
1485
  screen sizes.
1297
1486
  """
1298
1487
  if echo == "none":
1299
- return MockMarkdownDisplay()
1488
+ return ChatMarkdownDisplay(MockMarkdownDisplay(), self)
1300
1489
 
1301
1490
  # rich does a lot to detect a notebook environment, but it doesn't
1302
1491
  # detect Quarto (at least not yet).
@@ -1305,10 +1494,13 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1305
1494
  is_web = Console().is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
1306
1495
 
1307
1496
  opts = self._echo_options
1497
+
1308
1498
  if is_web:
1309
- return IPyMarkdownDisplay(opts)
1499
+ display = IPyMarkdownDisplay(opts)
1310
1500
  else:
1311
- return LiveMarkdownDisplay(opts)
1501
+ display = LiveMarkdownDisplay(opts)
1502
+
1503
+ return ChatMarkdownDisplay(display, self)
1312
1504
 
1313
1505
  def set_echo_options(
1314
1506
  self,
@@ -1331,7 +1523,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1331
1523
  A dictionary of CSS styles to apply to `IPython.display.Markdown()`.
1332
1524
  This is only relevant when outputing to the browser.
1333
1525
  """
1334
- self._echo_options: EchoOptions = {
1526
+ self._echo_options: EchoDisplayOptions = {
1335
1527
  "rich_markdown": rich_markdown or {},
1336
1528
  "rich_console": rich_console or {},
1337
1529
  "css_styles": css_styles or {},
@@ -1501,3 +1693,30 @@ def emit_other_contents(
1501
1693
  to_emit.reverse()
1502
1694
 
1503
1695
  emit("\n\n".join(to_emit))
1696
+
1697
+
1698
+ # Helper/wrapper class to let Chat know about the currently active display
1699
+ class ChatMarkdownDisplay:
1700
+ def __init__(self, display: MarkdownDisplay, chat: Chat):
1701
+ self._display = display
1702
+ self._chat = chat
1703
+
1704
+ def __enter__(self):
1705
+ self._chat._current_display = self._display
1706
+ return self._display.__enter__()
1707
+
1708
+ def __exit__(self, *args, **kwargs):
1709
+ result = self._display.__exit__(*args, **kwargs)
1710
+ self._chat._current_display = None
1711
+ return result
1712
+
1713
+ def append(self, content):
1714
+ return self._display.echo(content)
1715
+
1716
+
1717
+ class ToolFailureWarning(RuntimeWarning):
1718
+ pass
1719
+
1720
+
1721
+ # By default warnings are shown once; we want to always show them.
1722
+ warnings.simplefilter("always", ToolFailureWarning)