chatlas 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +79 -45
- chatlas/_auto.py +3 -12
- chatlas/_chat.py +774 -148
- chatlas/_content.py +149 -29
- chatlas/_databricks.py +4 -14
- chatlas/_github.py +21 -25
- chatlas/_google.py +71 -32
- chatlas/_groq.py +15 -18
- chatlas/_interpolate.py +3 -4
- chatlas/_mcp_manager.py +306 -0
- chatlas/_ollama.py +14 -18
- chatlas/_openai.py +74 -39
- chatlas/_perplexity.py +14 -18
- chatlas/_provider.py +78 -8
- chatlas/_snowflake.py +29 -18
- chatlas/_tokens.py +85 -5
- chatlas/_tools.py +181 -22
- chatlas/_turn.py +2 -18
- chatlas/_utils.py +27 -1
- chatlas/_version.py +2 -2
- chatlas/data/prices.json +264 -0
- chatlas/types/anthropic/_submit.py +2 -0
- chatlas/types/openai/_client.py +1 -0
- chatlas/types/openai/_client_azure.py +1 -0
- chatlas/types/openai/_submit.py +4 -1
- chatlas-0.9.1.dist-info/METADATA +141 -0
- chatlas-0.9.1.dist-info/RECORD +48 -0
- chatlas-0.8.1.dist-info/METADATA +0 -383
- chatlas-0.8.1.dist-info/RECORD +0 -46
- {chatlas-0.8.1.dist-info → chatlas-0.9.1.dist-info}/WHEEL +0 -0
- {chatlas-0.8.1.dist-info → chatlas-0.9.1.dist-info}/licenses/LICENSE +0 -0
chatlas/_interpolate.py
CHANGED
|
@@ -23,7 +23,7 @@ def interpolate(
|
|
|
23
23
|
This is a light-weight wrapper around the Jinja2 templating engine, making
|
|
24
24
|
it easier to interpolate dynamic data into a prompt template. Compared to
|
|
25
25
|
f-strings, which expects you to wrap dynamic values in `{ }`, this function
|
|
26
|
-
expects `{{ }}` instead, making it easier to include Python code and JSON in
|
|
26
|
+
expects `{{{ }}}` instead, making it easier to include Python code and JSON in
|
|
27
27
|
your prompt.
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
@@ -80,7 +80,7 @@ def interpolate_file(
|
|
|
80
80
|
This is a light-weight wrapper around the Jinja2 templating engine, making
|
|
81
81
|
it easier to interpolate dynamic data into a static prompt. Compared to
|
|
82
82
|
f-strings, which expects you to wrap dynamic values in `{ }`, this function
|
|
83
|
-
expects `{{ }}` instead, making it easier to include Python code and JSON in
|
|
83
|
+
expects `{{{ }}}` instead, making it easier to include Python code and JSON in
|
|
84
84
|
your prompt.
|
|
85
85
|
|
|
86
86
|
Parameters
|
|
@@ -102,8 +102,7 @@ def interpolate_file(
|
|
|
102
102
|
|
|
103
103
|
See Also
|
|
104
104
|
--------
|
|
105
|
-
interpolate
|
|
106
|
-
Interpolating data into a system prompt
|
|
105
|
+
* :func:`~chatlas.interpolate` : Interpolating data into a prompt
|
|
107
106
|
"""
|
|
108
107
|
if variables is None:
|
|
109
108
|
frame = inspect.currentframe()
|
chatlas/_mcp_manager.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import warnings
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from contextlib import AsyncExitStack
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
|
9
|
+
|
|
10
|
+
from ._tools import Tool
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from mcp import ClientSession
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class SessionInfo(ABC):
|
|
18
|
+
# Input parameters
|
|
19
|
+
name: str
|
|
20
|
+
include_tools: Sequence[str] = field(default_factory=list)
|
|
21
|
+
exclude_tools: Sequence[str] = field(default_factory=list)
|
|
22
|
+
namespace: str | None = None
|
|
23
|
+
|
|
24
|
+
# Primary derived attributes
|
|
25
|
+
session: ClientSession | None = None
|
|
26
|
+
tools: dict[str, Tool] = field(default_factory=dict)
|
|
27
|
+
|
|
28
|
+
# Background task management
|
|
29
|
+
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
|
|
30
|
+
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
|
|
31
|
+
exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack)
|
|
32
|
+
task: asyncio.Task | None = None
|
|
33
|
+
error: asyncio.CancelledError | Exception | None = None
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def open_session(self) -> None: ...
|
|
37
|
+
|
|
38
|
+
async def close_session(self) -> None:
|
|
39
|
+
await self.exit_stack.aclose()
|
|
40
|
+
|
|
41
|
+
async def request_tools(self) -> None:
|
|
42
|
+
if self.session is None:
|
|
43
|
+
raise ValueError("Session must be opened before requesting tools.")
|
|
44
|
+
|
|
45
|
+
if self.include_tools and self.exclude_tools:
|
|
46
|
+
raise ValueError("Cannot specify both include_tools and exclude_tools.")
|
|
47
|
+
|
|
48
|
+
# Request the MCP tools available
|
|
49
|
+
response = await self.session.list_tools()
|
|
50
|
+
tool_names = set(x.name for x in response.tools)
|
|
51
|
+
|
|
52
|
+
# Warn if tools are mis-specified
|
|
53
|
+
include = set(self.include_tools or [])
|
|
54
|
+
missing_include = include.difference(tool_names)
|
|
55
|
+
if missing_include:
|
|
56
|
+
warnings.warn(
|
|
57
|
+
f"Specified include_tools {missing_include} did not match any tools from the MCP server. "
|
|
58
|
+
f"The tools available are: {tool_names}",
|
|
59
|
+
stacklevel=2,
|
|
60
|
+
)
|
|
61
|
+
exclude = set(self.exclude_tools or [])
|
|
62
|
+
missing_exclude = exclude.difference(tool_names)
|
|
63
|
+
if missing_exclude:
|
|
64
|
+
warnings.warn(
|
|
65
|
+
f"Specified exclude_tools {missing_exclude} did not match any tools from the MCP server. "
|
|
66
|
+
f"The tools available are: {tool_names}",
|
|
67
|
+
stacklevel=2,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Filter the tool names
|
|
71
|
+
if include:
|
|
72
|
+
tool_names = include.intersection(tool_names)
|
|
73
|
+
if exclude:
|
|
74
|
+
tool_names = tool_names.difference(exclude)
|
|
75
|
+
|
|
76
|
+
# Apply namespace and convert to chatlas.Tool instances
|
|
77
|
+
self_tools: dict[str, Tool] = {}
|
|
78
|
+
for tool in response.tools:
|
|
79
|
+
if tool.name not in tool_names:
|
|
80
|
+
continue
|
|
81
|
+
if self.namespace:
|
|
82
|
+
tool.name = f"{self.namespace}.{tool.name}"
|
|
83
|
+
self_tools[tool.name] = Tool.from_mcp(
|
|
84
|
+
session=self.session,
|
|
85
|
+
mcp_tool=tool,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Store the tools
|
|
89
|
+
self.tools = self_tools
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class HTTPSessionInfo(SessionInfo):
|
|
94
|
+
url: str = ""
|
|
95
|
+
transport_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
96
|
+
|
|
97
|
+
async def open_session(self):
|
|
98
|
+
mcp = try_import_mcp()
|
|
99
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
100
|
+
|
|
101
|
+
read, write, _ = await self.exit_stack.enter_async_context(
|
|
102
|
+
streamablehttp_client(
|
|
103
|
+
self.url,
|
|
104
|
+
**self.transport_kwargs,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
session = await self.exit_stack.enter_async_context(
|
|
108
|
+
mcp.ClientSession(read, write)
|
|
109
|
+
)
|
|
110
|
+
server = await session.initialize()
|
|
111
|
+
self.session = session
|
|
112
|
+
if not self.name:
|
|
113
|
+
self.name = server.serverInfo.name or "mcp"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class STDIOSessionInfo(SessionInfo):
|
|
118
|
+
command: str = ""
|
|
119
|
+
args: list[str] = field(default_factory=list)
|
|
120
|
+
transport_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
121
|
+
|
|
122
|
+
async def open_session(self):
|
|
123
|
+
mcp = try_import_mcp()
|
|
124
|
+
from mcp.client.stdio import stdio_client
|
|
125
|
+
|
|
126
|
+
server_params = mcp.StdioServerParameters(
|
|
127
|
+
command=self.command,
|
|
128
|
+
args=self.args,
|
|
129
|
+
**self.transport_kwargs,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
transport = await self.exit_stack.enter_async_context(
|
|
133
|
+
stdio_client(server_params)
|
|
134
|
+
)
|
|
135
|
+
session = await self.exit_stack.enter_async_context(
|
|
136
|
+
mcp.ClientSession(*transport)
|
|
137
|
+
)
|
|
138
|
+
server = await session.initialize()
|
|
139
|
+
self.session = session
|
|
140
|
+
if not self.name:
|
|
141
|
+
self.name = server.serverInfo.name or "mcp"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class MCPSessionManager:
|
|
145
|
+
"""Manages MCP (Model Context Protocol) server connections and tools."""
|
|
146
|
+
|
|
147
|
+
def __init__(self):
|
|
148
|
+
self._mcp_sessions: dict[str, SessionInfo] = {}
|
|
149
|
+
|
|
150
|
+
async def register_http_stream_tools(
|
|
151
|
+
self,
|
|
152
|
+
*,
|
|
153
|
+
url: str,
|
|
154
|
+
name: str | None,
|
|
155
|
+
include_tools: Sequence[str],
|
|
156
|
+
exclude_tools: Sequence[str],
|
|
157
|
+
namespace: str | None,
|
|
158
|
+
transport_kwargs: dict[str, Any],
|
|
159
|
+
):
|
|
160
|
+
session_info = HTTPSessionInfo(
|
|
161
|
+
name=name or "",
|
|
162
|
+
url=url,
|
|
163
|
+
include_tools=include_tools,
|
|
164
|
+
exclude_tools=exclude_tools,
|
|
165
|
+
namespace=namespace,
|
|
166
|
+
transport_kwargs=transport_kwargs or {},
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Launch background task that runs until MCP session is *shutdown*
|
|
170
|
+
# N.B. this is needed since mcp sessions must be opened and closed in the same task
|
|
171
|
+
asyncio.create_task(self.open_session(session_info))
|
|
172
|
+
|
|
173
|
+
# Wait for a ready event from the task (signals that tools are registered)
|
|
174
|
+
await session_info.ready_event.wait()
|
|
175
|
+
|
|
176
|
+
# An error might have been caught in the background task
|
|
177
|
+
if session_info.error:
|
|
178
|
+
raise RuntimeError(
|
|
179
|
+
f"Failed to register tools from MCP server '{name}' at URL '{url}'"
|
|
180
|
+
) from session_info.error
|
|
181
|
+
|
|
182
|
+
return session_info
|
|
183
|
+
|
|
184
|
+
async def register_stdio_tools(
|
|
185
|
+
self,
|
|
186
|
+
*,
|
|
187
|
+
command: str,
|
|
188
|
+
args: list[str],
|
|
189
|
+
name: str | None,
|
|
190
|
+
include_tools: Sequence[str],
|
|
191
|
+
exclude_tools: Sequence[str],
|
|
192
|
+
namespace: str | None,
|
|
193
|
+
transport_kwargs: dict[str, Any],
|
|
194
|
+
):
|
|
195
|
+
session_info = STDIOSessionInfo(
|
|
196
|
+
name=name or "",
|
|
197
|
+
command=command,
|
|
198
|
+
args=args,
|
|
199
|
+
include_tools=include_tools,
|
|
200
|
+
exclude_tools=exclude_tools,
|
|
201
|
+
namespace=namespace,
|
|
202
|
+
transport_kwargs=transport_kwargs or {},
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Launch a background task to initialize the MCP server
|
|
206
|
+
# N.B. this is needed since mcp sessions must be opened and closed in the same task
|
|
207
|
+
asyncio.create_task(self.open_session(session_info))
|
|
208
|
+
|
|
209
|
+
# Wait for a ready event from the task (signals that tools are registered)
|
|
210
|
+
await session_info.ready_event.wait()
|
|
211
|
+
|
|
212
|
+
# An error might have been caught in the background task
|
|
213
|
+
if session_info.error:
|
|
214
|
+
raise RuntimeError(
|
|
215
|
+
f"Failed to register tools from MCP server '{name}' with command '{command} {args}'"
|
|
216
|
+
) from session_info.error
|
|
217
|
+
|
|
218
|
+
return session_info
|
|
219
|
+
|
|
220
|
+
async def open_session(self, session_info: "SessionInfo"):
|
|
221
|
+
session_info.task = asyncio.current_task()
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
# Open the MCP session
|
|
225
|
+
await session_info.open_session()
|
|
226
|
+
# Request the tools
|
|
227
|
+
await session_info.request_tools()
|
|
228
|
+
# Make sure session can be added to the manager
|
|
229
|
+
self.add_session(session_info)
|
|
230
|
+
except (asyncio.CancelledError, Exception) as err:
|
|
231
|
+
# Keep the error so we can handle in the main task
|
|
232
|
+
session_info.error = err
|
|
233
|
+
# Make sure the session is closed
|
|
234
|
+
try:
|
|
235
|
+
await session_info.close_session()
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
return
|
|
239
|
+
finally:
|
|
240
|
+
# Whether successful or not, set ready state to prevent deadlock
|
|
241
|
+
session_info.ready_event.set()
|
|
242
|
+
|
|
243
|
+
# If successful, wait for shutdown signal
|
|
244
|
+
await session_info.shutdown_event.wait()
|
|
245
|
+
|
|
246
|
+
# On shutdown close connection to MCP server
|
|
247
|
+
# This is why we're using a background task in the 1st place...
|
|
248
|
+
# we must close in the same task that opened the session
|
|
249
|
+
await session_info.close_session()
|
|
250
|
+
|
|
251
|
+
async def close_sessions(self, names: Optional[Sequence[str]] = None):
|
|
252
|
+
if names is None:
|
|
253
|
+
names = list(self._mcp_sessions.keys())
|
|
254
|
+
|
|
255
|
+
if isinstance(names, str):
|
|
256
|
+
names = [names]
|
|
257
|
+
|
|
258
|
+
closed_sessions: list[SessionInfo] = []
|
|
259
|
+
for x in names:
|
|
260
|
+
session = await self.close_background_session(x)
|
|
261
|
+
if session is None:
|
|
262
|
+
continue
|
|
263
|
+
closed_sessions.append(session)
|
|
264
|
+
|
|
265
|
+
return closed_sessions
|
|
266
|
+
|
|
267
|
+
async def close_background_session(self, name: str) -> SessionInfo | None:
|
|
268
|
+
session = self.remove_session(name)
|
|
269
|
+
if session is None:
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
# Signal shutdown and wait for the task to finish
|
|
273
|
+
session.shutdown_event.set()
|
|
274
|
+
if session.task is not None:
|
|
275
|
+
await session.task
|
|
276
|
+
|
|
277
|
+
return session
|
|
278
|
+
|
|
279
|
+
def add_session(self, session_info: SessionInfo) -> None:
|
|
280
|
+
name = session_info.name
|
|
281
|
+
if name in self._mcp_sessions:
|
|
282
|
+
raise ValueError(f"Already connected to an MCP server named: '{name}'.")
|
|
283
|
+
self._mcp_sessions[name] = session_info
|
|
284
|
+
|
|
285
|
+
def remove_session(self, name: str) -> SessionInfo | None:
|
|
286
|
+
if name not in self._mcp_sessions:
|
|
287
|
+
warnings.warn(
|
|
288
|
+
f"Cannot close MCP session named '{name}' since it was not found.",
|
|
289
|
+
stacklevel=2,
|
|
290
|
+
)
|
|
291
|
+
return None
|
|
292
|
+
session = self._mcp_sessions[name]
|
|
293
|
+
del self._mcp_sessions[name]
|
|
294
|
+
return session
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def try_import_mcp():
|
|
298
|
+
try:
|
|
299
|
+
import mcp
|
|
300
|
+
|
|
301
|
+
return mcp
|
|
302
|
+
except ImportError:
|
|
303
|
+
raise ImportError(
|
|
304
|
+
"The `mcp` package is required to connect to MCP servers. "
|
|
305
|
+
"Install it with `pip install mcp`."
|
|
306
|
+
)
|
chatlas/_ollama.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Optional
|
|
|
7
7
|
import orjson
|
|
8
8
|
|
|
9
9
|
from ._chat import Chat
|
|
10
|
-
from ._openai import
|
|
11
|
-
from .
|
|
10
|
+
from ._openai import OpenAIProvider
|
|
11
|
+
from ._utils import MISSING_TYPE, is_testing
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
from ._openai import ChatCompletion
|
|
@@ -19,7 +19,6 @@ def ChatOllama(
|
|
|
19
19
|
model: Optional[str] = None,
|
|
20
20
|
*,
|
|
21
21
|
system_prompt: Optional[str] = None,
|
|
22
|
-
turns: Optional[list[Turn]] = None,
|
|
23
22
|
base_url: str = "http://localhost:11434",
|
|
24
23
|
seed: Optional[int] = None,
|
|
25
24
|
kwargs: Optional["ChatClientArgs"] = None,
|
|
@@ -67,13 +66,6 @@ def ChatOllama(
|
|
|
67
66
|
models will be printed.
|
|
68
67
|
system_prompt
|
|
69
68
|
A system prompt to set the behavior of the assistant.
|
|
70
|
-
turns
|
|
71
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
72
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
73
|
-
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
74
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
75
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
76
|
-
there is also a `content` field, which is a string.
|
|
77
69
|
base_url
|
|
78
70
|
The base URL to the endpoint; the default uses ollama's API.
|
|
79
71
|
seed
|
|
@@ -102,15 +94,19 @@ def ChatOllama(
|
|
|
102
94
|
raise ValueError(
|
|
103
95
|
f"Must specify model. Locally installed models: {', '.join(models)}"
|
|
104
96
|
)
|
|
105
|
-
|
|
106
|
-
|
|
97
|
+
if isinstance(seed, MISSING_TYPE):
|
|
98
|
+
seed = 1014 if is_testing() else None
|
|
99
|
+
|
|
100
|
+
return Chat(
|
|
101
|
+
provider=OpenAIProvider(
|
|
102
|
+
api_key="ollama", # ignored
|
|
103
|
+
model=model,
|
|
104
|
+
base_url=f"{base_url}/v1",
|
|
105
|
+
seed=seed,
|
|
106
|
+
name="Ollama",
|
|
107
|
+
kwargs=kwargs,
|
|
108
|
+
),
|
|
107
109
|
system_prompt=system_prompt,
|
|
108
|
-
api_key="ollama", # ignored
|
|
109
|
-
turns=turns,
|
|
110
|
-
base_url=f"{base_url}/v1",
|
|
111
|
-
model=model,
|
|
112
|
-
seed=seed,
|
|
113
|
-
kwargs=kwargs,
|
|
114
110
|
)
|
|
115
111
|
|
|
116
112
|
|
chatlas/_openai.py
CHANGED
|
@@ -4,6 +4,7 @@ import base64
|
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
|
|
5
5
|
|
|
6
6
|
import orjson
|
|
7
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
|
7
8
|
from pydantic import BaseModel
|
|
8
9
|
|
|
9
10
|
from ._chat import Chat
|
|
@@ -17,14 +18,16 @@ from ._content import (
|
|
|
17
18
|
ContentText,
|
|
18
19
|
ContentToolRequest,
|
|
19
20
|
ContentToolResult,
|
|
21
|
+
ContentToolResultImage,
|
|
22
|
+
ContentToolResultResource,
|
|
20
23
|
)
|
|
21
24
|
from ._logging import log_model_default
|
|
22
25
|
from ._merge import merge_dicts
|
|
23
|
-
from ._provider import Provider
|
|
26
|
+
from ._provider import Provider, StandardModelParamNames, StandardModelParams
|
|
24
27
|
from ._tokens import tokens_log
|
|
25
28
|
from ._tools import Tool, basemodel_to_param_schema
|
|
26
|
-
from ._turn import Turn,
|
|
27
|
-
from ._utils import MISSING, MISSING_TYPE, is_testing
|
|
29
|
+
from ._turn import Turn, user_turn
|
|
30
|
+
from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs
|
|
28
31
|
|
|
29
32
|
if TYPE_CHECKING:
|
|
30
33
|
from openai.types.chat import (
|
|
@@ -53,7 +56,6 @@ ChatCompletionDict = dict[str, Any]
|
|
|
53
56
|
def ChatOpenAI(
|
|
54
57
|
*,
|
|
55
58
|
system_prompt: Optional[str] = None,
|
|
56
|
-
turns: Optional[list[Turn]] = None,
|
|
57
59
|
model: "Optional[ChatModel | str]" = None,
|
|
58
60
|
api_key: Optional[str] = None,
|
|
59
61
|
base_url: str = "https://api.openai.com/v1",
|
|
@@ -92,13 +94,6 @@ def ChatOpenAI(
|
|
|
92
94
|
----------
|
|
93
95
|
system_prompt
|
|
94
96
|
A system prompt to set the behavior of the assistant.
|
|
95
|
-
turns
|
|
96
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
97
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
98
|
-
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
99
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
100
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
101
|
-
there is also a `content` field, which is a string.
|
|
102
97
|
model
|
|
103
98
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
104
99
|
default, and warn you about it. We strongly recommend explicitly
|
|
@@ -161,7 +156,7 @@ def ChatOpenAI(
|
|
|
161
156
|
seed = 1014 if is_testing() else None
|
|
162
157
|
|
|
163
158
|
if model is None:
|
|
164
|
-
model = log_model_default("gpt-
|
|
159
|
+
model = log_model_default("gpt-4.1")
|
|
165
160
|
|
|
166
161
|
return Chat(
|
|
167
162
|
provider=OpenAIProvider(
|
|
@@ -171,14 +166,13 @@ def ChatOpenAI(
|
|
|
171
166
|
seed=seed,
|
|
172
167
|
kwargs=kwargs,
|
|
173
168
|
),
|
|
174
|
-
|
|
175
|
-
turns or [],
|
|
176
|
-
system_prompt,
|
|
177
|
-
),
|
|
169
|
+
system_prompt=system_prompt,
|
|
178
170
|
)
|
|
179
171
|
|
|
180
172
|
|
|
181
|
-
class OpenAIProvider(
|
|
173
|
+
class OpenAIProvider(
|
|
174
|
+
Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict, "SubmitInputArgs"]
|
|
175
|
+
):
|
|
182
176
|
def __init__(
|
|
183
177
|
self,
|
|
184
178
|
*,
|
|
@@ -186,11 +180,11 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
186
180
|
model: str,
|
|
187
181
|
base_url: str = "https://api.openai.com/v1",
|
|
188
182
|
seed: Optional[int] = None,
|
|
183
|
+
name: str = "OpenAI",
|
|
189
184
|
kwargs: Optional["ChatClientArgs"] = None,
|
|
190
185
|
):
|
|
191
|
-
|
|
186
|
+
super().__init__(name=name, model=model)
|
|
192
187
|
|
|
193
|
-
self._model = model
|
|
194
188
|
self._seed = seed
|
|
195
189
|
|
|
196
190
|
kwargs_full: "ChatClientArgs" = {
|
|
@@ -199,9 +193,12 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
199
193
|
**(kwargs or {}),
|
|
200
194
|
}
|
|
201
195
|
|
|
196
|
+
# Avoid passing the wrong sync/async client to the OpenAI constructor.
|
|
197
|
+
sync_kwargs, async_kwargs = split_http_client_kwargs(kwargs_full)
|
|
198
|
+
|
|
202
199
|
# TODO: worth bringing in AsyncOpenAI types?
|
|
203
|
-
self._client = OpenAI(**
|
|
204
|
-
self._async_client = AsyncOpenAI(**
|
|
200
|
+
self._client = OpenAI(**sync_kwargs) # type: ignore
|
|
201
|
+
self._async_client = AsyncOpenAI(**async_kwargs)
|
|
205
202
|
|
|
206
203
|
@overload
|
|
207
204
|
def chat_perform(
|
|
@@ -284,7 +281,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
284
281
|
kwargs_full: "SubmitInputArgs" = {
|
|
285
282
|
"stream": stream,
|
|
286
283
|
"messages": self._as_message_param(turns),
|
|
287
|
-
"model": self.
|
|
284
|
+
"model": self.model,
|
|
288
285
|
**(kwargs or {}),
|
|
289
286
|
}
|
|
290
287
|
|
|
@@ -487,6 +484,12 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
487
484
|
}
|
|
488
485
|
)
|
|
489
486
|
elif isinstance(x, ContentToolResult):
|
|
487
|
+
if isinstance(
|
|
488
|
+
x, (ContentToolResultImage, ContentToolResultResource)
|
|
489
|
+
):
|
|
490
|
+
raise NotImplementedError(
|
|
491
|
+
"OpenAI does not support tool results with images or resources."
|
|
492
|
+
)
|
|
490
493
|
tool_results.append(
|
|
491
494
|
ChatCompletionToolMessageParam(
|
|
492
495
|
# Currently, OpenAI only allows for text content in tool results
|
|
@@ -573,6 +576,46 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
573
576
|
completion=completion,
|
|
574
577
|
)
|
|
575
578
|
|
|
579
|
+
def translate_model_params(self, params: StandardModelParams) -> "SubmitInputArgs":
|
|
580
|
+
res: "SubmitInputArgs" = {}
|
|
581
|
+
if "temperature" in params:
|
|
582
|
+
res["temperature"] = params["temperature"]
|
|
583
|
+
|
|
584
|
+
if "top_p" in params:
|
|
585
|
+
res["top_p"] = params["top_p"]
|
|
586
|
+
|
|
587
|
+
if "frequency_penalty" in params:
|
|
588
|
+
res["frequency_penalty"] = params["frequency_penalty"]
|
|
589
|
+
|
|
590
|
+
if "presence_penalty" in params:
|
|
591
|
+
res["presence_penalty"] = params["presence_penalty"]
|
|
592
|
+
|
|
593
|
+
if "seed" in params:
|
|
594
|
+
res["seed"] = params["seed"]
|
|
595
|
+
|
|
596
|
+
if "max_tokens" in params:
|
|
597
|
+
res["max_tokens"] = params["max_tokens"]
|
|
598
|
+
|
|
599
|
+
if "log_probs" in params:
|
|
600
|
+
res["logprobs"] = params["log_probs"]
|
|
601
|
+
|
|
602
|
+
if "stop_sequences" in params:
|
|
603
|
+
res["stop"] = params["stop_sequences"]
|
|
604
|
+
|
|
605
|
+
return res
|
|
606
|
+
|
|
607
|
+
def supported_model_params(self) -> set[StandardModelParamNames]:
|
|
608
|
+
return {
|
|
609
|
+
"temperature",
|
|
610
|
+
"top_p",
|
|
611
|
+
"frequency_penalty",
|
|
612
|
+
"presence_penalty",
|
|
613
|
+
"seed",
|
|
614
|
+
"max_tokens",
|
|
615
|
+
"log_probs",
|
|
616
|
+
"stop_sequences",
|
|
617
|
+
}
|
|
618
|
+
|
|
576
619
|
|
|
577
620
|
def ChatAzureOpenAI(
|
|
578
621
|
*,
|
|
@@ -581,7 +624,6 @@ def ChatAzureOpenAI(
|
|
|
581
624
|
api_version: str,
|
|
582
625
|
api_key: Optional[str] = None,
|
|
583
626
|
system_prompt: Optional[str] = None,
|
|
584
|
-
turns: Optional[list[Turn]] = None,
|
|
585
627
|
seed: int | None | MISSING_TYPE = MISSING,
|
|
586
628
|
kwargs: Optional["ChatAzureClientArgs"] = None,
|
|
587
629
|
) -> Chat["SubmitInputArgs", ChatCompletion]:
|
|
@@ -624,13 +666,6 @@ def ChatAzureOpenAI(
|
|
|
624
666
|
variable.
|
|
625
667
|
system_prompt
|
|
626
668
|
A system prompt to set the behavior of the assistant.
|
|
627
|
-
turns
|
|
628
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
629
|
-
conversation). If not provided, the conversation begins from scratch.
|
|
630
|
-
Do not provide non-None values for both `turns` and `system_prompt`.
|
|
631
|
-
Each message in the list should be a dictionary with at least `role`
|
|
632
|
-
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
633
|
-
Normally there is also a `content` field, which is a string.
|
|
634
669
|
seed
|
|
635
670
|
Optional integer seed that ChatGPT uses to try and make output more
|
|
636
671
|
reproducible.
|
|
@@ -655,10 +690,7 @@ def ChatAzureOpenAI(
|
|
|
655
690
|
seed=seed,
|
|
656
691
|
kwargs=kwargs,
|
|
657
692
|
),
|
|
658
|
-
|
|
659
|
-
turns or [],
|
|
660
|
-
system_prompt,
|
|
661
|
-
),
|
|
693
|
+
system_prompt=system_prompt,
|
|
662
694
|
)
|
|
663
695
|
|
|
664
696
|
|
|
@@ -667,15 +699,16 @@ class OpenAIAzureProvider(OpenAIProvider):
|
|
|
667
699
|
self,
|
|
668
700
|
*,
|
|
669
701
|
endpoint: Optional[str] = None,
|
|
670
|
-
deployment_id:
|
|
702
|
+
deployment_id: str,
|
|
671
703
|
api_version: Optional[str] = None,
|
|
672
704
|
api_key: Optional[str] = None,
|
|
673
705
|
seed: int | None = None,
|
|
706
|
+
name: str = "OpenAIAzure",
|
|
707
|
+
model: Optional[str] = "UnusedValue",
|
|
674
708
|
kwargs: Optional["ChatAzureClientArgs"] = None,
|
|
675
709
|
):
|
|
676
|
-
|
|
710
|
+
super().__init__(name=name, model=deployment_id)
|
|
677
711
|
|
|
678
|
-
self._model = deployment_id
|
|
679
712
|
self._seed = seed
|
|
680
713
|
|
|
681
714
|
kwargs_full: "ChatAzureClientArgs" = {
|
|
@@ -686,8 +719,10 @@ class OpenAIAzureProvider(OpenAIProvider):
|
|
|
686
719
|
**(kwargs or {}),
|
|
687
720
|
}
|
|
688
721
|
|
|
689
|
-
|
|
690
|
-
|
|
722
|
+
sync_kwargs, async_kwargs = split_http_client_kwargs(kwargs_full)
|
|
723
|
+
|
|
724
|
+
self._client = AzureOpenAI(**sync_kwargs) # type: ignore
|
|
725
|
+
self._async_client = AsyncAzureOpenAI(**async_kwargs) # type: ignore
|
|
691
726
|
|
|
692
727
|
|
|
693
728
|
class InvalidJSONParameterWarning(RuntimeWarning):
|
chatlas/_perplexity.py
CHANGED
|
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Optional
|
|
|
5
5
|
|
|
6
6
|
from ._chat import Chat
|
|
7
7
|
from ._logging import log_model_default
|
|
8
|
-
from ._openai import
|
|
9
|
-
from .
|
|
10
|
-
from ._utils import MISSING, MISSING_TYPE
|
|
8
|
+
from ._openai import OpenAIProvider
|
|
9
|
+
from ._utils import MISSING, MISSING_TYPE, is_testing
|
|
11
10
|
|
|
12
11
|
if TYPE_CHECKING:
|
|
13
12
|
from ._openai import ChatCompletion
|
|
@@ -17,7 +16,6 @@ if TYPE_CHECKING:
|
|
|
17
16
|
def ChatPerplexity(
|
|
18
17
|
*,
|
|
19
18
|
system_prompt: Optional[str] = None,
|
|
20
|
-
turns: Optional[list[Turn]] = None,
|
|
21
19
|
model: Optional[str] = None,
|
|
22
20
|
api_key: Optional[str] = None,
|
|
23
21
|
base_url: str = "https://api.perplexity.ai/",
|
|
@@ -56,13 +54,6 @@ def ChatPerplexity(
|
|
|
56
54
|
----------
|
|
57
55
|
system_prompt
|
|
58
56
|
A system prompt to set the behavior of the assistant.
|
|
59
|
-
turns
|
|
60
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
61
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
62
|
-
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
63
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
64
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
65
|
-
there is also a `content` field, which is a string.
|
|
66
57
|
model
|
|
67
58
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
68
59
|
default, and warn you about it. We strongly recommend explicitly
|
|
@@ -131,12 +122,17 @@ def ChatPerplexity(
|
|
|
131
122
|
if api_key is None:
|
|
132
123
|
api_key = os.getenv("PERPLEXITY_API_KEY")
|
|
133
124
|
|
|
134
|
-
|
|
125
|
+
if isinstance(seed, MISSING_TYPE):
|
|
126
|
+
seed = 1014 if is_testing() else None
|
|
127
|
+
|
|
128
|
+
return Chat(
|
|
129
|
+
provider=OpenAIProvider(
|
|
130
|
+
api_key=api_key,
|
|
131
|
+
model=model,
|
|
132
|
+
base_url=base_url,
|
|
133
|
+
seed=seed,
|
|
134
|
+
name="Perplexity",
|
|
135
|
+
kwargs=kwargs,
|
|
136
|
+
),
|
|
135
137
|
system_prompt=system_prompt,
|
|
136
|
-
turns=turns,
|
|
137
|
-
model=model,
|
|
138
|
-
api_key=api_key,
|
|
139
|
-
base_url=base_url,
|
|
140
|
-
seed=seed,
|
|
141
|
-
kwargs=kwargs,
|
|
142
138
|
)
|