chatlas 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +11 -1
- chatlas/_anthropic.py +8 -10
- chatlas/_auto.py +183 -0
- chatlas/_chat.py +50 -19
- chatlas/_content.py +23 -7
- chatlas/_display.py +12 -2
- chatlas/_github.py +1 -1
- chatlas/_google.py +263 -166
- chatlas/_groq.py +1 -1
- chatlas/_live_render.py +116 -0
- chatlas/_merge.py +1 -1
- chatlas/_ollama.py +1 -1
- chatlas/_openai.py +4 -6
- chatlas/_perplexity.py +1 -1
- chatlas/_provider.py +0 -9
- chatlas/_snowflake.py +321 -0
- chatlas/_utils.py +7 -0
- chatlas/_version.py +21 -0
- chatlas/py.typed +0 -0
- chatlas/types/__init__.py +5 -1
- chatlas/types/anthropic/_submit.py +24 -2
- chatlas/types/google/_client.py +12 -91
- chatlas/types/google/_submit.py +40 -87
- chatlas/types/openai/_submit.py +9 -2
- chatlas/types/snowflake/__init__.py +8 -0
- chatlas/types/snowflake/_submit.py +24 -0
- {chatlas-0.3.0.dist-info → chatlas-0.5.0.dist-info}/METADATA +35 -7
- chatlas-0.5.0.dist-info/RECORD +44 -0
- chatlas-0.3.0.dist-info/RECORD +0 -37
- {chatlas-0.3.0.dist-info → chatlas-0.5.0.dist-info}/WHEEL +0 -0
chatlas/_live_render.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# A 'patched' version of LiveRender that adds the 'crop_above' vertical overflow method.
|
|
2
|
+
# Derives from https://github.com/Textualize/rich/pull/3637
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
if sys.version_info >= (3, 8):
|
|
7
|
+
from typing import Literal
|
|
8
|
+
else:
|
|
9
|
+
from typing_extensions import Literal # pragma: no cover
|
|
10
|
+
|
|
11
|
+
from rich._loop import loop_last
|
|
12
|
+
from rich.console import Console, ConsoleOptions, RenderableType, RenderResult
|
|
13
|
+
from rich.control import Control
|
|
14
|
+
from rich.segment import ControlType, Segment
|
|
15
|
+
from rich.style import StyleType
|
|
16
|
+
from rich.text import Text
|
|
17
|
+
|
|
18
|
+
VerticalOverflowMethod = Literal["crop", "crop_above", "ellipsis", "visible"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LiveRender:
|
|
22
|
+
"""Creates a renderable that may be updated.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
renderable (RenderableType): Any renderable object.
|
|
26
|
+
style (StyleType, optional): An optional style to apply to the renderable. Defaults to "".
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
renderable: RenderableType,
|
|
32
|
+
style: StyleType = "",
|
|
33
|
+
vertical_overflow: VerticalOverflowMethod = "ellipsis",
|
|
34
|
+
) -> None:
|
|
35
|
+
self.renderable = renderable
|
|
36
|
+
self.style = style
|
|
37
|
+
self.vertical_overflow = vertical_overflow
|
|
38
|
+
self._shape: Optional[Tuple[int, int]] = None
|
|
39
|
+
|
|
40
|
+
def set_renderable(self, renderable: RenderableType) -> None:
|
|
41
|
+
"""Set a new renderable.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
renderable (RenderableType): Any renderable object, including str.
|
|
45
|
+
"""
|
|
46
|
+
self.renderable = renderable
|
|
47
|
+
|
|
48
|
+
def position_cursor(self) -> Control:
|
|
49
|
+
"""Get control codes to move cursor to beginning of live render.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Control: A control instance that may be printed.
|
|
53
|
+
"""
|
|
54
|
+
if self._shape is not None:
|
|
55
|
+
_, height = self._shape
|
|
56
|
+
return Control(
|
|
57
|
+
ControlType.CARRIAGE_RETURN,
|
|
58
|
+
(ControlType.ERASE_IN_LINE, 2),
|
|
59
|
+
*(
|
|
60
|
+
(
|
|
61
|
+
(ControlType.CURSOR_UP, 1),
|
|
62
|
+
(ControlType.ERASE_IN_LINE, 2),
|
|
63
|
+
)
|
|
64
|
+
* (height - 1)
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
return Control()
|
|
68
|
+
|
|
69
|
+
def restore_cursor(self) -> Control:
|
|
70
|
+
"""Get control codes to clear the render and restore the cursor to its previous position.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Control: A Control instance that may be printed.
|
|
74
|
+
"""
|
|
75
|
+
if self._shape is not None:
|
|
76
|
+
_, height = self._shape
|
|
77
|
+
return Control(
|
|
78
|
+
ControlType.CARRIAGE_RETURN,
|
|
79
|
+
*((ControlType.CURSOR_UP, 1), (ControlType.ERASE_IN_LINE, 2)) * height,
|
|
80
|
+
)
|
|
81
|
+
return Control()
|
|
82
|
+
|
|
83
|
+
def __rich_console__(
|
|
84
|
+
self, console: Console, options: ConsoleOptions
|
|
85
|
+
) -> RenderResult:
|
|
86
|
+
renderable = self.renderable
|
|
87
|
+
style = console.get_style(self.style)
|
|
88
|
+
lines = console.render_lines(renderable, options, style=style, pad=False)
|
|
89
|
+
shape = Segment.get_shape(lines)
|
|
90
|
+
|
|
91
|
+
_, height = shape
|
|
92
|
+
if height > options.size.height:
|
|
93
|
+
if self.vertical_overflow == "crop":
|
|
94
|
+
lines = lines[: options.size.height]
|
|
95
|
+
shape = Segment.get_shape(lines)
|
|
96
|
+
elif self.vertical_overflow == "crop_above":
|
|
97
|
+
lines = lines[-(options.size.height) :]
|
|
98
|
+
shape = Segment.get_shape(lines)
|
|
99
|
+
elif self.vertical_overflow == "ellipsis":
|
|
100
|
+
lines = lines[: (options.size.height - 1)]
|
|
101
|
+
overflow_text = Text(
|
|
102
|
+
"...",
|
|
103
|
+
overflow="crop",
|
|
104
|
+
justify="center",
|
|
105
|
+
end="",
|
|
106
|
+
style="live.ellipsis",
|
|
107
|
+
)
|
|
108
|
+
lines.append(list(console.render(overflow_text)))
|
|
109
|
+
shape = Segment.get_shape(lines)
|
|
110
|
+
self._shape = shape
|
|
111
|
+
|
|
112
|
+
new_line = Segment.line()
|
|
113
|
+
for last, line in loop_last(lines):
|
|
114
|
+
yield from line
|
|
115
|
+
if not last:
|
|
116
|
+
yield new_line
|
chatlas/_merge.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# Adapted from https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/utils/_merge.py
|
|
2
|
-
# Also tweaked to more closely match https://github.com/hadley/
|
|
2
|
+
# Also tweaked to more closely match https://github.com/hadley/ellmer/blob/main/R/utils-merge.R
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
chatlas/_ollama.py
CHANGED
chatlas/_openai.py
CHANGED
|
@@ -79,7 +79,7 @@ def ChatOpenAI(
|
|
|
79
79
|
::: {.callout-note}
|
|
80
80
|
## Python requirements
|
|
81
81
|
|
|
82
|
-
`ChatOpenAI` requires the `openai` package
|
|
82
|
+
`ChatOpenAI` requires the `openai` package: `pip install "chatlas[openai]"`.
|
|
83
83
|
:::
|
|
84
84
|
|
|
85
85
|
Examples
|
|
@@ -338,7 +338,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
338
338
|
return chunkd
|
|
339
339
|
return merge_dicts(completion, chunkd)
|
|
340
340
|
|
|
341
|
-
def stream_turn(self, completion, has_data_model
|
|
341
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
342
342
|
from openai.types.chat import ChatCompletion
|
|
343
343
|
|
|
344
344
|
delta = completion["choices"][0].pop("delta") # type: ignore
|
|
@@ -346,9 +346,6 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
346
346
|
completion = ChatCompletion.construct(**completion)
|
|
347
347
|
return self._as_turn(completion, has_data_model)
|
|
348
348
|
|
|
349
|
-
async def stream_turn_async(self, completion, has_data_model, stream):
|
|
350
|
-
return self.stream_turn(completion, has_data_model, stream)
|
|
351
|
-
|
|
352
349
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
353
350
|
return self._as_turn(completion, has_data_model)
|
|
354
351
|
|
|
@@ -595,7 +592,8 @@ def ChatAzureOpenAI(
|
|
|
595
592
|
::: {.callout-note}
|
|
596
593
|
## Python requirements
|
|
597
594
|
|
|
598
|
-
`ChatAzureOpenAI` requires the `openai` package
|
|
595
|
+
`ChatAzureOpenAI` requires the `openai` package:
|
|
596
|
+
`pip install "chatlas[azure-openai]"`.
|
|
599
597
|
:::
|
|
600
598
|
|
|
601
599
|
Examples
|
chatlas/_perplexity.py
CHANGED
chatlas/_provider.py
CHANGED
|
@@ -125,15 +125,6 @@ class Provider(
|
|
|
125
125
|
self,
|
|
126
126
|
completion: ChatCompletionDictT,
|
|
127
127
|
has_data_model: bool,
|
|
128
|
-
stream: Any,
|
|
129
|
-
) -> Turn: ...
|
|
130
|
-
|
|
131
|
-
@abstractmethod
|
|
132
|
-
async def stream_turn_async(
|
|
133
|
-
self,
|
|
134
|
-
completion: ChatCompletionDictT,
|
|
135
|
-
has_data_model: bool,
|
|
136
|
-
stream: Any,
|
|
137
128
|
) -> Turn: ...
|
|
138
129
|
|
|
139
130
|
@abstractmethod
|
chatlas/_snowflake.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, overload
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from ._chat import Chat
|
|
6
|
+
from ._content import Content
|
|
7
|
+
from ._logging import log_model_default
|
|
8
|
+
from ._provider import Provider
|
|
9
|
+
from ._tools import Tool
|
|
10
|
+
from ._turn import Turn, normalize_turns
|
|
11
|
+
from ._utils import drop_none
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from snowflake.snowpark import Column
|
|
15
|
+
|
|
16
|
+
# Types inferred from the return type of the `snowflake.cortex.complete` function
|
|
17
|
+
Completion = str | Column
|
|
18
|
+
CompletionChunk = str
|
|
19
|
+
|
|
20
|
+
from .types.snowflake import SubmitInputArgs
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# The main prompt input type for Snowflake
|
|
24
|
+
# This was copy-pasted from `snowflake.cortex._complete.ConversationMessage`
|
|
25
|
+
class ConversationMessage(TypedDict):
|
|
26
|
+
role: str
|
|
27
|
+
content: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def ChatSnowflake(
|
|
31
|
+
*,
|
|
32
|
+
system_prompt: Optional[str] = None,
|
|
33
|
+
model: Optional[str] = None,
|
|
34
|
+
turns: Optional[list[Turn]] = None,
|
|
35
|
+
connection_name: Optional[str] = None,
|
|
36
|
+
account: Optional[str] = None,
|
|
37
|
+
user: Optional[str] = None,
|
|
38
|
+
password: Optional[str] = None,
|
|
39
|
+
private_key_file: Optional[str] = None,
|
|
40
|
+
private_key_file_pwd: Optional[str] = None,
|
|
41
|
+
kwargs: Optional[dict[str, "str | int"]] = None,
|
|
42
|
+
) -> Chat["SubmitInputArgs", "Completion"]:
|
|
43
|
+
"""
|
|
44
|
+
Chat with a Snowflake Cortex LLM
|
|
45
|
+
|
|
46
|
+
https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions
|
|
47
|
+
|
|
48
|
+
Prerequisites
|
|
49
|
+
-------------
|
|
50
|
+
|
|
51
|
+
::: {.callout-note}
|
|
52
|
+
## Python requirements
|
|
53
|
+
|
|
54
|
+
`ChatSnowflake`, requires the `snowflake-ml-python` package:
|
|
55
|
+
`pip install "chatlas[snowflake]"`.
|
|
56
|
+
:::
|
|
57
|
+
|
|
58
|
+
::: {.callout-note}
|
|
59
|
+
## Snowflake credentials
|
|
60
|
+
|
|
61
|
+
Snowflake provides a handful of ways to authenticate, but it's recommended
|
|
62
|
+
to use [key-pair
|
|
63
|
+
auth](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#label-python-connection-toml)
|
|
64
|
+
to generate a `private_key_file`. It's also recommended to place your
|
|
65
|
+
credentials in a [`connections.toml`
|
|
66
|
+
file](https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-session#connect-by-using-the-connections-toml-file).
|
|
67
|
+
|
|
68
|
+
This way, once your credentials are in the `connections.toml` file, you can
|
|
69
|
+
simply call `ChatSnowflake(connection_name="my_connection")` to
|
|
70
|
+
authenticate. If you don't want to use a `connections.toml` file, you can
|
|
71
|
+
specify the connection parameters directly (with `account`, `user`,
|
|
72
|
+
`password`, etc.).
|
|
73
|
+
:::
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
system_prompt
|
|
79
|
+
A system prompt to set the behavior of the assistant.
|
|
80
|
+
model
|
|
81
|
+
The model to use for the chat. The default, None, will pick a reasonable
|
|
82
|
+
default, and warn you about it. We strongly recommend explicitly
|
|
83
|
+
choosing a model for all but the most casual use.
|
|
84
|
+
turns
|
|
85
|
+
A list of turns to start the chat with (i.e., continuing a previous
|
|
86
|
+
conversation). If not provided, the conversation begins from scratch. Do
|
|
87
|
+
not provide non-None values for both `turns` and `system_prompt`. Each
|
|
88
|
+
message in the list should be a dictionary with at least `role` (usually
|
|
89
|
+
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
90
|
+
there is also a `content` field, which is a string.
|
|
91
|
+
connection_name
|
|
92
|
+
The name of the connection (i.e., section) within the connections.toml file.
|
|
93
|
+
This is useful if you want to keep your credentials in a connections.toml file
|
|
94
|
+
rather than specifying them directly in the arguments.
|
|
95
|
+
https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-session#connect-by-using-the-connections-toml-file
|
|
96
|
+
account
|
|
97
|
+
Your Snowflake account identifier. Required if `connection_name` is not provided.
|
|
98
|
+
https://docs.snowflake.com/en/user-guide/admin-account-identifier
|
|
99
|
+
user
|
|
100
|
+
Your Snowflake user name. Required if `connection_name` is not provided.
|
|
101
|
+
password
|
|
102
|
+
Your Snowflake password. Required if doing password authentication and
|
|
103
|
+
`connection_name` is not provided.
|
|
104
|
+
private_key_file
|
|
105
|
+
The path to your private key file. Required if you are using key pair authentication.
|
|
106
|
+
https://docs.snowflake.com/en/user-guide/key-pair-auth
|
|
107
|
+
private_key_file_pwd
|
|
108
|
+
The password for your private key file. Required if you are using key pair authentication.
|
|
109
|
+
https://docs.snowflake.com/en/user-guide/key-pair-auth
|
|
110
|
+
kwargs
|
|
111
|
+
Additional keyword arguments passed along to the Snowflake connection builder. These can
|
|
112
|
+
include any parameters supported by the `snowflake-ml-python` package.
|
|
113
|
+
https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-session#connect-by-specifying-connection-parameters
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
if model is None:
|
|
117
|
+
model = log_model_default("llama3.1-70b")
|
|
118
|
+
|
|
119
|
+
return Chat(
|
|
120
|
+
provider=SnowflakeProvider(
|
|
121
|
+
model=model,
|
|
122
|
+
connection_name=connection_name,
|
|
123
|
+
account=account,
|
|
124
|
+
user=user,
|
|
125
|
+
password=password,
|
|
126
|
+
private_key_file=private_key_file,
|
|
127
|
+
private_key_file_pwd=private_key_file_pwd,
|
|
128
|
+
kwargs=kwargs,
|
|
129
|
+
),
|
|
130
|
+
turns=normalize_turns(
|
|
131
|
+
turns or [],
|
|
132
|
+
system_prompt,
|
|
133
|
+
),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class SnowflakeProvider(Provider["Completion", "CompletionChunk", "CompletionChunk"]):
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
*,
|
|
141
|
+
model: str,
|
|
142
|
+
connection_name: Optional[str],
|
|
143
|
+
account: Optional[str],
|
|
144
|
+
user: Optional[str],
|
|
145
|
+
password: Optional[str],
|
|
146
|
+
private_key_file: Optional[str],
|
|
147
|
+
private_key_file_pwd: Optional[str],
|
|
148
|
+
kwargs: Optional[dict[str, "str | int"]],
|
|
149
|
+
):
|
|
150
|
+
try:
|
|
151
|
+
from snowflake.snowpark import Session
|
|
152
|
+
except ImportError:
|
|
153
|
+
raise ImportError(
|
|
154
|
+
"`ChatSnowflake()` requires the `snowflake-ml-python` package. "
|
|
155
|
+
"Please install it via `pip install snowflake-ml-python`."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
configs: dict[str, str | int] = drop_none(
|
|
159
|
+
{
|
|
160
|
+
"connection_name": connection_name,
|
|
161
|
+
"account": account,
|
|
162
|
+
"user": user,
|
|
163
|
+
"password": password,
|
|
164
|
+
"private_key_file": private_key_file,
|
|
165
|
+
"private_key_file_pwd": private_key_file_pwd,
|
|
166
|
+
**(kwargs or {}),
|
|
167
|
+
}
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self._model = model
|
|
171
|
+
self._session = Session.builder.configs(configs).create()
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
def chat_perform(
|
|
175
|
+
self,
|
|
176
|
+
*,
|
|
177
|
+
stream: Literal[False],
|
|
178
|
+
turns: list[Turn],
|
|
179
|
+
tools: dict[str, Tool],
|
|
180
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
181
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
182
|
+
): ...
|
|
183
|
+
|
|
184
|
+
@overload
|
|
185
|
+
def chat_perform(
|
|
186
|
+
self,
|
|
187
|
+
*,
|
|
188
|
+
stream: Literal[True],
|
|
189
|
+
turns: list[Turn],
|
|
190
|
+
tools: dict[str, Tool],
|
|
191
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
192
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
193
|
+
): ...
|
|
194
|
+
|
|
195
|
+
def chat_perform(
|
|
196
|
+
self,
|
|
197
|
+
*,
|
|
198
|
+
stream: bool,
|
|
199
|
+
turns: list[Turn],
|
|
200
|
+
tools: dict[str, Tool],
|
|
201
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
202
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
203
|
+
):
|
|
204
|
+
from snowflake.cortex import complete
|
|
205
|
+
|
|
206
|
+
kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
|
|
207
|
+
return complete(**kwargs)
|
|
208
|
+
|
|
209
|
+
@overload
|
|
210
|
+
async def chat_perform_async(
|
|
211
|
+
self,
|
|
212
|
+
*,
|
|
213
|
+
stream: Literal[False],
|
|
214
|
+
turns: list[Turn],
|
|
215
|
+
tools: dict[str, Tool],
|
|
216
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
217
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
218
|
+
): ...
|
|
219
|
+
|
|
220
|
+
@overload
|
|
221
|
+
async def chat_perform_async(
|
|
222
|
+
self,
|
|
223
|
+
*,
|
|
224
|
+
stream: Literal[True],
|
|
225
|
+
turns: list[Turn],
|
|
226
|
+
tools: dict[str, Tool],
|
|
227
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
228
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
229
|
+
): ...
|
|
230
|
+
|
|
231
|
+
async def chat_perform_async(
|
|
232
|
+
self,
|
|
233
|
+
*,
|
|
234
|
+
stream: bool,
|
|
235
|
+
turns: list[Turn],
|
|
236
|
+
tools: dict[str, Tool],
|
|
237
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
238
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
239
|
+
):
|
|
240
|
+
raise NotImplementedError(
|
|
241
|
+
"Snowflake does not currently support async completions."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def _chat_perform_args(
|
|
245
|
+
self,
|
|
246
|
+
stream: bool,
|
|
247
|
+
turns: list[Turn],
|
|
248
|
+
tools: dict[str, Tool],
|
|
249
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
250
|
+
kwargs: Optional["SubmitInputArgs"] = None,
|
|
251
|
+
):
|
|
252
|
+
# Cortex doesn't seem to support tools
|
|
253
|
+
if tools:
|
|
254
|
+
raise ValueError("Snowflake does not currently support tools.")
|
|
255
|
+
|
|
256
|
+
# TODO: implement data_model when this PR makes it into snowflake-ml-python
|
|
257
|
+
# https://github.com/snowflakedb/snowflake-ml-python/pull/141
|
|
258
|
+
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#structured-output-example
|
|
259
|
+
if data_model:
|
|
260
|
+
raise NotImplementedError(
|
|
261
|
+
"The snowflake-ml-python package currently doesn't support structured output. "
|
|
262
|
+
"Upvote this PR to help prioritize it: "
|
|
263
|
+
"https://github.com/snowflakedb/snowflake-ml-python/pull/141"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
kwargs_full: "SubmitInputArgs" = {
|
|
267
|
+
"stream": stream,
|
|
268
|
+
"prompt": self._as_prompt_input(turns),
|
|
269
|
+
"model": self._model,
|
|
270
|
+
**(kwargs or {}),
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
return kwargs_full
|
|
274
|
+
|
|
275
|
+
def stream_text(self, chunk):
|
|
276
|
+
return chunk
|
|
277
|
+
|
|
278
|
+
def stream_merge_chunks(self, completion, chunk):
|
|
279
|
+
if completion is None:
|
|
280
|
+
return chunk
|
|
281
|
+
return completion + chunk
|
|
282
|
+
|
|
283
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
284
|
+
return self._as_turn(completion, has_data_model)
|
|
285
|
+
|
|
286
|
+
def value_turn(self, completion, has_data_model) -> Turn:
|
|
287
|
+
return self._as_turn(completion, has_data_model)
|
|
288
|
+
|
|
289
|
+
def token_count(
|
|
290
|
+
self,
|
|
291
|
+
*args: "Content | str",
|
|
292
|
+
tools: dict[str, Tool],
|
|
293
|
+
data_model: Optional[type[BaseModel]],
|
|
294
|
+
) -> int:
|
|
295
|
+
raise NotImplementedError(
|
|
296
|
+
"Snowflake does not currently support token counting."
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
async def token_count_async(
|
|
300
|
+
self,
|
|
301
|
+
*args: "Content | str",
|
|
302
|
+
tools: dict[str, Tool],
|
|
303
|
+
data_model: Optional[type[BaseModel]],
|
|
304
|
+
) -> int:
|
|
305
|
+
raise NotImplementedError(
|
|
306
|
+
"Snowflake does not currently support token counting."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]:
|
|
310
|
+
res: list["ConversationMessage"] = []
|
|
311
|
+
for turn in turns:
|
|
312
|
+
res.append(
|
|
313
|
+
{
|
|
314
|
+
"role": turn.role,
|
|
315
|
+
"content": turn.text,
|
|
316
|
+
}
|
|
317
|
+
)
|
|
318
|
+
return res
|
|
319
|
+
|
|
320
|
+
def _as_turn(self, completion, has_data_model) -> Turn:
|
|
321
|
+
return Turn("assistant", completion)
|
chatlas/_utils.py
CHANGED
|
@@ -61,6 +61,13 @@ def is_async_callable(
|
|
|
61
61
|
return False
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
T = TypeVar("T")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def drop_none(x: dict[str, T | None]) -> dict[str, T]:
|
|
68
|
+
return {k: v for k, v in x.items() if v is not None}
|
|
69
|
+
|
|
70
|
+
|
|
64
71
|
# https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
|
|
65
72
|
def is_testing():
|
|
66
73
|
return os.environ.get("PYTEST_CURRENT_TEST", None) is not None
|
chatlas/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.5.0'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 5, 0)
|
chatlas/py.typed
ADDED
|
File without changes
|
chatlas/types/__init__.py
CHANGED
|
@@ -8,18 +8,24 @@ from typing import Iterable, Literal, Mapping, Optional, TypedDict, Union
|
|
|
8
8
|
import anthropic
|
|
9
9
|
import anthropic.types.message_param
|
|
10
10
|
import anthropic.types.text_block_param
|
|
11
|
+
import anthropic.types.thinking_config_disabled_param
|
|
12
|
+
import anthropic.types.thinking_config_enabled_param
|
|
13
|
+
import anthropic.types.tool_bash_20250124_param
|
|
11
14
|
import anthropic.types.tool_choice_any_param
|
|
12
15
|
import anthropic.types.tool_choice_auto_param
|
|
16
|
+
import anthropic.types.tool_choice_none_param
|
|
13
17
|
import anthropic.types.tool_choice_tool_param
|
|
14
18
|
import anthropic.types.tool_param
|
|
19
|
+
import anthropic.types.tool_text_editor_20250124_param
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class SubmitInputArgs(TypedDict, total=False):
|
|
18
23
|
max_tokens: int
|
|
19
24
|
messages: Iterable[anthropic.types.message_param.MessageParam]
|
|
20
25
|
model: Union[
|
|
21
|
-
str,
|
|
22
26
|
Literal[
|
|
27
|
+
"claude-3-7-sonnet-latest",
|
|
28
|
+
"claude-3-7-sonnet-20250219",
|
|
23
29
|
"claude-3-5-haiku-latest",
|
|
24
30
|
"claude-3-5-haiku-20241022",
|
|
25
31
|
"claude-3-5-sonnet-latest",
|
|
@@ -32,6 +38,7 @@ class SubmitInputArgs(TypedDict, total=False):
|
|
|
32
38
|
"claude-2.1",
|
|
33
39
|
"claude-2.0",
|
|
34
40
|
],
|
|
41
|
+
str,
|
|
35
42
|
]
|
|
36
43
|
stop_sequences: Union[list[str], anthropic.NotGiven]
|
|
37
44
|
stream: Union[Literal[False], Literal[True], anthropic.NotGiven]
|
|
@@ -41,13 +48,28 @@ class SubmitInputArgs(TypedDict, total=False):
|
|
|
41
48
|
anthropic.NotGiven,
|
|
42
49
|
]
|
|
43
50
|
temperature: float | anthropic.NotGiven
|
|
51
|
+
thinking: Union[
|
|
52
|
+
anthropic.types.thinking_config_enabled_param.ThinkingConfigEnabledParam,
|
|
53
|
+
anthropic.types.thinking_config_disabled_param.ThinkingConfigDisabledParam,
|
|
54
|
+
anthropic.NotGiven,
|
|
55
|
+
]
|
|
44
56
|
tool_choice: Union[
|
|
45
57
|
anthropic.types.tool_choice_auto_param.ToolChoiceAutoParam,
|
|
46
58
|
anthropic.types.tool_choice_any_param.ToolChoiceAnyParam,
|
|
47
59
|
anthropic.types.tool_choice_tool_param.ToolChoiceToolParam,
|
|
60
|
+
anthropic.types.tool_choice_none_param.ToolChoiceNoneParam,
|
|
61
|
+
anthropic.NotGiven,
|
|
62
|
+
]
|
|
63
|
+
tools: Union[
|
|
64
|
+
Iterable[
|
|
65
|
+
Union[
|
|
66
|
+
anthropic.types.tool_param.ToolParam,
|
|
67
|
+
anthropic.types.tool_bash_20250124_param.ToolBash20250124Param,
|
|
68
|
+
anthropic.types.tool_text_editor_20250124_param.ToolTextEditor20250124Param,
|
|
69
|
+
]
|
|
70
|
+
],
|
|
48
71
|
anthropic.NotGiven,
|
|
49
72
|
]
|
|
50
|
-
tools: Union[Iterable[anthropic.types.tool_param.ToolParam], anthropic.NotGiven]
|
|
51
73
|
top_k: int | anthropic.NotGiven
|
|
52
74
|
top_p: float | anthropic.NotGiven
|
|
53
75
|
extra_headers: Optional[Mapping[str, Union[str, anthropic.Omit]]]
|