chatlas 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chatlas/_interpolate.py CHANGED
@@ -23,7 +23,7 @@ def interpolate(
23
23
  This is a light-weight wrapper around the Jinja2 templating engine, making
24
24
  it easier to interpolate dynamic data into a prompt template. Compared to
25
25
  f-strings, which expects you to wrap dynamic values in `{ }`, this function
26
- expects `{{ }}` instead, making it easier to include Python code and JSON in
26
+ expects `{{{ }}}` instead, making it easier to include Python code and JSON in
27
27
  your prompt.
28
28
 
29
29
  Parameters
@@ -80,7 +80,7 @@ def interpolate_file(
80
80
  This is a light-weight wrapper around the Jinja2 templating engine, making
81
81
  it easier to interpolate dynamic data into a static prompt. Compared to
82
82
  f-strings, which expects you to wrap dynamic values in `{ }`, this function
83
- expects `{{ }}` instead, making it easier to include Python code and JSON in
83
+ expects `{{{ }}}` instead, making it easier to include Python code and JSON in
84
84
  your prompt.
85
85
 
86
86
  Parameters
@@ -102,8 +102,7 @@ def interpolate_file(
102
102
 
103
103
  See Also
104
104
  --------
105
- interpolate
106
- Interpolating data into a system prompt
105
+ * :func:`~chatlas.interpolate` : Interpolating data into a prompt
107
106
  """
108
107
  if variables is None:
109
108
  frame = inspect.currentframe()
@@ -0,0 +1,306 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import warnings
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import AsyncExitStack
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any, Optional, Sequence
9
+
10
+ from ._tools import Tool
11
+
12
+ if TYPE_CHECKING:
13
+ from mcp import ClientSession
14
+
15
+
16
+ @dataclass
17
+ class SessionInfo(ABC):
18
+ # Input parameters
19
+ name: str
20
+ include_tools: Sequence[str] = field(default_factory=list)
21
+ exclude_tools: Sequence[str] = field(default_factory=list)
22
+ namespace: str | None = None
23
+
24
+ # Primary derived attributes
25
+ session: ClientSession | None = None
26
+ tools: dict[str, Tool] = field(default_factory=dict)
27
+
28
+ # Background task management
29
+ ready_event: asyncio.Event = field(default_factory=asyncio.Event)
30
+ shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
31
+ exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack)
32
+ task: asyncio.Task | None = None
33
+ error: asyncio.CancelledError | Exception | None = None
34
+
35
+ @abstractmethod
36
+ async def open_session(self) -> None: ...
37
+
38
+ async def close_session(self) -> None:
39
+ await self.exit_stack.aclose()
40
+
41
+ async def request_tools(self) -> None:
42
+ if self.session is None:
43
+ raise ValueError("Session must be opened before requesting tools.")
44
+
45
+ if self.include_tools and self.exclude_tools:
46
+ raise ValueError("Cannot specify both include_tools and exclude_tools.")
47
+
48
+ # Request the MCP tools available
49
+ response = await self.session.list_tools()
50
+ tool_names = set(x.name for x in response.tools)
51
+
52
+ # Warn if tools are mis-specified
53
+ include = set(self.include_tools or [])
54
+ missing_include = include.difference(tool_names)
55
+ if missing_include:
56
+ warnings.warn(
57
+ f"Specified include_tools {missing_include} did not match any tools from the MCP server. "
58
+ f"The tools available are: {tool_names}",
59
+ stacklevel=2,
60
+ )
61
+ exclude = set(self.exclude_tools or [])
62
+ missing_exclude = exclude.difference(tool_names)
63
+ if missing_exclude:
64
+ warnings.warn(
65
+ f"Specified exclude_tools {missing_exclude} did not match any tools from the MCP server. "
66
+ f"The tools available are: {tool_names}",
67
+ stacklevel=2,
68
+ )
69
+
70
+ # Filter the tool names
71
+ if include:
72
+ tool_names = include.intersection(tool_names)
73
+ if exclude:
74
+ tool_names = tool_names.difference(exclude)
75
+
76
+ # Apply namespace and convert to chatlas.Tool instances
77
+ self_tools: dict[str, Tool] = {}
78
+ for tool in response.tools:
79
+ if tool.name not in tool_names:
80
+ continue
81
+ if self.namespace:
82
+ tool.name = f"{self.namespace}.{tool.name}"
83
+ self_tools[tool.name] = Tool.from_mcp(
84
+ session=self.session,
85
+ mcp_tool=tool,
86
+ )
87
+
88
+ # Store the tools
89
+ self.tools = self_tools
90
+
91
+
92
+ @dataclass
93
+ class HTTPSessionInfo(SessionInfo):
94
+ url: str = ""
95
+ transport_kwargs: dict[str, Any] = field(default_factory=dict)
96
+
97
+ async def open_session(self):
98
+ mcp = try_import_mcp()
99
+ from mcp.client.streamable_http import streamablehttp_client
100
+
101
+ read, write, _ = await self.exit_stack.enter_async_context(
102
+ streamablehttp_client(
103
+ self.url,
104
+ **self.transport_kwargs,
105
+ )
106
+ )
107
+ session = await self.exit_stack.enter_async_context(
108
+ mcp.ClientSession(read, write)
109
+ )
110
+ server = await session.initialize()
111
+ self.session = session
112
+ if not self.name:
113
+ self.name = server.serverInfo.name or "mcp"
114
+
115
+
116
+ @dataclass
117
+ class STDIOSessionInfo(SessionInfo):
118
+ command: str = ""
119
+ args: list[str] = field(default_factory=list)
120
+ transport_kwargs: dict[str, Any] = field(default_factory=dict)
121
+
122
+ async def open_session(self):
123
+ mcp = try_import_mcp()
124
+ from mcp.client.stdio import stdio_client
125
+
126
+ server_params = mcp.StdioServerParameters(
127
+ command=self.command,
128
+ args=self.args,
129
+ **self.transport_kwargs,
130
+ )
131
+
132
+ transport = await self.exit_stack.enter_async_context(
133
+ stdio_client(server_params)
134
+ )
135
+ session = await self.exit_stack.enter_async_context(
136
+ mcp.ClientSession(*transport)
137
+ )
138
+ server = await session.initialize()
139
+ self.session = session
140
+ if not self.name:
141
+ self.name = server.serverInfo.name or "mcp"
142
+
143
+
144
+ class MCPSessionManager:
145
+ """Manages MCP (Model Context Protocol) server connections and tools."""
146
+
147
+ def __init__(self):
148
+ self._mcp_sessions: dict[str, SessionInfo] = {}
149
+
150
+ async def register_http_stream_tools(
151
+ self,
152
+ *,
153
+ url: str,
154
+ name: str | None,
155
+ include_tools: Sequence[str],
156
+ exclude_tools: Sequence[str],
157
+ namespace: str | None,
158
+ transport_kwargs: dict[str, Any],
159
+ ):
160
+ session_info = HTTPSessionInfo(
161
+ name=name or "",
162
+ url=url,
163
+ include_tools=include_tools,
164
+ exclude_tools=exclude_tools,
165
+ namespace=namespace,
166
+ transport_kwargs=transport_kwargs or {},
167
+ )
168
+
169
+ # Launch background task that runs until MCP session is *shutdown*
170
+ # N.B. this is needed since mcp sessions must be opened and closed in the same task
171
+ asyncio.create_task(self.open_session(session_info))
172
+
173
+ # Wait for a ready event from the task (signals that tools are registered)
174
+ await session_info.ready_event.wait()
175
+
176
+ # An error might have been caught in the background task
177
+ if session_info.error:
178
+ raise RuntimeError(
179
+ f"Failed to register tools from MCP server '{name}' at URL '{url}'"
180
+ ) from session_info.error
181
+
182
+ return session_info
183
+
184
+ async def register_stdio_tools(
185
+ self,
186
+ *,
187
+ command: str,
188
+ args: list[str],
189
+ name: str | None,
190
+ include_tools: Sequence[str],
191
+ exclude_tools: Sequence[str],
192
+ namespace: str | None,
193
+ transport_kwargs: dict[str, Any],
194
+ ):
195
+ session_info = STDIOSessionInfo(
196
+ name=name or "",
197
+ command=command,
198
+ args=args,
199
+ include_tools=include_tools,
200
+ exclude_tools=exclude_tools,
201
+ namespace=namespace,
202
+ transport_kwargs=transport_kwargs or {},
203
+ )
204
+
205
+ # Launch a background task to initialize the MCP server
206
+ # N.B. this is needed since mcp sessions must be opened and closed in the same task
207
+ asyncio.create_task(self.open_session(session_info))
208
+
209
+ # Wait for a ready event from the task (signals that tools are registered)
210
+ await session_info.ready_event.wait()
211
+
212
+ # An error might have been caught in the background task
213
+ if session_info.error:
214
+ raise RuntimeError(
215
+ f"Failed to register tools from MCP server '{name}' with command '{command} {args}'"
216
+ ) from session_info.error
217
+
218
+ return session_info
219
+
220
+ async def open_session(self, session_info: "SessionInfo"):
221
+ session_info.task = asyncio.current_task()
222
+
223
+ try:
224
+ # Open the MCP session
225
+ await session_info.open_session()
226
+ # Request the tools
227
+ await session_info.request_tools()
228
+ # Make sure session can be added to the manager
229
+ self.add_session(session_info)
230
+ except (asyncio.CancelledError, Exception) as err:
231
+ # Keep the error so we can handle in the main task
232
+ session_info.error = err
233
+ # Make sure the session is closed
234
+ try:
235
+ await session_info.close_session()
236
+ except Exception:
237
+ pass
238
+ return
239
+ finally:
240
+ # Whether successful or not, set ready state to prevent deadlock
241
+ session_info.ready_event.set()
242
+
243
+ # If successful, wait for shutdown signal
244
+ await session_info.shutdown_event.wait()
245
+
246
+ # On shutdown close connection to MCP server
247
+ # This is why we're using a background task in the 1st place...
248
+ # we must close in the same task that opened the session
249
+ await session_info.close_session()
250
+
251
+ async def close_sessions(self, names: Optional[Sequence[str]] = None):
252
+ if names is None:
253
+ names = list(self._mcp_sessions.keys())
254
+
255
+ if isinstance(names, str):
256
+ names = [names]
257
+
258
+ closed_sessions: list[SessionInfo] = []
259
+ for x in names:
260
+ session = await self.close_background_session(x)
261
+ if session is None:
262
+ continue
263
+ closed_sessions.append(session)
264
+
265
+ return closed_sessions
266
+
267
+ async def close_background_session(self, name: str) -> SessionInfo | None:
268
+ session = self.remove_session(name)
269
+ if session is None:
270
+ return None
271
+
272
+ # Signal shutdown and wait for the task to finish
273
+ session.shutdown_event.set()
274
+ if session.task is not None:
275
+ await session.task
276
+
277
+ return session
278
+
279
+ def add_session(self, session_info: SessionInfo) -> None:
280
+ name = session_info.name
281
+ if name in self._mcp_sessions:
282
+ raise ValueError(f"Already connected to an MCP server named: '{name}'.")
283
+ self._mcp_sessions[name] = session_info
284
+
285
+ def remove_session(self, name: str) -> SessionInfo | None:
286
+ if name not in self._mcp_sessions:
287
+ warnings.warn(
288
+ f"Cannot close MCP session named '{name}' since it was not found.",
289
+ stacklevel=2,
290
+ )
291
+ return None
292
+ session = self._mcp_sessions[name]
293
+ del self._mcp_sessions[name]
294
+ return session
295
+
296
+
297
+ def try_import_mcp():
298
+ try:
299
+ import mcp
300
+
301
+ return mcp
302
+ except ImportError:
303
+ raise ImportError(
304
+ "The `mcp` package is required to connect to MCP servers. "
305
+ "Install it with `pip install mcp`."
306
+ )
chatlas/_ollama.py CHANGED
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Optional
7
7
  import orjson
8
8
 
9
9
  from ._chat import Chat
10
- from ._openai import ChatOpenAI
11
- from ._turn import Turn
10
+ from ._openai import OpenAIProvider
11
+ from ._utils import MISSING_TYPE, is_testing
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ._openai import ChatCompletion
@@ -19,7 +19,6 @@ def ChatOllama(
19
19
  model: Optional[str] = None,
20
20
  *,
21
21
  system_prompt: Optional[str] = None,
22
- turns: Optional[list[Turn]] = None,
23
22
  base_url: str = "http://localhost:11434",
24
23
  seed: Optional[int] = None,
25
24
  kwargs: Optional["ChatClientArgs"] = None,
@@ -67,13 +66,6 @@ def ChatOllama(
67
66
  models will be printed.
68
67
  system_prompt
69
68
  A system prompt to set the behavior of the assistant.
70
- turns
71
- A list of turns to start the chat with (i.e., continuing a previous
72
- conversation). If not provided, the conversation begins from scratch. Do
73
- not provide non-`None` values for both `turns` and `system_prompt`. Each
74
- message in the list should be a dictionary with at least `role` (usually
75
- `system`, `user`, or `assistant`, but `tool` is also possible). Normally
76
- there is also a `content` field, which is a string.
77
69
  base_url
78
70
  The base URL to the endpoint; the default uses ollama's API.
79
71
  seed
@@ -102,15 +94,19 @@ def ChatOllama(
102
94
  raise ValueError(
103
95
  f"Must specify model. Locally installed models: {', '.join(models)}"
104
96
  )
105
-
106
- return ChatOpenAI(
97
+ if isinstance(seed, MISSING_TYPE):
98
+ seed = 1014 if is_testing() else None
99
+
100
+ return Chat(
101
+ provider=OpenAIProvider(
102
+ api_key="ollama", # ignored
103
+ model=model,
104
+ base_url=f"{base_url}/v1",
105
+ seed=seed,
106
+ name="Ollama",
107
+ kwargs=kwargs,
108
+ ),
107
109
  system_prompt=system_prompt,
108
- api_key="ollama", # ignored
109
- turns=turns,
110
- base_url=f"{base_url}/v1",
111
- model=model,
112
- seed=seed,
113
- kwargs=kwargs,
114
110
  )
115
111
 
116
112
 
chatlas/_openai.py CHANGED
@@ -4,6 +4,7 @@ import base64
4
4
  from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
5
5
 
6
6
  import orjson
7
+ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
7
8
  from pydantic import BaseModel
8
9
 
9
10
  from ._chat import Chat
@@ -17,14 +18,16 @@ from ._content import (
17
18
  ContentText,
18
19
  ContentToolRequest,
19
20
  ContentToolResult,
21
+ ContentToolResultImage,
22
+ ContentToolResultResource,
20
23
  )
21
24
  from ._logging import log_model_default
22
25
  from ._merge import merge_dicts
23
- from ._provider import Provider
26
+ from ._provider import Provider, StandardModelParamNames, StandardModelParams
24
27
  from ._tokens import tokens_log
25
28
  from ._tools import Tool, basemodel_to_param_schema
26
- from ._turn import Turn, normalize_turns, user_turn
27
- from ._utils import MISSING, MISSING_TYPE, is_testing
29
+ from ._turn import Turn, user_turn
30
+ from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs
28
31
 
29
32
  if TYPE_CHECKING:
30
33
  from openai.types.chat import (
@@ -53,7 +56,6 @@ ChatCompletionDict = dict[str, Any]
53
56
  def ChatOpenAI(
54
57
  *,
55
58
  system_prompt: Optional[str] = None,
56
- turns: Optional[list[Turn]] = None,
57
59
  model: "Optional[ChatModel | str]" = None,
58
60
  api_key: Optional[str] = None,
59
61
  base_url: str = "https://api.openai.com/v1",
@@ -92,13 +94,6 @@ def ChatOpenAI(
92
94
  ----------
93
95
  system_prompt
94
96
  A system prompt to set the behavior of the assistant.
95
- turns
96
- A list of turns to start the chat with (i.e., continuing a previous
97
- conversation). If not provided, the conversation begins from scratch. Do
98
- not provide non-`None` values for both `turns` and `system_prompt`. Each
99
- message in the list should be a dictionary with at least `role` (usually
100
- `system`, `user`, or `assistant`, but `tool` is also possible). Normally
101
- there is also a `content` field, which is a string.
102
97
  model
103
98
  The model to use for the chat. The default, None, will pick a reasonable
104
99
  default, and warn you about it. We strongly recommend explicitly
@@ -161,7 +156,7 @@ def ChatOpenAI(
161
156
  seed = 1014 if is_testing() else None
162
157
 
163
158
  if model is None:
164
- model = log_model_default("gpt-4o")
159
+ model = log_model_default("gpt-4.1")
165
160
 
166
161
  return Chat(
167
162
  provider=OpenAIProvider(
@@ -171,14 +166,13 @@ def ChatOpenAI(
171
166
  seed=seed,
172
167
  kwargs=kwargs,
173
168
  ),
174
- turns=normalize_turns(
175
- turns or [],
176
- system_prompt,
177
- ),
169
+ system_prompt=system_prompt,
178
170
  )
179
171
 
180
172
 
181
- class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict]):
173
+ class OpenAIProvider(
174
+ Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict, "SubmitInputArgs"]
175
+ ):
182
176
  def __init__(
183
177
  self,
184
178
  *,
@@ -186,11 +180,11 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
186
180
  model: str,
187
181
  base_url: str = "https://api.openai.com/v1",
188
182
  seed: Optional[int] = None,
183
+ name: str = "OpenAI",
189
184
  kwargs: Optional["ChatClientArgs"] = None,
190
185
  ):
191
- from openai import AsyncOpenAI, OpenAI
186
+ super().__init__(name=name, model=model)
192
187
 
193
- self._model = model
194
188
  self._seed = seed
195
189
 
196
190
  kwargs_full: "ChatClientArgs" = {
@@ -199,9 +193,12 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
199
193
  **(kwargs or {}),
200
194
  }
201
195
 
196
+ # Avoid passing the wrong sync/async client to the OpenAI constructor.
197
+ sync_kwargs, async_kwargs = split_http_client_kwargs(kwargs_full)
198
+
202
199
  # TODO: worth bringing in AsyncOpenAI types?
203
- self._client = OpenAI(**kwargs_full) # type: ignore
204
- self._async_client = AsyncOpenAI(**kwargs_full)
200
+ self._client = OpenAI(**sync_kwargs) # type: ignore
201
+ self._async_client = AsyncOpenAI(**async_kwargs)
205
202
 
206
203
  @overload
207
204
  def chat_perform(
@@ -284,7 +281,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
284
281
  kwargs_full: "SubmitInputArgs" = {
285
282
  "stream": stream,
286
283
  "messages": self._as_message_param(turns),
287
- "model": self._model,
284
+ "model": self.model,
288
285
  **(kwargs or {}),
289
286
  }
290
287
 
@@ -487,6 +484,12 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
487
484
  }
488
485
  )
489
486
  elif isinstance(x, ContentToolResult):
487
+ if isinstance(
488
+ x, (ContentToolResultImage, ContentToolResultResource)
489
+ ):
490
+ raise NotImplementedError(
491
+ "OpenAI does not support tool results with images or resources."
492
+ )
490
493
  tool_results.append(
491
494
  ChatCompletionToolMessageParam(
492
495
  # Currently, OpenAI only allows for text content in tool results
@@ -573,6 +576,46 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
573
576
  completion=completion,
574
577
  )
575
578
 
579
+ def translate_model_params(self, params: StandardModelParams) -> "SubmitInputArgs":
580
+ res: "SubmitInputArgs" = {}
581
+ if "temperature" in params:
582
+ res["temperature"] = params["temperature"]
583
+
584
+ if "top_p" in params:
585
+ res["top_p"] = params["top_p"]
586
+
587
+ if "frequency_penalty" in params:
588
+ res["frequency_penalty"] = params["frequency_penalty"]
589
+
590
+ if "presence_penalty" in params:
591
+ res["presence_penalty"] = params["presence_penalty"]
592
+
593
+ if "seed" in params:
594
+ res["seed"] = params["seed"]
595
+
596
+ if "max_tokens" in params:
597
+ res["max_tokens"] = params["max_tokens"]
598
+
599
+ if "log_probs" in params:
600
+ res["logprobs"] = params["log_probs"]
601
+
602
+ if "stop_sequences" in params:
603
+ res["stop"] = params["stop_sequences"]
604
+
605
+ return res
606
+
607
+ def supported_model_params(self) -> set[StandardModelParamNames]:
608
+ return {
609
+ "temperature",
610
+ "top_p",
611
+ "frequency_penalty",
612
+ "presence_penalty",
613
+ "seed",
614
+ "max_tokens",
615
+ "log_probs",
616
+ "stop_sequences",
617
+ }
618
+
576
619
 
577
620
  def ChatAzureOpenAI(
578
621
  *,
@@ -581,7 +624,6 @@ def ChatAzureOpenAI(
581
624
  api_version: str,
582
625
  api_key: Optional[str] = None,
583
626
  system_prompt: Optional[str] = None,
584
- turns: Optional[list[Turn]] = None,
585
627
  seed: int | None | MISSING_TYPE = MISSING,
586
628
  kwargs: Optional["ChatAzureClientArgs"] = None,
587
629
  ) -> Chat["SubmitInputArgs", ChatCompletion]:
@@ -624,13 +666,6 @@ def ChatAzureOpenAI(
624
666
  variable.
625
667
  system_prompt
626
668
  A system prompt to set the behavior of the assistant.
627
- turns
628
- A list of turns to start the chat with (i.e., continuing a previous
629
- conversation). If not provided, the conversation begins from scratch.
630
- Do not provide non-None values for both `turns` and `system_prompt`.
631
- Each message in the list should be a dictionary with at least `role`
632
- (usually `system`, `user`, or `assistant`, but `tool` is also possible).
633
- Normally there is also a `content` field, which is a string.
634
669
  seed
635
670
  Optional integer seed that ChatGPT uses to try and make output more
636
671
  reproducible.
@@ -655,10 +690,7 @@ def ChatAzureOpenAI(
655
690
  seed=seed,
656
691
  kwargs=kwargs,
657
692
  ),
658
- turns=normalize_turns(
659
- turns or [],
660
- system_prompt,
661
- ),
693
+ system_prompt=system_prompt,
662
694
  )
663
695
 
664
696
 
@@ -667,15 +699,16 @@ class OpenAIAzureProvider(OpenAIProvider):
667
699
  self,
668
700
  *,
669
701
  endpoint: Optional[str] = None,
670
- deployment_id: Optional[str] = None,
702
+ deployment_id: str,
671
703
  api_version: Optional[str] = None,
672
704
  api_key: Optional[str] = None,
673
705
  seed: int | None = None,
706
+ name: str = "OpenAIAzure",
707
+ model: Optional[str] = "UnusedValue",
674
708
  kwargs: Optional["ChatAzureClientArgs"] = None,
675
709
  ):
676
- from openai import AsyncAzureOpenAI, AzureOpenAI
710
+ super().__init__(name=name, model=deployment_id)
677
711
 
678
- self._model = deployment_id
679
712
  self._seed = seed
680
713
 
681
714
  kwargs_full: "ChatAzureClientArgs" = {
@@ -686,8 +719,10 @@ class OpenAIAzureProvider(OpenAIProvider):
686
719
  **(kwargs or {}),
687
720
  }
688
721
 
689
- self._client = AzureOpenAI(**kwargs_full) # type: ignore
690
- self._async_client = AsyncAzureOpenAI(**kwargs_full) # type: ignore
722
+ sync_kwargs, async_kwargs = split_http_client_kwargs(kwargs_full)
723
+
724
+ self._client = AzureOpenAI(**sync_kwargs) # type: ignore
725
+ self._async_client = AsyncAzureOpenAI(**async_kwargs) # type: ignore
691
726
 
692
727
 
693
728
  class InvalidJSONParameterWarning(RuntimeWarning):
chatlas/_perplexity.py CHANGED
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  from ._chat import Chat
7
7
  from ._logging import log_model_default
8
- from ._openai import ChatOpenAI
9
- from ._turn import Turn
10
- from ._utils import MISSING, MISSING_TYPE
8
+ from ._openai import OpenAIProvider
9
+ from ._utils import MISSING, MISSING_TYPE, is_testing
11
10
 
12
11
  if TYPE_CHECKING:
13
12
  from ._openai import ChatCompletion
@@ -17,7 +16,6 @@ if TYPE_CHECKING:
17
16
  def ChatPerplexity(
18
17
  *,
19
18
  system_prompt: Optional[str] = None,
20
- turns: Optional[list[Turn]] = None,
21
19
  model: Optional[str] = None,
22
20
  api_key: Optional[str] = None,
23
21
  base_url: str = "https://api.perplexity.ai/",
@@ -56,13 +54,6 @@ def ChatPerplexity(
56
54
  ----------
57
55
  system_prompt
58
56
  A system prompt to set the behavior of the assistant.
59
- turns
60
- A list of turns to start the chat with (i.e., continuing a previous
61
- conversation). If not provided, the conversation begins from scratch. Do
62
- not provide non-`None` values for both `turns` and `system_prompt`. Each
63
- message in the list should be a dictionary with at least `role` (usually
64
- `system`, `user`, or `assistant`, but `tool` is also possible). Normally
65
- there is also a `content` field, which is a string.
66
57
  model
67
58
  The model to use for the chat. The default, None, will pick a reasonable
68
59
  default, and warn you about it. We strongly recommend explicitly
@@ -131,12 +122,17 @@ def ChatPerplexity(
131
122
  if api_key is None:
132
123
  api_key = os.getenv("PERPLEXITY_API_KEY")
133
124
 
134
- return ChatOpenAI(
125
+ if isinstance(seed, MISSING_TYPE):
126
+ seed = 1014 if is_testing() else None
127
+
128
+ return Chat(
129
+ provider=OpenAIProvider(
130
+ api_key=api_key,
131
+ model=model,
132
+ base_url=base_url,
133
+ seed=seed,
134
+ name="Perplexity",
135
+ kwargs=kwargs,
136
+ ),
135
137
  system_prompt=system_prompt,
136
- turns=turns,
137
- model=model,
138
- api_key=api_key,
139
- base_url=base_url,
140
- seed=seed,
141
- kwargs=kwargs,
142
138
  )