chatlas 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/__init__.py CHANGED
@@ -16,7 +16,7 @@ from ._perplexity import ChatPerplexity
16
16
  from ._provider import Provider
17
17
  from ._snowflake import ChatSnowflake
18
18
  from ._tokens import token_usage
19
- from ._tools import Tool
19
+ from ._tools import Tool, ToolRejectError
20
20
  from ._turn import Turn
21
21
 
22
22
  try:
@@ -51,6 +51,7 @@ __all__ = (
51
51
  "Provider",
52
52
  "token_usage",
53
53
  "Tool",
54
+ "ToolRejectError",
54
55
  "Turn",
55
56
  "types",
56
57
  )
chatlas/_anthropic.py CHANGED
@@ -451,10 +451,7 @@ class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]):
451
451
  @staticmethod
452
452
  def _as_content_block(content: Content) -> "ContentBlockParam":
453
453
  if isinstance(content, ContentText):
454
- text = content.text
455
- if text == "" or text.isspace():
456
- text = "[empty string]"
457
- return {"type": "text", "text": text}
454
+ return {"text": content.text, "type": "text"}
458
455
  elif isinstance(content, ContentJson):
459
456
  return {"text": "<structured data/>", "type": "text"}
460
457
  elif isinstance(content, ContentPDF):
chatlas/_callbacks.py ADDED
@@ -0,0 +1,56 @@
1
+ from collections import OrderedDict
2
+ from typing import Any, Callable
3
+
4
+ from ._utils import is_async_callable
5
+
6
+
7
+ class CallbackManager:
8
+ def __init__(self) -> None:
9
+ self._callbacks: dict[str, Callable[..., Any]] = OrderedDict()
10
+ self._id: int = 1
11
+
12
+ def add(self, callback: Callable[..., Any]) -> Callable[[], None]:
13
+ callback_id = self._next_id()
14
+ self._callbacks[callback_id] = callback
15
+
16
+ def _rm_callback() -> None:
17
+ self._callbacks.pop(callback_id, None)
18
+
19
+ return _rm_callback
20
+
21
+ def invoke(self, *args: Any, **kwargs: Any) -> None:
22
+ if not self._callbacks:
23
+ return
24
+
25
+ # Invoke in reverse insertion order
26
+ for callback_id in reversed(list(self._callbacks.keys())):
27
+ callback = self._callbacks[callback_id]
28
+ if is_async_callable(callback):
29
+ raise RuntimeError(
30
+ "Can't use async callbacks with `.chat()`/`.stream()`."
31
+ "Async callbacks can only be used with `.chat_async()`/`.stream_async()`."
32
+ )
33
+ callback(*args, **kwargs)
34
+
35
+ async def invoke_async(self, *args: Any, **kwargs: Any) -> None:
36
+ if not self._callbacks:
37
+ return
38
+
39
+ # Invoke in reverse insertion order
40
+ for callback_id in reversed(list(self._callbacks.keys())):
41
+ callback = self._callbacks[callback_id]
42
+ if is_async_callable(callback):
43
+ await callback(*args, **kwargs)
44
+ else:
45
+ callback(*args, **kwargs)
46
+
47
+ def count(self) -> int:
48
+ return len(self._callbacks)
49
+
50
+ def get_callbacks(self) -> list[Callable[..., Any]]:
51
+ return list(self._callbacks.values())
52
+
53
+ def _next_id(self) -> str:
54
+ current_id = self._id
55
+ self._id += 1
56
+ return str(current_id)
chatlas/_chat.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import inspect
4
5
  import os
5
6
  import sys
