chatlas 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +1 -4
- chatlas/_callbacks.py +56 -0
- chatlas/_chat.py +131 -61
- chatlas/_content.py +6 -0
- chatlas/_databricks.py +1 -1
- chatlas/_logging.py +29 -5
- chatlas/_snowflake.py +398 -72
- chatlas/_tools.py +59 -1
- chatlas/_version.py +2 -2
- chatlas/types/anthropic/_submit.py +7 -0
- chatlas/types/openai/_submit.py +1 -0
- {chatlas-0.7.1.dist-info → chatlas-0.8.0.dist-info}/METADATA +2 -2
- {chatlas-0.7.1.dist-info → chatlas-0.8.0.dist-info}/RECORD +16 -17
- chatlas/types/snowflake/__init__.py +0 -8
- chatlas/types/snowflake/_submit.py +0 -24
- {chatlas-0.7.1.dist-info → chatlas-0.8.0.dist-info}/WHEEL +0 -0
- {chatlas-0.7.1.dist-info → chatlas-0.8.0.dist-info}/licenses/LICENSE +0 -0
chatlas/__init__.py
CHANGED
|
@@ -16,7 +16,7 @@ from ._perplexity import ChatPerplexity
|
|
|
16
16
|
from ._provider import Provider
|
|
17
17
|
from ._snowflake import ChatSnowflake
|
|
18
18
|
from ._tokens import token_usage
|
|
19
|
-
from ._tools import Tool
|
|
19
|
+
from ._tools import Tool, ToolRejectError
|
|
20
20
|
from ._turn import Turn
|
|
21
21
|
|
|
22
22
|
try:
|
|
@@ -51,6 +51,7 @@ __all__ = (
|
|
|
51
51
|
"Provider",
|
|
52
52
|
"token_usage",
|
|
53
53
|
"Tool",
|
|
54
|
+
"ToolRejectError",
|
|
54
55
|
"Turn",
|
|
55
56
|
"types",
|
|
56
57
|
)
|
chatlas/_anthropic.py
CHANGED
|
@@ -451,10 +451,7 @@ class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]):
|
|
|
451
451
|
@staticmethod
|
|
452
452
|
def _as_content_block(content: Content) -> "ContentBlockParam":
|
|
453
453
|
if isinstance(content, ContentText):
|
|
454
|
-
text
|
|
455
|
-
if text == "" or text.isspace():
|
|
456
|
-
text = "[empty string]"
|
|
457
|
-
return {"type": "text", "text": text}
|
|
454
|
+
return {"text": content.text, "type": "text"}
|
|
458
455
|
elif isinstance(content, ContentJson):
|
|
459
456
|
return {"text": "<structured data/>", "type": "text"}
|
|
460
457
|
elif isinstance(content, ContentPDF):
|
chatlas/_callbacks.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
from ._utils import is_async_callable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CallbackManager:
|
|
8
|
+
def __init__(self) -> None:
|
|
9
|
+
self._callbacks: dict[str, Callable[..., Any]] = OrderedDict()
|
|
10
|
+
self._id: int = 1
|
|
11
|
+
|
|
12
|
+
def add(self, callback: Callable[..., Any]) -> Callable[[], None]:
|
|
13
|
+
callback_id = self._next_id()
|
|
14
|
+
self._callbacks[callback_id] = callback
|
|
15
|
+
|
|
16
|
+
def _rm_callback() -> None:
|
|
17
|
+
self._callbacks.pop(callback_id, None)
|
|
18
|
+
|
|
19
|
+
return _rm_callback
|
|
20
|
+
|
|
21
|
+
def invoke(self, *args: Any, **kwargs: Any) -> None:
|
|
22
|
+
if not self._callbacks:
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
# Invoke in reverse insertion order
|
|
26
|
+
for callback_id in reversed(list(self._callbacks.keys())):
|
|
27
|
+
callback = self._callbacks[callback_id]
|
|
28
|
+
if is_async_callable(callback):
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
"Can't use async callbacks with `.chat()`/`.stream()`."
|
|
31
|
+
"Async callbacks can only be used with `.chat_async()`/`.stream_async()`."
|
|
32
|
+
)
|
|
33
|
+
callback(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
async def invoke_async(self, *args: Any, **kwargs: Any) -> None:
|
|
36
|
+
if not self._callbacks:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
# Invoke in reverse insertion order
|
|
40
|
+
for callback_id in reversed(list(self._callbacks.keys())):
|
|
41
|
+
callback = self._callbacks[callback_id]
|
|
42
|
+
if is_async_callable(callback):
|
|
43
|
+
await callback(*args, **kwargs)
|
|
44
|
+
else:
|
|
45
|
+
callback(*args, **kwargs)
|
|
46
|
+
|
|
47
|
+
def count(self) -> int:
|
|
48
|
+
return len(self._callbacks)
|
|
49
|
+
|
|
50
|
+
def get_callbacks(self) -> list[Callable[..., Any]]:
|
|
51
|
+
return list(self._callbacks.values())
|
|
52
|
+
|
|
53
|
+
def _next_id(self) -> str:
|
|
54
|
+
current_id = self._id
|
|
55
|
+
self._id += 1
|
|
56
|
+
return str(current_id)
|
chatlas/_chat.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import copy
|
|
3
4
|
import inspect
|
|
4
5
|
import os
|
|
5
6
|
import sys
|
|
@@ -25,6 +26,7 @@ from typing import (
|
|
|
25
26
|
|
|
26
27
|
from pydantic import BaseModel
|
|
27
28
|
|
|
29
|
+
from ._callbacks import CallbackManager
|
|
28
30
|
from ._content import (
|
|
29
31
|
Content,
|
|
30
32
|
ContentJson,
|
|
@@ -41,7 +43,7 @@ from ._display import (
|
|
|
41
43
|
)
|
|
42
44
|
from ._logging import log_tool_error
|
|
43
45
|
from ._provider import Provider
|
|
44
|
-
from ._tools import Tool
|
|
46
|
+
from ._tools import Tool, ToolRejectError
|
|
45
47
|
from ._turn import Turn, user_turn
|
|
46
48
|
from ._typing_extensions import TypedDict
|
|
47
49
|
from ._utils import html_escape, wrap_async
|
|
@@ -95,6 +97,8 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
95
97
|
self.provider = provider
|
|
96
98
|
self._turns: list[Turn] = list(turns or [])
|
|
97
99
|
self._tools: dict[str, Tool] = {}
|
|
100
|
+
self._on_tool_request_callbacks = CallbackManager()
|
|
101
|
+
self._on_tool_result_callbacks = CallbackManager()
|
|
98
102
|
self._current_display: Optional[MarkdownDisplay] = None
|
|
99
103
|
self._echo_options: EchoDisplayOptions = {
|
|
100
104
|
"rich_markdown": {},
|
|
@@ -631,31 +635,18 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
631
635
|
def stream(
|
|
632
636
|
self,
|
|
633
637
|
*args: Content | str,
|
|
634
|
-
) -> Generator[str, None, None]: ...
|
|
635
|
-
|
|
636
|
-
@overload
|
|
637
|
-
def stream(
|
|
638
|
-
self,
|
|
639
|
-
*args: Content | str,
|
|
640
|
-
echo: EchoOptions,
|
|
641
|
-
) -> Generator[str, None, None]: ...
|
|
642
|
-
|
|
643
|
-
@overload
|
|
644
|
-
def stream(
|
|
645
|
-
self,
|
|
646
|
-
*args: Content | str,
|
|
647
|
-
echo: EchoOptions,
|
|
648
638
|
content: Literal["text"],
|
|
649
|
-
|
|
639
|
+
echo: EchoOptions = "none",
|
|
640
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
650
641
|
) -> Generator[str, None, None]: ...
|
|
651
642
|
|
|
652
643
|
@overload
|
|
653
644
|
def stream(
|
|
654
645
|
self,
|
|
655
646
|
*args: Content | str,
|
|
656
|
-
echo: EchoOptions,
|
|
657
647
|
content: Literal["all"],
|
|
658
|
-
|
|
648
|
+
echo: EchoOptions = "none",
|
|
649
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
659
650
|
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
|
|
660
651
|
|
|
661
652
|
def stream(
|
|
@@ -712,31 +703,18 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
712
703
|
async def stream_async(
|
|
713
704
|
self,
|
|
714
705
|
*args: Content | str,
|
|
715
|
-
) -> AsyncGenerator[str, None]: ...
|
|
716
|
-
|
|
717
|
-
@overload
|
|
718
|
-
async def stream_async(
|
|
719
|
-
self,
|
|
720
|
-
*args: Content | str,
|
|
721
|
-
echo: EchoOptions,
|
|
722
|
-
) -> AsyncGenerator[str, None]: ...
|
|
723
|
-
|
|
724
|
-
@overload
|
|
725
|
-
async def stream_async(
|
|
726
|
-
self,
|
|
727
|
-
*args: Content | str,
|
|
728
|
-
echo: EchoOptions,
|
|
729
706
|
content: Literal["text"],
|
|
730
|
-
|
|
707
|
+
echo: EchoOptions = "none",
|
|
708
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
731
709
|
) -> AsyncGenerator[str, None]: ...
|
|
732
710
|
|
|
733
711
|
@overload
|
|
734
712
|
async def stream_async(
|
|
735
713
|
self,
|
|
736
714
|
*args: Content | str,
|
|
737
|
-
echo: EchoOptions,
|
|
738
715
|
content: Literal["all"],
|
|
739
|
-
|
|
716
|
+
echo: EchoOptions = "none",
|
|
717
|
+
kwargs: Optional[SubmitInputArgsT] = None,
|
|
740
718
|
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
|
|
741
719
|
|
|
742
720
|
async def stream_async(
|
|
@@ -987,6 +965,53 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
987
965
|
tool = Tool(func, model=model)
|
|
988
966
|
self._tools[tool.name] = tool
|
|
989
967
|
|
|
968
|
+
def on_tool_request(self, callback: Callable[[ContentToolRequest], None]):
|
|
969
|
+
"""
|
|
970
|
+
Register a callback for a tool request event.
|
|
971
|
+
|
|
972
|
+
A tool request event occurs when the assistant requests a tool to be
|
|
973
|
+
called on its behalf. Before invoking the tool, `on_tool_request`
|
|
974
|
+
handlers are called with the relevant `ContentToolRequest` object. This
|
|
975
|
+
is useful if you want to handle tool requests in a custom way, such as
|
|
976
|
+
requiring logging them or requiring user approval before invoking the
|
|
977
|
+
tool
|
|
978
|
+
|
|
979
|
+
Parameters
|
|
980
|
+
----------
|
|
981
|
+
callback
|
|
982
|
+
A function to be called when a tool request event occurs.
|
|
983
|
+
This function must have a single argument, which will be the
|
|
984
|
+
tool request (i.e., a `ContentToolRequest` object).
|
|
985
|
+
|
|
986
|
+
Returns
|
|
987
|
+
-------
|
|
988
|
+
A callable that can be used to remove the callback later.
|
|
989
|
+
"""
|
|
990
|
+
return self._on_tool_request_callbacks.add(callback)
|
|
991
|
+
|
|
992
|
+
def on_tool_result(self, callback: Callable[[ContentToolResult], None]):
|
|
993
|
+
"""
|
|
994
|
+
Register a callback for a tool result event.
|
|
995
|
+
|
|
996
|
+
A tool result event occurs when a tool has been invoked and the
|
|
997
|
+
result is ready to be provided to the assistant. After the tool
|
|
998
|
+
has been invoked, `on_tool_result` handlers are called with the
|
|
999
|
+
relevant `ContentToolResult` object. This is useful if you want to
|
|
1000
|
+
handle tool results in a custom way such as logging them.
|
|
1001
|
+
|
|
1002
|
+
Parameters
|
|
1003
|
+
----------
|
|
1004
|
+
callback
|
|
1005
|
+
A function to be called when a tool result event occurs.
|
|
1006
|
+
This function must have a single argument, which will be the
|
|
1007
|
+
tool result (i.e., a `ContentToolResult` object).
|
|
1008
|
+
|
|
1009
|
+
Returns
|
|
1010
|
+
-------
|
|
1011
|
+
A callable that can be used to remove the callback later.
|
|
1012
|
+
"""
|
|
1013
|
+
return self._on_tool_result_callbacks.add(callback)
|
|
1014
|
+
|
|
990
1015
|
@property
|
|
991
1016
|
def current_display(self) -> Optional[MarkdownDisplay]:
|
|
992
1017
|
"""
|
|
@@ -1417,28 +1442,43 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1417
1442
|
e = RuntimeError(f"Unknown tool: {x.name}")
|
|
1418
1443
|
return ContentToolResult(value=None, error=e, request=x)
|
|
1419
1444
|
|
|
1420
|
-
|
|
1421
|
-
|
|
1445
|
+
# First, invoke the request callbacks. If a ToolRejectError is raised,
|
|
1446
|
+
# treat it like a tool failure (i.e., gracefully handle it).
|
|
1447
|
+
result: ContentToolResult | None = None
|
|
1422
1448
|
try:
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1449
|
+
self._on_tool_request_callbacks.invoke(x)
|
|
1450
|
+
except ToolRejectError as e:
|
|
1451
|
+
result = ContentToolResult(value=None, error=e, request=x)
|
|
1452
|
+
|
|
1453
|
+
# Invoke the tool (if it hasn't been rejected).
|
|
1454
|
+
if result is None:
|
|
1455
|
+
try:
|
|
1456
|
+
if isinstance(x.arguments, dict):
|
|
1457
|
+
res = func(**x.arguments)
|
|
1458
|
+
else:
|
|
1459
|
+
res = func(x.arguments)
|
|
1460
|
+
|
|
1461
|
+
if isinstance(res, ContentToolResult):
|
|
1462
|
+
result = res
|
|
1463
|
+
else:
|
|
1464
|
+
result = ContentToolResult(value=res)
|
|
1427
1465
|
|
|
1428
|
-
|
|
1429
|
-
|
|
1466
|
+
result.request = x
|
|
1467
|
+
except Exception as e:
|
|
1468
|
+
result = ContentToolResult(value=None, error=e, request=x)
|
|
1430
1469
|
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
except Exception as e:
|
|
1470
|
+
# If we've captured an error, notify and log it.
|
|
1471
|
+
if result.error:
|
|
1434
1472
|
warnings.warn(
|
|
1435
1473
|
f"Calling tool '{x.name}' led to an error.",
|
|
1436
1474
|
ToolFailureWarning,
|
|
1437
1475
|
stacklevel=2,
|
|
1438
1476
|
)
|
|
1439
1477
|
traceback.print_exc()
|
|
1440
|
-
log_tool_error(x.name, str(
|
|
1441
|
-
|
|
1478
|
+
log_tool_error(x.name, str(x.arguments), result.error)
|
|
1479
|
+
|
|
1480
|
+
self._on_tool_result_callbacks.invoke(result)
|
|
1481
|
+
return result
|
|
1442
1482
|
|
|
1443
1483
|
async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
|
|
1444
1484
|
tool_def = self._tools.get(x.name, None)
|
|
@@ -1453,28 +1493,43 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1453
1493
|
e = RuntimeError(f"Unknown tool: {x.name}")
|
|
1454
1494
|
return ContentToolResult(value=None, error=e, request=x)
|
|
1455
1495
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1496
|
+
# First, invoke the request callbacks. If a ToolRejectError is raised,
|
|
1497
|
+
# treat it like a tool failure (i.e., gracefully handle it).
|
|
1498
|
+
result: ContentToolResult | None = None
|
|
1458
1499
|
try:
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1500
|
+
await self._on_tool_request_callbacks.invoke_async(x)
|
|
1501
|
+
except ToolRejectError as e:
|
|
1502
|
+
result = ContentToolResult(value=None, error=e, request=x)
|
|
1503
|
+
|
|
1504
|
+
# Invoke the tool (if it hasn't been rejected).
|
|
1505
|
+
if result is None:
|
|
1506
|
+
try:
|
|
1507
|
+
if isinstance(x.arguments, dict):
|
|
1508
|
+
res = await func(**x.arguments)
|
|
1509
|
+
else:
|
|
1510
|
+
res = await func(x.arguments)
|
|
1511
|
+
|
|
1512
|
+
if isinstance(res, ContentToolResult):
|
|
1513
|
+
result = res
|
|
1514
|
+
else:
|
|
1515
|
+
result = ContentToolResult(value=res)
|
|
1463
1516
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1517
|
+
result.request = x
|
|
1518
|
+
except Exception as e:
|
|
1519
|
+
result = ContentToolResult(value=None, error=e, request=x)
|
|
1466
1520
|
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
except Exception as e:
|
|
1521
|
+
# If we've captured an error, notify and log it.
|
|
1522
|
+
if result.error:
|
|
1470
1523
|
warnings.warn(
|
|
1471
1524
|
f"Calling tool '{x.name}' led to an error.",
|
|
1472
1525
|
ToolFailureWarning,
|
|
1473
1526
|
stacklevel=2,
|
|
1474
1527
|
)
|
|
1475
1528
|
traceback.print_exc()
|
|
1476
|
-
log_tool_error(x.name, str(
|
|
1477
|
-
|
|
1529
|
+
log_tool_error(x.name, str(x.arguments), result.error)
|
|
1530
|
+
|
|
1531
|
+
await self._on_tool_result_callbacks.invoke_async(result)
|
|
1532
|
+
return result
|
|
1478
1533
|
|
|
1479
1534
|
def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
|
|
1480
1535
|
"""
|
|
@@ -1545,6 +1600,21 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1545
1600
|
res += "\n" + turn.__repr__(indent=2)
|
|
1546
1601
|
return res + "\n"
|
|
1547
1602
|
|
|
1603
|
+
def __deepcopy__(self, memo):
|
|
1604
|
+
result = self.__class__.__new__(self.__class__)
|
|
1605
|
+
|
|
1606
|
+
# Avoid recursive references
|
|
1607
|
+
memo[id(self)] = result
|
|
1608
|
+
|
|
1609
|
+
# Copy all attributes except the problematic provider attribute
|
|
1610
|
+
for key, value in self.__dict__.items():
|
|
1611
|
+
if key != "provider":
|
|
1612
|
+
setattr(result, key, copy.deepcopy(value, memo))
|
|
1613
|
+
else:
|
|
1614
|
+
setattr(result, key, value)
|
|
1615
|
+
|
|
1616
|
+
return result
|
|
1617
|
+
|
|
1548
1618
|
|
|
1549
1619
|
class ChatResponse:
|
|
1550
1620
|
"""
|
chatlas/_content.py
CHANGED
|
@@ -60,6 +60,12 @@ class ContentText(Content):
|
|
|
60
60
|
text: str
|
|
61
61
|
content_type: ContentTypeEnum = "text"
|
|
62
62
|
|
|
63
|
+
def __init__(self, **data: Any):
|
|
64
|
+
super().__init__(**data)
|
|
65
|
+
|
|
66
|
+
if self.text == "" or self.text.isspace():
|
|
67
|
+
self.text = "[empty string]"
|
|
68
|
+
|
|
63
69
|
def __str__(self):
|
|
64
70
|
return self.text
|
|
65
71
|
|
chatlas/_databricks.py
CHANGED
|
@@ -85,7 +85,7 @@ def ChatDatabricks(
|
|
|
85
85
|
A chat object that retains the state of the conversation.
|
|
86
86
|
"""
|
|
87
87
|
if model is None:
|
|
88
|
-
model = log_model_default("databricks-
|
|
88
|
+
model = log_model_default("databricks-claude-3-7-sonnet")
|
|
89
89
|
|
|
90
90
|
return Chat(
|
|
91
91
|
provider=DatabricksProvider(
|
chatlas/_logging.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import warnings
|
|
4
|
+
from typing import Literal
|
|
4
5
|
|
|
5
6
|
from rich.logging import RichHandler
|
|
6
7
|
|
|
@@ -12,15 +13,38 @@ def _rich_handler() -> RichHandler:
|
|
|
12
13
|
return handler
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
if
|
|
16
|
+
def setup_logger(x: str, level: Literal["debug", "info"]) -> logging.Logger:
|
|
17
|
+
logger = logging.getLogger(x)
|
|
18
|
+
if level == "debug":
|
|
19
|
+
logger.setLevel(logging.DEBUG)
|
|
20
|
+
elif level == "info":
|
|
21
|
+
logger.setLevel(logging.INFO)
|
|
18
22
|
# By adding a RichHandler to chatlas' logger, we can guarantee that they
|
|
19
23
|
# never get dropped, even if the root logger's handlers are not
|
|
20
24
|
# RichHandlers.
|
|
21
|
-
logger.
|
|
22
|
-
|
|
25
|
+
if not any(isinstance(h, RichHandler) for h in logger.handlers):
|
|
26
|
+
logger.addHandler(_rich_handler())
|
|
23
27
|
logger.propagate = False
|
|
28
|
+
return logger
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger("chatlas")
|
|
32
|
+
log_level = os.environ.get("CHATLAS_LOG")
|
|
33
|
+
if log_level:
|
|
34
|
+
if log_level != "debug" and log_level != "info":
|
|
35
|
+
warnings.warn(
|
|
36
|
+
f"CHATLAS_LOG is set to '{log_level}', but the log level must "
|
|
37
|
+
"be one of 'debug' or 'info'. Defaulting to 'info'.",
|
|
38
|
+
)
|
|
39
|
+
log_level = "info"
|
|
40
|
+
|
|
41
|
+
# Manually setup the logger for each dependency we care about. This way, we
|
|
42
|
+
# can ensure that the logs won't get dropped when a rich display is activate
|
|
43
|
+
logger = setup_logger("chatlas", log_level)
|
|
44
|
+
openai_logger = setup_logger("openai", log_level)
|
|
45
|
+
anthropic_logger = setup_logger("anthropic", log_level)
|
|
46
|
+
google_logger = setup_logger("google_genai.models", log_level)
|
|
47
|
+
httpx_logger = setup_logger("httpx", log_level)
|
|
24
48
|
|
|
25
49
|
# Add a RichHandler to the root logger if there are no handlers. Note that
|
|
26
50
|
# if chatlas is imported before other libraries that set up logging, (like
|
chatlas/_snowflake.py
CHANGED
|
@@ -1,32 +1,60 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from typing import (
|
|
2
|
+
TYPE_CHECKING,
|
|
3
|
+
Generator,
|
|
4
|
+
Literal,
|
|
5
|
+
Optional,
|
|
6
|
+
TypedDict,
|
|
7
|
+
Union,
|
|
8
|
+
overload,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import orjson
|
|
5
12
|
from pydantic import BaseModel
|
|
6
13
|
|
|
7
14
|
from ._chat import Chat
|
|
8
|
-
from ._content import
|
|
15
|
+
from ._content import (
|
|
16
|
+
Content,
|
|
17
|
+
ContentJson,
|
|
18
|
+
ContentText,
|
|
19
|
+
ContentToolRequest,
|
|
20
|
+
ContentToolResult,
|
|
21
|
+
)
|
|
9
22
|
from ._logging import log_model_default
|
|
10
23
|
from ._provider import Provider
|
|
24
|
+
from ._tokens import tokens_log
|
|
11
25
|
from ._tools import Tool, basemodel_to_param_schema
|
|
12
26
|
from ._turn import Turn, normalize_turns
|
|
13
|
-
from ._utils import drop_none
|
|
27
|
+
from ._utils import drop_none
|
|
14
28
|
|
|
15
29
|
if TYPE_CHECKING:
|
|
16
|
-
|
|
30
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
31
|
+
from snowflake.core.rest import Event, SSEClient
|
|
32
|
+
|
|
33
|
+
Completion = models.NonStreamingCompleteResponse
|
|
34
|
+
CompletionChunk = models.StreamingCompleteResponseDataEvent
|
|
35
|
+
|
|
36
|
+
# Manually constructed TypedDict equivalent of models.CompleteRequest
|
|
37
|
+
class CompleteRequest(TypedDict, total=False):
|
|
38
|
+
"""
|
|
39
|
+
CompleteRequest parameters for Snowflake Cortex LLMs.
|
|
40
|
+
|
|
41
|
+
See `snowflake.core.cortex.inference_service.CompleteRequest` for more details.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
temperature: Union[float, int]
|
|
45
|
+
"""Temperature controls the amount of randomness used in response generation. A higher temperature corresponds to more randomness."""
|
|
17
46
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
CompletionChunk = str
|
|
47
|
+
top_p: Union[float, int]
|
|
48
|
+
"""Threshold probability for nucleus sampling. A higher top-p value increases the diversity of tokens that the model considers, while a lower value results in more predictable output."""
|
|
21
49
|
|
|
22
|
-
|
|
50
|
+
max_tokens: int
|
|
51
|
+
"""The maximum number of output tokens to produce. The default value is model-dependent."""
|
|
23
52
|
|
|
53
|
+
guardrails: models.GuardrailsConfig
|
|
54
|
+
"""Controls whether guardrails are enabled."""
|
|
24
55
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class ConversationMessage(TypedDict):
|
|
28
|
-
role: str
|
|
29
|
-
content: str
|
|
56
|
+
tool_choice: models.ToolChoice
|
|
57
|
+
"""Determines how tools are selected."""
|
|
30
58
|
|
|
31
59
|
|
|
32
60
|
def ChatSnowflake(
|
|
@@ -41,7 +69,7 @@ def ChatSnowflake(
|
|
|
41
69
|
private_key_file: Optional[str] = None,
|
|
42
70
|
private_key_file_pwd: Optional[str] = None,
|
|
43
71
|
kwargs: Optional[dict[str, "str | int"]] = None,
|
|
44
|
-
) -> Chat["
|
|
72
|
+
) -> Chat["CompleteRequest", "Completion"]:
|
|
45
73
|
"""
|
|
46
74
|
Chat with a Snowflake Cortex LLM
|
|
47
75
|
|
|
@@ -116,7 +144,7 @@ def ChatSnowflake(
|
|
|
116
144
|
"""
|
|
117
145
|
|
|
118
146
|
if model is None:
|
|
119
|
-
model = log_model_default("
|
|
147
|
+
model = log_model_default("claude-3-7-sonnet")
|
|
120
148
|
|
|
121
149
|
return Chat(
|
|
122
150
|
provider=SnowflakeProvider(
|
|
@@ -150,6 +178,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
150
178
|
kwargs: Optional[dict[str, "str | int"]],
|
|
151
179
|
):
|
|
152
180
|
try:
|
|
181
|
+
from snowflake.core import Root
|
|
153
182
|
from snowflake.snowpark import Session
|
|
154
183
|
except ImportError:
|
|
155
184
|
raise ImportError(
|
|
@@ -170,7 +199,9 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
170
199
|
)
|
|
171
200
|
|
|
172
201
|
self._model = model
|
|
173
|
-
|
|
202
|
+
|
|
203
|
+
session = Session.builder.configs(configs).create()
|
|
204
|
+
self._cortex_service = Root(session).cortex_inference_service
|
|
174
205
|
|
|
175
206
|
@overload
|
|
176
207
|
def chat_perform(
|
|
@@ -180,7 +211,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
180
211
|
turns: list[Turn],
|
|
181
212
|
tools: dict[str, Tool],
|
|
182
213
|
data_model: Optional[type[BaseModel]] = None,
|
|
183
|
-
kwargs: Optional["
|
|
214
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
184
215
|
): ...
|
|
185
216
|
|
|
186
217
|
@overload
|
|
@@ -191,7 +222,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
191
222
|
turns: list[Turn],
|
|
192
223
|
tools: dict[str, Tool],
|
|
193
224
|
data_model: Optional[type[BaseModel]] = None,
|
|
194
|
-
kwargs: Optional["
|
|
225
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
195
226
|
): ...
|
|
196
227
|
|
|
197
228
|
def chat_perform(
|
|
@@ -201,12 +232,25 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
201
232
|
turns: list[Turn],
|
|
202
233
|
tools: dict[str, Tool],
|
|
203
234
|
data_model: Optional[type[BaseModel]] = None,
|
|
204
|
-
kwargs: Optional["
|
|
235
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
205
236
|
):
|
|
206
|
-
|
|
237
|
+
req = self._complete_request(stream, turns, tools, data_model, kwargs)
|
|
238
|
+
client = self._cortex_service.complete(req)
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
events = client.events()
|
|
242
|
+
except Exception as e:
|
|
243
|
+
data = parse_request_object(client)
|
|
244
|
+
if data is None:
|
|
245
|
+
raise e
|
|
246
|
+
return data
|
|
207
247
|
|
|
208
|
-
|
|
209
|
-
|
|
248
|
+
if stream:
|
|
249
|
+
return generate_event_data(events)
|
|
250
|
+
|
|
251
|
+
for evt in events:
|
|
252
|
+
if evt.data:
|
|
253
|
+
return parse_event_data(evt.data, stream=False)
|
|
210
254
|
|
|
211
255
|
@overload
|
|
212
256
|
async def chat_perform_async(
|
|
@@ -216,7 +260,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
216
260
|
turns: list[Turn],
|
|
217
261
|
tools: dict[str, Tool],
|
|
218
262
|
data_model: Optional[type[BaseModel]] = None,
|
|
219
|
-
kwargs: Optional["
|
|
263
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
220
264
|
): ...
|
|
221
265
|
|
|
222
266
|
@overload
|
|
@@ -227,7 +271,7 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
227
271
|
turns: list[Turn],
|
|
228
272
|
tools: dict[str, Tool],
|
|
229
273
|
data_model: Optional[type[BaseModel]] = None,
|
|
230
|
-
kwargs: Optional["
|
|
274
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
231
275
|
): ...
|
|
232
276
|
|
|
233
277
|
async def chat_perform_async(
|
|
@@ -237,65 +281,164 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
237
281
|
turns: list[Turn],
|
|
238
282
|
tools: dict[str, Tool],
|
|
239
283
|
data_model: Optional[type[BaseModel]] = None,
|
|
240
|
-
kwargs: Optional["
|
|
284
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
241
285
|
):
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
286
|
+
req = self._complete_request(stream, turns, tools, data_model, kwargs)
|
|
287
|
+
res = self._cortex_service.complete_async(req)
|
|
288
|
+
# TODO: is there a way to get the SSEClient result without blocking?
|
|
289
|
+
client = res.result()
|
|
245
290
|
|
|
246
|
-
|
|
247
|
-
|
|
291
|
+
try:
|
|
292
|
+
events = client.events()
|
|
293
|
+
except Exception as e:
|
|
294
|
+
data = parse_request_object(client)
|
|
295
|
+
if data is None:
|
|
296
|
+
raise e
|
|
297
|
+
return data
|
|
248
298
|
|
|
249
|
-
# When streaming, res is an iterable of strings, but Chat() wants an async iterable
|
|
250
299
|
if stream:
|
|
251
|
-
|
|
300
|
+
return generate_event_data_async(events)
|
|
252
301
|
|
|
253
|
-
|
|
302
|
+
for evt in events:
|
|
303
|
+
if evt.data:
|
|
304
|
+
return parse_event_data(evt.data, stream=False)
|
|
254
305
|
|
|
255
|
-
def
|
|
306
|
+
def _complete_request(
|
|
256
307
|
self,
|
|
257
308
|
stream: bool,
|
|
258
309
|
turns: list[Turn],
|
|
259
310
|
tools: dict[str, Tool],
|
|
260
311
|
data_model: Optional[type[BaseModel]] = None,
|
|
261
|
-
kwargs: Optional["
|
|
312
|
+
kwargs: Optional["CompleteRequest"] = None,
|
|
262
313
|
):
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
# TODO: get tools working
|
|
314
|
+
from snowflake.core.cortex.inference_service import CompleteRequest
|
|
315
|
+
|
|
316
|
+
req = CompleteRequest(
|
|
317
|
+
model=self._model,
|
|
318
|
+
messages=self._as_request_messages(turns),
|
|
319
|
+
stream=stream,
|
|
320
|
+
)
|
|
321
|
+
|
|
272
322
|
if tools:
|
|
273
|
-
|
|
323
|
+
req.tools = req.tools or []
|
|
324
|
+
snow_tools = [self._as_snowflake_tool(tool) for tool in tools.values()]
|
|
325
|
+
req.tools.extend(snow_tools)
|
|
274
326
|
|
|
275
327
|
if data_model is not None:
|
|
328
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
329
|
+
|
|
276
330
|
params = basemodel_to_param_schema(data_model)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
"schema": {
|
|
331
|
+
req.response_format = models.CompleteRequestResponseFormat(
|
|
332
|
+
type="json",
|
|
333
|
+
schema={
|
|
281
334
|
"type": "object",
|
|
282
335
|
"properties": params["properties"],
|
|
283
336
|
"required": params["required"],
|
|
284
337
|
},
|
|
285
|
-
|
|
286
|
-
kwargs_full["options"] = opts
|
|
338
|
+
)
|
|
287
339
|
|
|
288
|
-
|
|
340
|
+
if kwargs:
|
|
341
|
+
for k, v in kwargs.items():
|
|
342
|
+
if hasattr(req, k):
|
|
343
|
+
setattr(req, k, v)
|
|
344
|
+
else:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"Unknown parameter {k} for Snowflake CompleteRequest. "
|
|
347
|
+
"Please check the Snowflake documentation for valid parameters."
|
|
348
|
+
)
|
|
289
349
|
|
|
290
|
-
|
|
291
|
-
return chunk
|
|
350
|
+
return req
|
|
292
351
|
|
|
352
|
+
def stream_text(self, chunk):
|
|
353
|
+
if not chunk.choices:
|
|
354
|
+
return None
|
|
355
|
+
delta = chunk.choices[0].delta
|
|
356
|
+
if delta is None or "content" not in delta:
|
|
357
|
+
return None
|
|
358
|
+
return delta["content"]
|
|
359
|
+
|
|
360
|
+
# Snowflake sort-of follows OpenAI/Anthropic streaming formats except they
|
|
361
|
+
# don't have the critical "index" field in the delta that the merge logic
|
|
362
|
+
# depends on (i.e., OpenAI), or official start/stop events (i.e.,
|
|
363
|
+
# Anthropic). So we have to do some janky merging here.
|
|
364
|
+
#
|
|
365
|
+
# This was done in a panic to get working asap, so don't judge :) I wouldn't
|
|
366
|
+
# be surprised if Snowflake realizes how bad this streaming format is and
|
|
367
|
+
# changes it in the future (thus probably breaking this code :( ).
|
|
293
368
|
def stream_merge_chunks(self, completion, chunk):
|
|
294
369
|
if completion is None:
|
|
295
370
|
return chunk
|
|
296
|
-
|
|
371
|
+
|
|
372
|
+
if completion.choices is None or chunk.choices is None:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"Unexpected None for completion.choices. Please report this issue."
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if completion.choices[0].delta is None or chunk.choices[0].delta is None:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
"Unexpected None for completion.choices[0].delta. Please report this issue."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
delta = completion.choices[0].delta
|
|
383
|
+
new_delta = chunk.choices[0].delta
|
|
384
|
+
if "content_list" not in delta or "content_list" not in new_delta:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
"Expected content_list to be in completion.choices[0].delta. Please report this issue."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
content_list = delta["content_list"]
|
|
390
|
+
new_content_list = new_delta["content_list"]
|
|
391
|
+
if not isinstance(content_list, list) or not isinstance(new_content_list, list):
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Expected content_list to be a list, got {type(new_content_list)}"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if new_delta["type"] == "tool_use":
|
|
397
|
+
# Presence of "tool_use_id" indicates a new tool request; otherwise, we're
|
|
398
|
+
# expecting input parameters
|
|
399
|
+
if "tool_use_id" in new_delta:
|
|
400
|
+
del new_delta["text"] # why is this here :eye-roll:?
|
|
401
|
+
content_list.append(new_delta)
|
|
402
|
+
elif "input" in new_delta:
|
|
403
|
+
# find most recent content with type: "tool_use" and append to that
|
|
404
|
+
for i in range(len(content_list) - 1, -1, -1):
|
|
405
|
+
if "tool_use_id" in content_list[i]:
|
|
406
|
+
content_list[i]["input"] = content_list[i].get("input", "")
|
|
407
|
+
content_list[i]["input"] += new_delta["input"]
|
|
408
|
+
break
|
|
409
|
+
else:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Unexpected tool_use delta: {new_delta}. Please report this issue."
|
|
412
|
+
)
|
|
413
|
+
elif new_delta["type"] == "text":
|
|
414
|
+
text = new_delta["text"]
|
|
415
|
+
# find most recent content with type: "text" and append to that
|
|
416
|
+
for i in range(len(content_list) - 1, -1, -1):
|
|
417
|
+
if content_list[i].get("type") == "text":
|
|
418
|
+
content_list[i]["text"] += text
|
|
419
|
+
break
|
|
420
|
+
else:
|
|
421
|
+
# if we don't find it, just append to the end
|
|
422
|
+
# this shouldn't happen, but just in case
|
|
423
|
+
content_list.append({"type": "text", "text": text})
|
|
424
|
+
else:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Unexpected streaming delta type: {new_delta['type']}. Please report this issue."
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
completion.choices[0].delta["content_list"] = content_list
|
|
430
|
+
|
|
431
|
+
return completion
|
|
297
432
|
|
|
298
433
|
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
434
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
435
|
+
|
|
436
|
+
completion_dict = completion.model_dump()
|
|
437
|
+
delta = completion_dict["choices"][0].pop("delta")
|
|
438
|
+
completion_dict["choices"][0]["message"] = delta
|
|
439
|
+
completion = models.NonStreamingCompleteResponse.model_construct(
|
|
440
|
+
**completion_dict
|
|
441
|
+
)
|
|
299
442
|
return self._as_turn(completion, has_data_model)
|
|
300
443
|
|
|
301
444
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
@@ -321,24 +464,207 @@ class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChu
|
|
|
321
464
|
"Snowflake does not currently support token counting."
|
|
322
465
|
)
|
|
323
466
|
|
|
324
|
-
def
|
|
325
|
-
|
|
467
|
+
def _as_request_messages(self, turns: list[Turn]):
|
|
468
|
+
from snowflake.core.cortex.inference_service import CompleteRequestMessagesInner
|
|
469
|
+
|
|
470
|
+
res: list[CompleteRequestMessagesInner] = []
|
|
326
471
|
for turn in turns:
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
"content": str(turn),
|
|
331
|
-
}
|
|
472
|
+
req = CompleteRequestMessagesInner(
|
|
473
|
+
role=turn.role,
|
|
474
|
+
content=turn.text,
|
|
332
475
|
)
|
|
476
|
+
for x in turn.contents:
|
|
477
|
+
if isinstance(x, ContentToolRequest):
|
|
478
|
+
req.content_list = req.content_list or []
|
|
479
|
+
req.content_list.append(
|
|
480
|
+
{
|
|
481
|
+
"type": "tool_use",
|
|
482
|
+
"tool_use": {
|
|
483
|
+
"tool_use_id": x.id,
|
|
484
|
+
"name": x.name,
|
|
485
|
+
"input": x.arguments,
|
|
486
|
+
},
|
|
487
|
+
}
|
|
488
|
+
)
|
|
489
|
+
elif isinstance(x, ContentToolResult):
|
|
490
|
+
# Snowflake does like empty content
|
|
491
|
+
req.content = req.content or "[tool_result]"
|
|
492
|
+
req.content_list = req.content_list or []
|
|
493
|
+
req.content_list.append(
|
|
494
|
+
{
|
|
495
|
+
"type": "tool_results",
|
|
496
|
+
"tool_results": {
|
|
497
|
+
"tool_use_id": x.id,
|
|
498
|
+
"name": x.name,
|
|
499
|
+
"content": [
|
|
500
|
+
{"type": "text", "text": x.get_model_value()}
|
|
501
|
+
],
|
|
502
|
+
},
|
|
503
|
+
}
|
|
504
|
+
)
|
|
505
|
+
elif isinstance(x, ContentJson):
|
|
506
|
+
req.content = req.content or "<structured data/>"
|
|
507
|
+
|
|
508
|
+
res.append(req)
|
|
333
509
|
return res
|
|
334
510
|
|
|
335
|
-
def _as_turn(self, completion, has_data_model) -> Turn:
|
|
336
|
-
|
|
511
|
+
def _as_turn(self, completion: "Completion", has_data_model: bool) -> Turn:
|
|
512
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
513
|
+
|
|
514
|
+
if not completion.choices:
|
|
515
|
+
return Turn("assistant", [])
|
|
516
|
+
|
|
517
|
+
choice = completion.choices[0]
|
|
518
|
+
if isinstance(choice, dict):
|
|
519
|
+
choice = models.NonStreamingCompleteResponseChoicesInner.from_dict(choice)
|
|
520
|
+
|
|
521
|
+
message = choice.message
|
|
522
|
+
if message is None:
|
|
523
|
+
return Turn("assistant", [])
|
|
524
|
+
|
|
525
|
+
contents: list[Content] = []
|
|
526
|
+
content_list = message.content_list or []
|
|
527
|
+
for content in content_list:
|
|
528
|
+
if "text" in content:
|
|
529
|
+
if has_data_model:
|
|
530
|
+
data = orjson.loads(content["text"])
|
|
531
|
+
contents.append(ContentJson(value=data))
|
|
532
|
+
else:
|
|
533
|
+
contents.append(ContentText(text=content["text"]))
|
|
534
|
+
elif "tool_use_id" in content:
|
|
535
|
+
params = content.get("input", "{}")
|
|
536
|
+
try:
|
|
537
|
+
params = orjson.loads(params)
|
|
538
|
+
except orjson.JSONDecodeError:
|
|
539
|
+
raise ValueError(
|
|
540
|
+
f"Failed to parse tool_use input: {params}. Please report this issue."
|
|
541
|
+
)
|
|
542
|
+
contents.append(
|
|
543
|
+
ContentToolRequest(
|
|
544
|
+
name=content["name"],
|
|
545
|
+
id=content["tool_use_id"],
|
|
546
|
+
arguments=params,
|
|
547
|
+
)
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
usage = completion.usage
|
|
551
|
+
if usage is None:
|
|
552
|
+
tokens = (0, 0)
|
|
553
|
+
else:
|
|
554
|
+
tokens = (usage.prompt_tokens or 0, usage.completion_tokens or 0)
|
|
555
|
+
|
|
556
|
+
tokens_log(self, tokens)
|
|
557
|
+
|
|
558
|
+
return Turn(
|
|
559
|
+
"assistant",
|
|
560
|
+
contents,
|
|
561
|
+
tokens=tokens,
|
|
562
|
+
# TODO: no finish_reason in Snowflake?
|
|
563
|
+
# finish_reason=completion.choices[0].finish_reason,
|
|
564
|
+
completion=completion,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# N.B. this is currently the best documentation I can find for how tool calling works
|
|
568
|
+
# https://quickstarts.snowflake.com/guide/getting-started-with-tool-use-on-cortex-and-anthropic-claude/index.html#5
|
|
569
|
+
def _as_snowflake_tool(self, tool: Tool):
|
|
570
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
571
|
+
|
|
572
|
+
func = tool.schema["function"]
|
|
573
|
+
params = func.get("parameters", {})
|
|
574
|
+
|
|
575
|
+
props = params.get("properties", {})
|
|
576
|
+
if not isinstance(props, dict):
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"Tool function parameters must be a dictionary, got {type(props)}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
required = params.get("required", [])
|
|
582
|
+
if not isinstance(required, list):
|
|
583
|
+
raise ValueError(
|
|
584
|
+
f"Tool function required parameters must be a list, got {type(required)}"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
input_schema = models.ToolToolSpecInputSchema(
|
|
588
|
+
type="object",
|
|
589
|
+
properties=props or None,
|
|
590
|
+
required=required or None,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
spec = models.ToolToolSpec(
|
|
594
|
+
type="generic",
|
|
595
|
+
name=func["name"],
|
|
596
|
+
description=func.get("description", ""),
|
|
597
|
+
input_schema=input_schema,
|
|
598
|
+
)
|
|
337
599
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
600
|
+
return models.Tool(tool_spec=spec)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
# Yield parsed event data from the Snowflake SSEClient
|
|
604
|
+
# (this is only needed for the streaming case).
|
|
605
|
+
def generate_event_data(events: Generator["Event", None, None]):
|
|
606
|
+
for x in events:
|
|
607
|
+
if x.data:
|
|
608
|
+
yield parse_event_data(x.data, stream=True)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# Same thing for the async case.
|
|
612
|
+
async def generate_event_data_async(events: Generator["Event", None, None]):
|
|
613
|
+
for x in events:
|
|
614
|
+
if x.data:
|
|
615
|
+
yield parse_event_data(x.data, stream=True)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@overload
|
|
619
|
+
def parse_event_data(
|
|
620
|
+
data: str, stream: Literal[True]
|
|
621
|
+
) -> "models.StreamingCompleteResponseDataEvent": ...
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@overload
|
|
625
|
+
def parse_event_data(
|
|
626
|
+
data: str, stream: Literal[False]
|
|
627
|
+
) -> "models.NonStreamingCompleteResponse": ...
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def parse_event_data(
|
|
631
|
+
data: str, stream: bool
|
|
632
|
+
) -> "models.NonStreamingCompleteResponse | models.StreamingCompleteResponseDataEvent":
|
|
633
|
+
"Parse the (JSON) event data from Snowflake using the relevant pydantic model."
|
|
634
|
+
import snowflake.core.cortex.inference_service._generated.models as models
|
|
635
|
+
|
|
636
|
+
try:
|
|
637
|
+
if stream:
|
|
638
|
+
return models.StreamingCompleteResponseDataEvent.from_json(data)
|
|
341
639
|
else:
|
|
342
|
-
|
|
640
|
+
return models.NonStreamingCompleteResponse.from_json(data)
|
|
641
|
+
except Exception:
|
|
642
|
+
raise ValueError(
|
|
643
|
+
f"Failed to parse Snowflake event data: {data}. "
|
|
644
|
+
"Please report this error here: https://github.com/posit-dev/chatlas/issues/new"
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
# At the time writing, .events() flat out errors in the stream=False case since
|
|
649
|
+
# the Content-Type is set to application/json;charset=utf-8, and SSEClient
|
|
650
|
+
# doesn't know how to handle that.
|
|
651
|
+
# https://github.com/snowflakedb/snowflake-ml-python/blob/6910e96/snowflake/cortex/_sse_client.py#L69
|
|
652
|
+
#
|
|
653
|
+
# So, do some janky stuff here to get the data out of the response.
|
|
654
|
+
#
|
|
655
|
+
# If and when snowflake fixes this, we can remove the try/except block.
|
|
656
|
+
def parse_request_object(
|
|
657
|
+
client: "SSEClient",
|
|
658
|
+
) -> "Optional[models.NonStreamingCompleteResponse]":
|
|
659
|
+
try:
|
|
660
|
+
import urllib3
|
|
661
|
+
|
|
662
|
+
if isinstance(client._event_source, urllib3.response.HTTPResponse):
|
|
663
|
+
return parse_event_data(
|
|
664
|
+
client._event_source.data.decode("utf-8"),
|
|
665
|
+
stream=False,
|
|
666
|
+
)
|
|
667
|
+
except Exception:
|
|
668
|
+
pass
|
|
343
669
|
|
|
344
|
-
|
|
670
|
+
return None
|
chatlas/_tools.py
CHANGED
|
@@ -8,7 +8,10 @@ from pydantic import BaseModel, Field, create_model
|
|
|
8
8
|
|
|
9
9
|
from . import _utils
|
|
10
10
|
|
|
11
|
-
__all__ = (
|
|
11
|
+
__all__ = (
|
|
12
|
+
"Tool",
|
|
13
|
+
"ToolRejectError",
|
|
14
|
+
)
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
17
|
from openai.types.chat import ChatCompletionToolParam
|
|
@@ -47,6 +50,61 @@ class Tool:
|
|
|
47
50
|
self.name = self.schema["function"]["name"]
|
|
48
51
|
|
|
49
52
|
|
|
53
|
+
class ToolRejectError(Exception):
|
|
54
|
+
"""
|
|
55
|
+
Error to represent a tool call being rejected.
|
|
56
|
+
|
|
57
|
+
This error is meant to be raised when an end user has chosen to deny a tool
|
|
58
|
+
call. It can be raised in a tool function or in a `.on_tool_request()`
|
|
59
|
+
callback registered via a :class:`~chatlas.Chat`. When used in the callback,
|
|
60
|
+
the tool call is rejected before the tool function is invoked.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
reason
|
|
65
|
+
A string describing the reason for rejecting the tool call. This will be
|
|
66
|
+
included in the error message passed to the LLM. In addition to the
|
|
67
|
+
reason, the error message will also include "Tool call rejected." to
|
|
68
|
+
indicate that the tool call was not processed.
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
-------
|
|
72
|
+
ToolRejectError
|
|
73
|
+
An error with a message informing the LLM that the tool call was
|
|
74
|
+
rejected (and the reason why).
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
>>> import os
|
|
79
|
+
>>> import chatlas as ctl
|
|
80
|
+
>>>
|
|
81
|
+
>>> chat = ctl.ChatOpenAI()
|
|
82
|
+
>>>
|
|
83
|
+
>>> def list_files():
|
|
84
|
+
... "List files in the user's current directory"
|
|
85
|
+
... while True:
|
|
86
|
+
... allow = input(
|
|
87
|
+
... "Would you like to allow access to your current directory? (yes/no): "
|
|
88
|
+
... )
|
|
89
|
+
... if allow.lower() == "yes":
|
|
90
|
+
... return os.listdir(".")
|
|
91
|
+
... elif allow.lower() == "no":
|
|
92
|
+
... raise ctl.ToolRejectError(
|
|
93
|
+
... "The user has chosen to disallow the tool call."
|
|
94
|
+
... )
|
|
95
|
+
... else:
|
|
96
|
+
... print("Please answer with 'yes' or 'no'.")
|
|
97
|
+
>>>
|
|
98
|
+
>>> chat.register_tool(list_files)
|
|
99
|
+
>>> chat.chat("What files are available in my current directory?")
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, reason: str = "The user has chosen to disallow the tool call."):
|
|
103
|
+
message = f"Tool call rejected. {reason}"
|
|
104
|
+
super().__init__(message)
|
|
105
|
+
self.message = message
|
|
106
|
+
|
|
107
|
+
|
|
50
108
|
def func_to_schema(
|
|
51
109
|
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
52
110
|
model: Optional[type[BaseModel]] = None,
|
chatlas/_version.py
CHANGED
|
@@ -29,9 +29,15 @@ class SubmitInputArgs(TypedDict, total=False):
|
|
|
29
29
|
"claude-3-7-sonnet-20250219",
|
|
30
30
|
"claude-3-5-haiku-latest",
|
|
31
31
|
"claude-3-5-haiku-20241022",
|
|
32
|
+
"claude-sonnet-4-20250514",
|
|
33
|
+
"claude-sonnet-4-0",
|
|
34
|
+
"claude-4-sonnet-20250514",
|
|
32
35
|
"claude-3-5-sonnet-latest",
|
|
33
36
|
"claude-3-5-sonnet-20241022",
|
|
34
37
|
"claude-3-5-sonnet-20240620",
|
|
38
|
+
"claude-opus-4-0",
|
|
39
|
+
"claude-opus-4-20250514",
|
|
40
|
+
"claude-4-opus-20250514",
|
|
35
41
|
"claude-3-opus-latest",
|
|
36
42
|
"claude-3-opus-20240229",
|
|
37
43
|
"claude-3-sonnet-20240229",
|
|
@@ -41,6 +47,7 @@ class SubmitInputArgs(TypedDict, total=False):
|
|
|
41
47
|
],
|
|
42
48
|
str,
|
|
43
49
|
]
|
|
50
|
+
service_tier: Union[Literal["auto", "standard_only"], anthropic.NotGiven]
|
|
44
51
|
stop_sequences: Union[list[str], anthropic.NotGiven]
|
|
45
52
|
stream: Union[Literal[False], Literal[True], anthropic.NotGiven]
|
|
46
53
|
system: Union[
|
chatlas/types/openai/_submit.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: chatlas
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.8.0
|
|
4
4
|
Summary: A simple and consistent interface for chatting with LLMs
|
|
5
5
|
Project-URL: Homepage, https://posit-dev.github.io/chatlas
|
|
6
6
|
Project-URL: Documentation, https://posit-dev.github.io/chatlas
|
|
@@ -44,7 +44,7 @@ Requires-Dist: pillow; extra == 'dev'
|
|
|
44
44
|
Requires-Dist: python-dotenv; extra == 'dev'
|
|
45
45
|
Requires-Dist: ruff>=0.6.5; extra == 'dev'
|
|
46
46
|
Requires-Dist: shiny; extra == 'dev'
|
|
47
|
-
Requires-Dist: snowflake-ml-python; extra == 'dev'
|
|
47
|
+
Requires-Dist: snowflake-ml-python>=1.8.4; extra == 'dev'
|
|
48
48
|
Requires-Dist: tenacity; extra == 'dev'
|
|
49
49
|
Requires-Dist: tiktoken; extra == 'dev'
|
|
50
50
|
Requires-Dist: torch; (python_version <= '3.11') and extra == 'dev'
|
|
@@ -1,47 +1,46 @@
|
|
|
1
|
-
chatlas/__init__.py,sha256=
|
|
2
|
-
chatlas/_anthropic.py,sha256=
|
|
1
|
+
chatlas/__init__.py,sha256=vZpEIGNqb8pJ5QCEYjHcKuoc0HDTkuhR40sJfKdhN7s,1571
|
|
2
|
+
chatlas/_anthropic.py,sha256=iF2aFS07T4I2Zlw5hI6Y-PjIBH4cG6ibOgVy12dI_6I,24885
|
|
3
3
|
chatlas/_auto.py,sha256=HsAvVwpSOkI9fdC35YX8beaE2IBnWLWTzOzu0ny951o,6129
|
|
4
|
-
chatlas/
|
|
5
|
-
chatlas/
|
|
4
|
+
chatlas/_callbacks.py,sha256=3RpPaOQonTqScjXbaShgKJ1Rc-YxzWerxKRBjVssFnc,1838
|
|
5
|
+
chatlas/_chat.py,sha256=uo8ucU0l59KkpOD93bbeJC6jqrM7DrzIKOl9g6PvhhY,56171
|
|
6
|
+
chatlas/_content.py,sha256=KtQfJ3CATk_TGzGzihc_wWRuHdlLWp9Dr-OgxZQbOQE,15931
|
|
6
7
|
chatlas/_content_image.py,sha256=EUK6wAint-JatLsiwvaPDu4D3W-NcIsDCkzABkXgfDg,8304
|
|
7
8
|
chatlas/_content_pdf.py,sha256=cffeuJxzhUDukQ-Srkmpy62M8X12skYpU_FVq-Wvya4,2420
|
|
8
|
-
chatlas/_databricks.py,sha256=
|
|
9
|
+
chatlas/_databricks.py,sha256=56dInk3UjQ5HqGiysn9TJ9hat745MhaFEn-IPfnR4vE,4696
|
|
9
10
|
chatlas/_display.py,sha256=wyQzSc6z1VqrJfkTLkw1wQcti9s1Pr4qT8UxFJESn4U,4664
|
|
10
11
|
chatlas/_github.py,sha256=xdGsvWvlbGMc1lgEM_oRL5p-wuxaZG-lu6a_4zxV-4Y,4235
|
|
11
12
|
chatlas/_google.py,sha256=pwnomFcDEzPXQQfFifoH2FrbDvOgRIIyjcNoJ8Hvv-I,19397
|
|
12
13
|
chatlas/_groq.py,sha256=0ou8iiAeI8EqjNKsLZhljQffSVOj8eCQEtW7-stZToY,4022
|
|
13
14
|
chatlas/_interpolate.py,sha256=ykwLP3x-ya9Q33U4knSU75dtk6pzJAeythEEIW-43Pc,3631
|
|
14
15
|
chatlas/_live_render.py,sha256=UMZltE35LxziDKPMEeDwQ9meZ95SeqwhJi7j-y9pcro,4004
|
|
15
|
-
chatlas/_logging.py,sha256=
|
|
16
|
+
chatlas/_logging.py,sha256=weKvXZDIZ88X7X61ruXM_S0AAhQ5mgiW9dR-km8x7Mg,3324
|
|
16
17
|
chatlas/_merge.py,sha256=SGj_BetgA7gaOqSBKOhYmW3CYeQKTEehFrXvx3y4OYE,3924
|
|
17
18
|
chatlas/_ollama.py,sha256=_5Na11A7Hrs83CK8x542UHDvpLjQPyTJ9WM03eFDqKA,3627
|
|
18
19
|
chatlas/_openai.py,sha256=qkU2DWOTsEVfjiAk4LDRszErVGLswiLwwim1zZnHn8o,23968
|
|
19
20
|
chatlas/_perplexity.py,sha256=POp_Lc8RlSEsHIHhFlHOE8IVsOb2nf7mHQZvyQp49_U,4304
|
|
20
21
|
chatlas/_provider.py,sha256=YmdBbz_u5aP_kBxl6s26OPiSnWG_vZ_fvf9L2qvBmyI,3809
|
|
21
|
-
chatlas/_snowflake.py,sha256=
|
|
22
|
+
chatlas/_snowflake.py,sha256=y9Lu5SFGsJ1QqQD0Xx7vsAwcBtw2_MWT1Vu7e0HQis0,24098
|
|
22
23
|
chatlas/_tokens.py,sha256=3W3EPUp9eWXUiwuzJwEPBv43AUznbK46pm59Htti7z4,2392
|
|
23
24
|
chatlas/_tokens_old.py,sha256=L9d9oafrXvEx2u4nIn_Jjn7adnQyLBnYBuPwJUE8Pl8,5005
|
|
24
|
-
chatlas/_tools.py,sha256
|
|
25
|
+
chatlas/_tools.py,sha256=Qucyx_Pq-DX66r5xsZ7VRNxrIMS_se-FS7HH7jbB1A0,6085
|
|
25
26
|
chatlas/_turn.py,sha256=7pve6YmD-L4c7Oxd6_ZAPkDudJ8AMpa6pP-pSroA1dM,5067
|
|
26
27
|
chatlas/_typing_extensions.py,sha256=YdzmlyPSBpIEcsOkoz12e6jETT1XEMV2Q72haE4cfwY,1036
|
|
27
28
|
chatlas/_utils.py,sha256=lli8ChbPUwEPebW8AoOoNoqiA95SVtoW2gb6ymj9gw4,4028
|
|
28
|
-
chatlas/_version.py,sha256=
|
|
29
|
+
chatlas/_version.py,sha256=fSm5pLlwHxfTD7vBTVEqChJUua9ilUsdQYNN_V3u3iE,511
|
|
29
30
|
chatlas/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
31
|
chatlas/types/__init__.py,sha256=P_EDL4eqsigKwB-u2qRmKlYQS5Y65m7oWjGC3cYmxO4,719
|
|
31
32
|
chatlas/types/anthropic/__init__.py,sha256=OwubA-DPHYpYo0XyRyAFwftOI0mOxtHzAyhUSLcDx54,417
|
|
32
33
|
chatlas/types/anthropic/_client.py,sha256=G0LRhoFBcsSOMr5qhP-0rAScsVXaVlHCpggfVp54bnQ,690
|
|
33
34
|
chatlas/types/anthropic/_client_bedrock.py,sha256=mNazQlu0pQt8JdzrYn3LKNgE4n732GjhQUJdQQK9QkY,785
|
|
34
|
-
chatlas/types/anthropic/_submit.py,sha256=
|
|
35
|
+
chatlas/types/anthropic/_submit.py,sha256=FJPtswPf_NTV3kAXlFODin3bFOqYNdamopjIaVSSr24,3417
|
|
35
36
|
chatlas/types/google/__init__.py,sha256=ZJhi8Kwvio2zp8T1TQqmvdHqkS-Khb6BGESPjREADgo,337
|
|
36
37
|
chatlas/types/google/_client.py,sha256=t7aKbxYq_xOA1Z3RnWcjewifdQFSHi7vKEj6MyKMCJk,729
|
|
37
38
|
chatlas/types/google/_submit.py,sha256=b-ZqMvI551Ia7pFlWdqUQJjov3neHmVwLFw-P2bgU8w,1883
|
|
38
39
|
chatlas/types/openai/__init__.py,sha256=Q2RAr1bSH1nHsxICK05nAmKmxdhKmhbBkWD_XHiVSrI,411
|
|
39
40
|
chatlas/types/openai/_client.py,sha256=YGm_EHtRSSHeeOZe-CV7oNvMJpEblEta3UTuU7lSRO8,754
|
|
40
41
|
chatlas/types/openai/_client_azure.py,sha256=jx8D_p46CLDGzTP-k-TtGzj-f3junj6or-86m8DD_0w,858
|
|
41
|
-
chatlas/types/openai/_submit.py,sha256=
|
|
42
|
-
chatlas/
|
|
43
|
-
chatlas/
|
|
44
|
-
chatlas-0.
|
|
45
|
-
chatlas-0.
|
|
46
|
-
chatlas-0.7.1.dist-info/licenses/LICENSE,sha256=zyuGzPOC7CcbOaBHsQ3UEyKYRO56KDUkor0OA4LqqDg,1081
|
|
47
|
-
chatlas-0.7.1.dist-info/RECORD,,
|
|
42
|
+
chatlas/types/openai/_submit.py,sha256=FmDBq8Wg8R-GB3mFBCvAPHyniCyIkgrzwd_iOiAOLM8,6607
|
|
43
|
+
chatlas-0.8.0.dist-info/METADATA,sha256=2qCNO5-g1k5DvVwJEkTwk8ve8TE16fWditQowVEhdH0,15255
|
|
44
|
+
chatlas-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
45
|
+
chatlas-0.8.0.dist-info/licenses/LICENSE,sha256=zyuGzPOC7CcbOaBHsQ3UEyKYRO56KDUkor0OA4LqqDg,1081
|
|
46
|
+
chatlas-0.8.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
# ---------------------------------------------------------
|
|
2
|
-
# Do not modify this file. It was generated by `scripts/generate_typed_dicts.py`.
|
|
3
|
-
# ---------------------------------------------------------
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from ._submit import SubmitInputArgs
|
|
7
|
-
|
|
8
|
-
__all__ = ("SubmitInputArgs",)
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
# ---------------------------------------------------------
|
|
2
|
-
# Do not modify this file. It was generated by `scripts/generate_typed_dicts.py`.
|
|
3
|
-
# ---------------------------------------------------------
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from typing import Optional, TypedDict, Union
|
|
7
|
-
|
|
8
|
-
import snowflake.cortex._complete
|
|
9
|
-
import snowflake.snowpark.column
|
|
10
|
-
import snowflake.snowpark.session
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class SubmitInputArgs(TypedDict, total=False):
|
|
14
|
-
model: Union[str, snowflake.snowpark.column.Column]
|
|
15
|
-
prompt: Union[
|
|
16
|
-
str,
|
|
17
|
-
list[snowflake.cortex._complete.ConversationMessage],
|
|
18
|
-
snowflake.snowpark.column.Column,
|
|
19
|
-
]
|
|
20
|
-
options: Optional[snowflake.cortex._complete.CompleteOptions]
|
|
21
|
-
session: Optional[snowflake.snowpark.session.Session]
|
|
22
|
-
stream: bool
|
|
23
|
-
timeout: Optional[float]
|
|
24
|
-
deadline: Optional[float]
|
|
File without changes
|
|
File without changes
|