@@ -25,6 +26,7 @@ from typing import (
25
26
 
26
27
  from pydantic import BaseModel
27
28
 
29
+ from ._callbacks import CallbackManager
28
30
  from ._content import (
29
31
  Content,
30
32
  ContentJson,
@@ -41,7 +43,7 @@ from ._display import (
41
43
  )
42
44
  from ._logging import log_tool_error
43
45
  from ._provider import Provider
44
- from ._tools import Tool
46
+ from ._tools import Tool, ToolRejectError
45
47
  from ._turn import Turn, user_turn
46
48
  from ._typing_extensions import TypedDict
47
49
  from ._utils import html_escape, wrap_async
@@ -95,6 +97,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
95
97
  self.provider = provider
96
98
  self._turns: list[Turn] = list(turns or [])
97
99
  self._tools: dict[str, Tool] = {}
100
+ self._on_tool_request_callbacks = CallbackManager()
101
+ self._on_tool_result_callbacks = CallbackManager()
98
102
  self._current_display: Optional[MarkdownDisplay] = None
99
103
  self._echo_options: EchoDisplayOptions = {
100
104
  "rich_markdown": {},
@@ -631,31 +635,18 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
631
635
  def stream(
632
636
  self,
633
637
  *args: Content | str,
634
- ) -> Generator[str, None, None]: ...
635
-
636
- @overload
637
- def stream(
638
- self,
639
- *args: Content | str,
640
- echo: EchoOptions,
641
- ) -> Generator[str, None, None]: ...
642
-
643
- @overload
644
- def stream(
645
- self,
646
- *args: Content | str,
647
- echo: EchoOptions,
648
638
  content: Literal["text"],
649
- kwargs: Optional[SubmitInputArgsT],
639
+ echo: EchoOptions = "none",
640
+ kwargs: Optional[SubmitInputArgsT] = None,
650
641
  ) -> Generator[str, None, None]: ...
651
642
 
652
643
  @overload
653
644
  def stream(
654
645
  self,
655
646
  *args: Content | str,
656
- echo: EchoOptions,
657
647
  content: Literal["all"],
658
- kwargs: Optional[SubmitInputArgsT],
648
+ echo: EchoOptions = "none",
649
+ kwargs: Optional[SubmitInputArgsT] = None,
659
650
  ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
660
651
 
661
652
  def stream(
@@ -712,31 +703,18 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
712
703
  async def stream_async(
713
704
  self,
714
705
  *args: Content | str,
715
- ) -> AsyncGenerator[str, None]: ...
716
-
717
- @overload
718
- async def stream_async(
719
- self,
720
- *args: Content | str,
721
- echo: EchoOptions,
722
- ) -> AsyncGenerator[str, None]: ...
723
-
724
- @overload
725
- async def stream_async(
726
- self,
727
- *args: Content | str,
728
- echo: EchoOptions,
729
706
  content: Literal["text"],
730
- kwargs: Optional[SubmitInputArgsT],
707
+ echo: EchoOptions = "none",
708
+ kwargs: Optional[SubmitInputArgsT] = None,
731
709
  ) -> AsyncGenerator[str, None]: ...
732
710
 
733
711
  @overload
734
712
  async def stream_async(
735
713
  self,
736
714
  *args: Content | str,
737
- echo: EchoOptions,
738
715
  content: Literal["all"],
739
- kwargs: Optional[SubmitInputArgsT],
716
+ echo: EchoOptions = "none",
717
+ kwargs: Optional[SubmitInputArgsT] = None,
740
718
  ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
741
719
 
742
720
  async def stream_async(
@@ -987,6 +965,53 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
987
965
  tool = Tool(func, model=model)
988
966
  self._tools[tool.name] = tool
989
967
 
968
+ def on_tool_request(self, callback: Callable[[ContentToolRequest], None]):
969
+ """
970
+ Register a callback for a tool request event.
971
+
972
+ A tool request event occurs when the assistant requests a tool to be
973
+ called on its behalf. Before invoking the tool, `on_tool_request`
974
+ handlers are called with the relevant `ContentToolRequest` object. This
975
+ is useful if you want to handle tool requests in a custom way, such as
976
+ requiring logging them or requiring user approval before invoking the
977
+ tool
978
+
979
+ Parameters
980
+ ----------
981
+ callback
982
+ A function to be called when a tool request event occurs.
983
+ This function must have a single argument, which will be the
984
+ tool request (i.e., a `ContentToolRequest` object).
985
+
986
+ Returns
987
+ -------
988
+ A callable that can be used to remove the callback later.
989
+ """
990
+ return self._on_tool_request_callbacks.add(callback)
991
+
992
+ def on_tool_result(self, callback: Callable[[ContentToolResult], None]):
993
+ """
994
+ Register a callback for a tool result event.
995
+
996
+ A tool result event occurs when a tool has been invoked and the
997
+ result is ready to be provided to the assistant. After the tool
998
+ has been invoked, `on_tool_result` handlers are called with the
999
+ relevant `ContentToolResult` object. This is useful if you want to
1000
+ handle tool results in a custom way such as logging them.
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ callback
1005
+ A function to be called when a tool result event occurs.
1006
+ This function must have a single argument, which will be the
1007
+ tool result (i.e., a `ContentToolResult` object).
1008
+
1009
+ Returns
1010
+ -------
1011
+ A callable that can be used to remove the callback later.
1012
+ """
1013
+ return self._on_tool_result_callbacks.add(callback)
1014
+
990
1015
  @property
991
1016
  def current_display(self) -> Optional[MarkdownDisplay]:
992
1017
  """
@@ -1417,28 +1442,43 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1417
1442
  e = RuntimeError(f"Unknown tool: {x.name}")
1418
1443
  return ContentToolResult(value=None, error=e, request=x)
1419
1444
 
1420
- args = x.arguments
1421
-
1445
+ # First, invoke the request callbacks. If a ToolRejectError is raised,
1446
+ # treat it like a tool failure (i.e., gracefully handle it).
1447
+ result: ContentToolResult | None = None
1422
1448
  try:
1423
- if isinstance(args, dict):
1424
- result = func(**args)
1425
- else:
1426
- result = func(args)
1449
+ self._on_tool_request_callbacks.invoke(x)
1450
+ except ToolRejectError as e:
1451
+ result = ContentToolResult(value=None, error=e, request=x)
1452
+
1453
+ # Invoke the tool (if it hasn't been rejected).
1454
+ if result is None:
1455
+ try:
1456
+ if isinstance(x.arguments, dict):
1457
+ res = func(**x.arguments)
1458
+ else:
1459
+ res = func(x.arguments)
1460
+
1461
+ if isinstance(res, ContentToolResult):
1462
+ result = res
1463
+ else:
1464
+ result = ContentToolResult(value=res)
1427
1465
 
1428
- if not isinstance(result, ContentToolResult):
1429
- result = ContentToolResult(value=result)
1466
+ result.request = x
1467
+ except Exception as e:
1468
+ result = ContentToolResult(value=None, error=e, request=x)
1430
1469
 
1431
- result.request = x
1432
- return result
1433
- except Exception as e:
1470
+ # If we've captured an error, notify and log it.
1471
+ if result.error:
1434
1472
  warnings.warn(
1435
1473
  f"Calling tool '{x.name}' led to an error.",
1436
1474
  ToolFailureWarning,
1437
1475
  stacklevel=2,
1438
1476
  )
1439
1477
  traceback.print_exc()
1440
- log_tool_error(x.name, str(args), e)
1441
- return ContentToolResult(value=None, error=e, request=x)
1478
+ log_tool_error(x.name, str(x.arguments), result.error)
1479
+
1480
+ self._on_tool_result_callbacks.invoke(result)
1481
+ return result
1442
1482
 
1443
1483
  async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
1444
1484
  tool_def = self._tools.get(x.name, None)
@@ -1453,28 +1493,43 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1453
1493
  e = RuntimeError(f"Unknown tool: {x.name}")
1454
1494
  return ContentToolResult(value=None, error=e, request=x)
1455
1495
 
1456
- args = x.arguments
1457
-
1496
+ # First, invoke the request callbacks. If a ToolRejectError is raised,
1497
+ # treat it like a tool failure (i.e., gracefully handle it).
1498
+ result: ContentToolResult | None = None
1458
1499
  try:
1459
- if isinstance(args, dict):
1460
- result = await func(**args)
1461
- else:
1462
- result = await func(args)
1500
+ await self._on_tool_request_callbacks.invoke_async(x)
1501
+ except ToolRejectError as e:
1502
+ result = ContentToolResult(value=None, error=e, request=x)
1503
+
1504
+ # Invoke the tool (if it hasn't been rejected).
1505
+ if result is None:
1506
+ try:
1507
+ if isinstance(x.arguments, dict):
1508
+ res = await func(**x.arguments)
1509
+ else:
1510
+ res = await func(x.arguments)
1511
+
1512
+ if isinstance(res, ContentToolResult):
1513
+ result = res
1514
+ else:
1515
+ result = ContentToolResult(value=res)
1463
1516
 
1464
- if not isinstance(result, ContentToolResult):
1465
- result = ContentToolResult(value=result)
1517
+ result.request = x
1518
+ except Exception as e:
1519
+ result = ContentToolResult(value=None, error=e, request=x)
1466
1520
 
1467
- result.request = x
1468
- return result
1469
- except Exception as e:
1521
+ # If we've captured an error, notify and log it.
1522
+ if result.error:
1470
1523
  warnings.warn(
1471
1524
  f"Calling tool '{x.name}' led to an error.",
1472
1525
  ToolFailureWarning,
1473
1526
  stacklevel=2,
1474
1527
  )
1475
1528
  traceback.print_exc()
1476
- log_tool_error(x.name, str(args), e)
1477
- return ContentToolResult(value=None, error=e, request=x)
1529
+ log_tool_error(x.name, str(x.arguments), result.error)
1530
+
1531
+ await self._on_tool_result_callbacks.invoke_async(result)
1532
+ return result
1478
1533
 
1479
1534
  def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
1480
1535
  """
@@ -1545,6 +1600,21 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1545
1600
  res += "\n" + turn.__repr__(indent=2)
1546
1601
  return res + "\n"
1547
1602
 
1603
+ def __deepcopy__(self, memo):
1604
+ result = self.__class__.__new__(self.__class__)
1605
+
1606
+ # Avoid recursive references
1607
+ memo[id(self)] = result
1608
+
1609
+ # Copy all attributes except the problematic provider attribute
1610
+ for key, value in self.__dict__.items():
1611
+ if key != "provider":
1612
+ setattr(result, key, copy.deepcopy(value, memo))
1613
+ else:
1614
+ setattr(result, key, value)
1615
+
1616
+ return result
1617
+
1548
1618
 
1549
1619
  class ChatResponse:
1550
1620
  """
chatlas/_content.py CHANGED
@@ -60,6 +60,12 @@ class ContentText(Content):
60
60
  text: str
61
61
  content_type: ContentTypeEnum = "text"
62
62
 
63
+ def __init__(self, **data: Any):
64
+ super().__init__(**data)
65
+
66
+ if self.text == "" or self.text.isspace():
67
+ self.text = "[empty string]"
68
+
63
69
  def __str__(self):
64
70
  return self.text
65
71
 
chatlas/_databricks.py CHANGED
@@ -85,7 +85,7 @@ def ChatDatabricks(
85
85
  A chat object that retains the state of the conversation.
86
86
  """
87
87
  if model is None:
88
- model = log_model_default("databricks-dbrx-instruct")
88
+ model = log_model_default("databricks-claude-3-7-sonnet")
89
89
 
90
90
  return Chat(
91
91
  provider=DatabricksProvider(
@@ -111,17 +111,11 @@ class DatabricksProvider(OpenAIProvider):
111
111
  except ImportError:
112
112
  raise ImportError(
113
113
  "`ChatDatabricks()` requires the `databricks-sdk` package. "
114
- "Install it with `pip install databricks-sdk[openai]`."
114
+ "Install it with `pip install databricks-sdk`."
115
115
  )
116
116
 
117
- try:
118
- import httpx
119
- from openai import AsyncOpenAI
120
- except ImportError:
121
- raise ImportError(
122
- "`ChatDatabricks()` requires the `openai` package. "
123
- "Install it with `pip install openai`."
124
- )
117
+ import httpx
118
+ from openai import AsyncOpenAI
125
119
 
126
120
  self._model = model
127
121
  self._seed = None
chatlas/_github.py CHANGED
@@ -40,12 +40,6 @@ def ChatGithub(
40
40
  You may need to apply for and be accepted into a beta access program.
41
41
  :::
42
42
 
43
- ::: {.callout-note}
44
- ## Python requirements
45
-
46
- `ChatGithub` requires the `openai` package: `pip install "chatlas[github]"`.
47
- :::
48
-
49
43
 
50
44
  Examples
51
45
  --------
chatlas/_google.py CHANGED
@@ -291,7 +291,8 @@ class GoogleProvider(
291
291
  GoogleTool(
292
292
  function_declarations=[
293
293
  FunctionDeclaration.from_callable(
294
- client=self._client, callable=tool.func
294
+ client=self._client._api_client,
295
+ callable=tool.func,
295
296
  )
296
297
  for tool in tools.values()
297
298
  ]
chatlas/_groq.py CHANGED
@@ -38,12 +38,6 @@ def ChatGroq(
38
38
  Sign up at <https://groq.com> to get an API key.
39
39
  :::
40
40
 
41
- ::: {.callout-note}
42
- ## Python requirements
43
-
44
- `ChatGroq` requires the `openai` package: `pip install "chatlas[groq]"`.
45
- :::
46
-
47
41
  Examples
48
42
  --------
49
43
 
chatlas/_logging.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import warnings
4
+ from typing import Literal
4
5
 
5
6
  from rich.logging import RichHandler
6
7
 
@@ -12,15 +13,38 @@ def _rich_handler() -> RichHandler:
12
13
  return handler
13
14
 
14
15
 
15
- logger = logging.getLogger("chatlas")
16
-
17
- if os.environ.get("CHATLAS_LOG") == "info":
16
+ def setup_logger(x: str, level: Literal["debug", "info"]) -> logging.Logger:
17
+ logger = logging.getLogger(x)
18
+ if level == "debug":
19
+ logger.setLevel(logging.DEBUG)
20
+ elif level == "info":
21
+ logger.setLevel(logging.INFO)
18
22
  # By adding a RichHandler to chatlas' logger, we can guarantee that they
19
23
  # never get dropped, even if the root logger's handlers are not
20
24
  # RichHandlers.
21
- logger.setLevel(logging.INFO)
22
- logger.addHandler(_rich_handler())
25
+ if not any(isinstance(h, RichHandler) for h in logger.handlers):
26
+ logger.addHandler(_rich_handler())
23
27
  logger.propagate = False
28
+ return logger
29
+
30
+
31
+ logger = logging.getLogger("chatlas")
32
+ log_level = os.environ.get("CHATLAS_LOG")
33
+ if log_level:
34
+ if log_level != "debug" and log_level != "info":
35
+ warnings.warn(
36
+ f"CHATLAS_LOG is set to '{log_level}', but the log level must "
37
+ "be one of 'debug' or 'info'. Defaulting to 'info'.",
38
+ )
39
+ log_level = "info"
40
+
41
+ # Manually setup the logger for each dependency we care about. This way, we
42
+ # can ensure that the logs won't get dropped when a rich display is activate
43
+ logger = setup_logger("chatlas", log_level)
44
+ openai_logger = setup_logger("openai", log_level)
45
+ anthropic_logger = setup_logger("anthropic", log_level)
46
+ google_logger = setup_logger("google_genai.models", log_level)
47
+ httpx_logger = setup_logger("httpx", log_level)
24
48
 
25
49
  # Add a RichHandler to the root logger if there are no handlers. Note that
26
50
  # if chatlas is imported before other libraries that set up logging, (like
chatlas/_ollama.py CHANGED
@@ -49,12 +49,6 @@ def ChatOllama(
49
49
  (e.g. `ollama pull llama3.2`).
50
50
  :::
51
51
 
52
- ::: {.callout-note}
53
- ## Python requirements
54
-
55
- `ChatOllama` requires the `openai` package: `pip install "chatlas[ollama]"`.
56
- :::
57
-
58
52
 
59
53
  Examples
60
54
  --------
chatlas/_openai.py CHANGED
@@ -78,12 +78,6 @@ def ChatOpenAI(
78
78
  account that will give you an API key that you can use with this package.
79
79
  :::
80
80
 
81
- ::: {.callout-note}
82
- ## Python requirements
83
-
84
- `ChatOpenAI` requires the `openai` package: `pip install "chatlas[openai]"`.
85
- :::
86
-
87
81
  Examples
88
82
  --------
89
83
  ```python
@@ -194,13 +188,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
194
188
  seed: Optional[int] = None,
195
189
  kwargs: Optional["ChatClientArgs"] = None,
196
190
  ):
197
- try:
198
- from openai import AsyncOpenAI, OpenAI
199
- except ImportError:
200
- raise ImportError(
201
- "`ChatOpenAI()` requires the `openai` package. "
202
- "Install it with `pip install openai`."
203
- )
191
+ from openai import AsyncOpenAI, OpenAI
204
192
 
205
193
  self._model = model
206
194
  self._seed = seed
@@ -433,7 +421,9 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
433
421
  "id": x.id,
434
422
  "function": {
435
423
  "name": x.name,
436
- "arguments": orjson.dumps(x.arguments).decode("utf-8"),
424
+ "arguments": orjson.dumps(x.arguments).decode(
425
+ "utf-8"
426
+ ),
437
427
  },
438
428
  "type": "function",
439
429
  }
@@ -602,16 +592,6 @@ def ChatAzureOpenAI(
602
592
  hosts a number of open source models as well as proprietary models
603
593
  from OpenAI.
604
594
 
605
- Prerequisites
606
- -------------
607
-
608
- ::: {.callout-note}
609
- ## Python requirements
610
-
611
- `ChatAzureOpenAI` requires the `openai` package:
612
- `pip install "chatlas[azure-openai]"`.
613
- :::
614
-
615
595
  Examples
616
596
  --------
617
597
  ```python
@@ -693,13 +673,7 @@ class OpenAIAzureProvider(OpenAIProvider):
693
673
  seed: int | None = None,
694
674
  kwargs: Optional["ChatAzureClientArgs"] = None,
695
675
  ):
696
- try:
697
- from openai import AsyncAzureOpenAI, AzureOpenAI
698
- except ImportError:
699
- raise ImportError(
700
- "`ChatAzureOpenAI()` requires the `openai` package. "
701
- "Install it with `pip install openai`."
702
- )
676
+ from openai import AsyncAzureOpenAI, AzureOpenAI
703
677
 
704
678
  self._model = deployment_id
705
679
  self._seed = seed
chatlas/_perplexity.py CHANGED
@@ -40,12 +40,6 @@ def ChatPerplexity(
40
40
  Sign up at <https://www.perplexity.ai> to get an API key.
41
41
  :::
42
42
 
43
- ::: {.callout-note}
44
- ## Python requirements
45
-
46
- `ChatPerplexity` requires the `openai` package: `pip install "chatlas[perplexity]"`.
47
- :::
48
-
49
43
 
50
44
  Examples
51
45
  --------