pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
|
|
5
|
-
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
5
|
+
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelRetry(Exception):
|
|
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
|
|
|
30
30
|
super().__init__(message)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class
|
|
33
|
+
class AgentRunError(RuntimeError):
|
|
34
|
+
"""Base class for errors occurring during an agent run."""
|
|
35
|
+
|
|
36
|
+
message: str
|
|
37
|
+
"""The error message."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, message: str):
|
|
40
|
+
self.message = message
|
|
41
|
+
super().__init__(message)
|
|
42
|
+
|
|
43
|
+
def __str__(self) -> str:
|
|
44
|
+
return self.message
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class UsageLimitExceeded(AgentRunError):
|
|
48
|
+
"""Error raised when a Model's usage exceeds the specified limits."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class UnexpectedModelBehavior(AgentRunError):
|
|
34
52
|
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""
|
|
35
53
|
|
|
36
54
|
message: str
|
pydantic_ai/messages.py
CHANGED
|
@@ -2,18 +2,17 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from
|
|
9
|
+
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
|
-
from . import _pydantic
|
|
12
11
|
from ._utils import now_utc as _now_utc
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
@dataclass
|
|
16
|
-
class
|
|
15
|
+
class SystemPromptPart:
|
|
17
16
|
"""A system prompt, generally written by the application developer.
|
|
18
17
|
|
|
19
18
|
This gives the model context and guidance on how to respond.
|
|
@@ -21,12 +20,13 @@ class SystemPrompt:
|
|
|
21
20
|
|
|
22
21
|
content: str
|
|
23
22
|
"""The content of the prompt."""
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
|
|
24
|
+
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
25
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@dataclass
|
|
29
|
-
class
|
|
29
|
+
class UserPromptPart:
|
|
30
30
|
"""A user prompt, generally written by the end user.
|
|
31
31
|
|
|
32
32
|
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
|
|
@@ -35,29 +35,35 @@ class UserPrompt:
|
|
|
35
35
|
|
|
36
36
|
content: str
|
|
37
37
|
"""The content of the prompt."""
|
|
38
|
+
|
|
38
39
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
39
40
|
"""The timestamp of the prompt."""
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
|
|
42
|
+
part_kind: Literal['user-prompt'] = 'user-prompt'
|
|
43
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
42
44
|
|
|
43
45
|
|
|
44
|
-
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
|
|
46
|
+
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
@dataclass
|
|
48
|
-
class
|
|
50
|
+
class ToolReturnPart:
|
|
49
51
|
"""A tool return message, this encodes the result of running a tool."""
|
|
50
52
|
|
|
51
53
|
tool_name: str
|
|
52
54
|
"""The name of the "tool" was called."""
|
|
55
|
+
|
|
53
56
|
content: Any
|
|
54
57
|
"""The return value."""
|
|
55
|
-
|
|
56
|
-
|
|
58
|
+
|
|
59
|
+
tool_call_id: str | None = None
|
|
60
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
61
|
+
|
|
57
62
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
58
63
|
"""The timestamp, when the tool returned."""
|
|
59
|
-
|
|
60
|
-
|
|
64
|
+
|
|
65
|
+
part_kind: Literal['tool-return'] = 'tool-return'
|
|
66
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
61
67
|
|
|
62
68
|
def model_response_str(self) -> str:
|
|
63
69
|
if isinstance(self.content, str):
|
|
@@ -73,11 +79,11 @@ class ToolReturn:
|
|
|
73
79
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
74
80
|
|
|
75
81
|
|
|
76
|
-
|
|
82
|
+
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
|
|
77
83
|
|
|
78
84
|
|
|
79
85
|
@dataclass
|
|
80
|
-
class
|
|
86
|
+
class RetryPromptPart:
|
|
81
87
|
"""A message back to a model asking it to try again.
|
|
82
88
|
|
|
83
89
|
This can be sent for a number of reasons:
|
|
@@ -98,37 +104,54 @@ class RetryPrompt:
|
|
|
98
104
|
If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
|
|
99
105
|
error details.
|
|
100
106
|
"""
|
|
107
|
+
|
|
101
108
|
tool_name: str | None = None
|
|
102
109
|
"""The name of the tool that was called, if any."""
|
|
103
|
-
|
|
104
|
-
|
|
110
|
+
|
|
111
|
+
tool_call_id: str | None = None
|
|
112
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
113
|
+
|
|
105
114
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
106
115
|
"""The timestamp, when the retry was triggered."""
|
|
107
|
-
|
|
108
|
-
|
|
116
|
+
|
|
117
|
+
part_kind: Literal['retry-prompt'] = 'retry-prompt'
|
|
118
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
109
119
|
|
|
110
120
|
def model_response(self) -> str:
|
|
111
121
|
if isinstance(self.content, str):
|
|
112
122
|
description = self.content
|
|
113
123
|
else:
|
|
114
|
-
json_errors =
|
|
124
|
+
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
115
125
|
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
116
126
|
return f'{description}\n\nFix the errors and try again.'
|
|
117
127
|
|
|
118
128
|
|
|
129
|
+
ModelRequestPart = Annotated[
|
|
130
|
+
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
|
|
131
|
+
]
|
|
132
|
+
"""A message part sent by PydanticAI to a model."""
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass
|
|
136
|
+
class ModelRequest:
|
|
137
|
+
"""A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
|
|
138
|
+
|
|
139
|
+
parts: list[ModelRequestPart]
|
|
140
|
+
"""The parts of the user message."""
|
|
141
|
+
|
|
142
|
+
kind: Literal['request'] = 'request'
|
|
143
|
+
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
144
|
+
|
|
145
|
+
|
|
119
146
|
@dataclass
|
|
120
|
-
class
|
|
147
|
+
class TextPart:
|
|
121
148
|
"""A plain text response from a model."""
|
|
122
149
|
|
|
123
150
|
content: str
|
|
124
151
|
"""The text content of the response."""
|
|
125
|
-
timestamp: datetime = field(default_factory=_now_utc)
|
|
126
|
-
"""The timestamp of the response.
|
|
127
152
|
|
|
128
|
-
|
|
129
|
-
"""
|
|
130
|
-
role: Literal['model-text-response'] = 'model-text-response'
|
|
131
|
-
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
153
|
+
part_kind: Literal['text'] = 'text'
|
|
154
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
132
155
|
|
|
133
156
|
|
|
134
157
|
@dataclass
|
|
@@ -148,26 +171,53 @@ class ArgsDict:
|
|
|
148
171
|
|
|
149
172
|
|
|
150
173
|
@dataclass
|
|
151
|
-
class
|
|
152
|
-
"""
|
|
174
|
+
class ToolCallPart:
|
|
175
|
+
"""A tool call from a model."""
|
|
153
176
|
|
|
154
177
|
tool_name: str
|
|
155
178
|
"""The name of the tool to call."""
|
|
179
|
+
|
|
156
180
|
args: ArgsJson | ArgsDict
|
|
157
181
|
"""The arguments to pass to the tool.
|
|
158
182
|
|
|
159
183
|
Either as JSON or a Python dictionary depending on how data was returned.
|
|
160
184
|
"""
|
|
161
|
-
tool_id: str | None = None
|
|
162
|
-
"""Optional tool identifier, this is used by some models including OpenAI."""
|
|
163
185
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
186
|
+
tool_call_id: str | None = None
|
|
187
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
188
|
+
|
|
189
|
+
part_kind: Literal['tool-call'] = 'tool-call'
|
|
190
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
167
191
|
|
|
168
192
|
@classmethod
|
|
169
|
-
def
|
|
170
|
-
|
|
193
|
+
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
194
|
+
"""Create a `ToolCallPart` from raw arguments."""
|
|
195
|
+
if isinstance(args, str):
|
|
196
|
+
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
197
|
+
elif isinstance(args, dict):
|
|
198
|
+
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
199
|
+
else:
|
|
200
|
+
assert_never(args)
|
|
201
|
+
|
|
202
|
+
def args_as_dict(self) -> dict[str, Any]:
|
|
203
|
+
"""Return the arguments as a Python dictionary.
|
|
204
|
+
|
|
205
|
+
This is just for convenience with models that require dicts as input.
|
|
206
|
+
"""
|
|
207
|
+
if isinstance(self.args, ArgsDict):
|
|
208
|
+
return self.args.args_dict
|
|
209
|
+
args = pydantic_core.from_json(self.args.args_json)
|
|
210
|
+
assert isinstance(args, dict), 'args should be a dict'
|
|
211
|
+
return cast(dict[str, Any], args)
|
|
212
|
+
|
|
213
|
+
def args_as_json_str(self) -> str:
|
|
214
|
+
"""Return the arguments as a JSON string.
|
|
215
|
+
|
|
216
|
+
This is just for convenience with models that require JSON strings as input.
|
|
217
|
+
"""
|
|
218
|
+
if isinstance(self.args, ArgsJson):
|
|
219
|
+
return self.args.args_json
|
|
220
|
+
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
171
221
|
|
|
172
222
|
def has_content(self) -> bool:
|
|
173
223
|
if isinstance(self.args, ArgsDict):
|
|
@@ -176,28 +226,39 @@ class ToolCall:
|
|
|
176
226
|
return bool(self.args.args_json)
|
|
177
227
|
|
|
178
228
|
|
|
229
|
+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
|
|
230
|
+
"""A message part returned by a model."""
|
|
231
|
+
|
|
232
|
+
|
|
179
233
|
@dataclass
|
|
180
|
-
class
|
|
181
|
-
"""A
|
|
234
|
+
class ModelResponse:
|
|
235
|
+
"""A response from a model, e.g. a message from the model to the PydanticAI app."""
|
|
182
236
|
|
|
183
|
-
|
|
184
|
-
"""
|
|
237
|
+
parts: list[ModelResponsePart]
|
|
238
|
+
"""The parts of the model message."""
|
|
185
239
|
|
|
186
|
-
calls: list[ToolCall]
|
|
187
|
-
"""The tool calls being made."""
|
|
188
240
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
189
241
|
"""The timestamp of the response.
|
|
190
242
|
|
|
191
243
|
If the model provides a timestamp in the response (as OpenAI does) that will be used.
|
|
192
244
|
"""
|
|
193
|
-
|
|
194
|
-
|
|
245
|
+
|
|
246
|
+
kind: Literal['response'] = 'response'
|
|
247
|
+
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
251
|
+
return cls([TextPart(content)], timestamp=timestamp or _now_utc())
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
|
|
255
|
+
return cls([tool_call])
|
|
195
256
|
|
|
196
257
|
|
|
197
|
-
|
|
198
|
-
"""Any response from a model."""
|
|
199
|
-
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
|
|
258
|
+
ModelMessage = Union[ModelRequest, ModelResponse]
|
|
200
259
|
"""Any message send to or returned by a model."""
|
|
201
260
|
|
|
202
|
-
|
|
261
|
+
ModelMessagesTypeAdapter = pydantic.TypeAdapter(
|
|
262
|
+
list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
|
|
263
|
+
)
|
|
203
264
|
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -16,10 +16,11 @@ from typing import TYPE_CHECKING, Literal, Union
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
18
18
|
from ..exceptions import UserError
|
|
19
|
-
from ..messages import
|
|
19
|
+
from ..messages import ModelMessage, ModelResponse
|
|
20
|
+
from ..settings import ModelSettings
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
22
|
-
from ..result import
|
|
23
|
+
from ..result import Usage
|
|
23
24
|
from ..tools import ToolDefinition
|
|
24
25
|
|
|
25
26
|
|
|
@@ -30,7 +31,9 @@ KnownModelName = Literal[
|
|
|
30
31
|
'openai:gpt-4',
|
|
31
32
|
'openai:o1-preview',
|
|
32
33
|
'openai:o1-mini',
|
|
34
|
+
'openai:o1',
|
|
33
35
|
'openai:gpt-3.5-turbo',
|
|
36
|
+
'groq:llama-3.3-70b-versatile',
|
|
34
37
|
'groq:llama-3.1-70b-versatile',
|
|
35
38
|
'groq:llama3-groq-70b-8192-tool-use-preview',
|
|
36
39
|
'groq:llama3-groq-8b-8192-tool-use-preview',
|
|
@@ -47,8 +50,15 @@ KnownModelName = Literal[
|
|
|
47
50
|
'groq:gemma-7b-it',
|
|
48
51
|
'gemini-1.5-flash',
|
|
49
52
|
'gemini-1.5-pro',
|
|
53
|
+
'gemini-2.0-flash-exp',
|
|
50
54
|
'vertexai:gemini-1.5-flash',
|
|
51
55
|
'vertexai:gemini-1.5-pro',
|
|
56
|
+
# since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
|
|
57
|
+
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
|
|
58
|
+
'mistral:mistral-small-latest',
|
|
59
|
+
'mistral:mistral-large-latest',
|
|
60
|
+
'mistral:codestral-latest',
|
|
61
|
+
'mistral:mistral-moderation-latest',
|
|
52
62
|
'ollama:codellama',
|
|
53
63
|
'ollama:gemma',
|
|
54
64
|
'ollama:gemma2',
|
|
@@ -66,6 +76,9 @@ KnownModelName = Literal[
|
|
|
66
76
|
'ollama:qwen2',
|
|
67
77
|
'ollama:qwen2.5',
|
|
68
78
|
'ollama:starcoder2',
|
|
79
|
+
'claude-3-5-haiku-latest',
|
|
80
|
+
'claude-3-5-sonnet-latest',
|
|
81
|
+
'claude-3-opus-latest',
|
|
69
82
|
'test',
|
|
70
83
|
]
|
|
71
84
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -108,12 +121,16 @@ class AgentModel(ABC):
|
|
|
108
121
|
"""Model configured for each step of an Agent run."""
|
|
109
122
|
|
|
110
123
|
@abstractmethod
|
|
111
|
-
async def request(
|
|
124
|
+
async def request(
|
|
125
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
126
|
+
) -> tuple[ModelResponse, Usage]:
|
|
112
127
|
"""Make a request to the model."""
|
|
113
128
|
raise NotImplementedError()
|
|
114
129
|
|
|
115
130
|
@asynccontextmanager
|
|
116
|
-
async def request_stream(
|
|
131
|
+
async def request_stream(
|
|
132
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
133
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
117
134
|
"""Make a request to the model and return a streaming response."""
|
|
118
135
|
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
119
136
|
# yield is required to make this a generator for type checking
|
|
@@ -148,10 +165,10 @@ class StreamTextResponse(ABC):
|
|
|
148
165
|
raise NotImplementedError()
|
|
149
166
|
|
|
150
167
|
@abstractmethod
|
|
151
|
-
def
|
|
152
|
-
"""Return the
|
|
168
|
+
def usage(self) -> Usage:
|
|
169
|
+
"""Return the usage of the request.
|
|
153
170
|
|
|
154
|
-
NOTE: this won't return the
|
|
171
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
155
172
|
"""
|
|
156
173
|
raise NotImplementedError()
|
|
157
174
|
|
|
@@ -178,10 +195,10 @@ class StreamStructuredResponse(ABC):
|
|
|
178
195
|
raise NotImplementedError()
|
|
179
196
|
|
|
180
197
|
@abstractmethod
|
|
181
|
-
def get(self, *, final: bool = False) ->
|
|
182
|
-
"""Get the `
|
|
198
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
199
|
+
"""Get the `ModelResponse` at this point.
|
|
183
200
|
|
|
184
|
-
The `
|
|
201
|
+
The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
|
|
185
202
|
|
|
186
203
|
Args:
|
|
187
204
|
final: If True, this is the final call, after iteration is complete, the response should be fully validated.
|
|
@@ -189,10 +206,10 @@ class StreamStructuredResponse(ABC):
|
|
|
189
206
|
raise NotImplementedError()
|
|
190
207
|
|
|
191
208
|
@abstractmethod
|
|
192
|
-
def
|
|
193
|
-
"""Get the
|
|
209
|
+
def usage(self) -> Usage:
|
|
210
|
+
"""Get the usage of the request.
|
|
194
211
|
|
|
195
|
-
NOTE: this won't return the full
|
|
212
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
196
213
|
"""
|
|
197
214
|
raise NotImplementedError()
|
|
198
215
|
|
|
@@ -219,7 +236,7 @@ The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
|
|
|
219
236
|
def check_allow_model_requests() -> None:
|
|
220
237
|
"""Check if model requests are allowed.
|
|
221
238
|
|
|
222
|
-
If you're defining your own models that have
|
|
239
|
+
If you're defining your own models that have costs or latency associated with their use, you should call this in
|
|
223
240
|
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
|
|
224
241
|
|
|
225
242
|
Raises:
|
|
@@ -270,10 +287,18 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
270
287
|
from .vertexai import VertexAIModel
|
|
271
288
|
|
|
272
289
|
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
|
|
290
|
+
elif model.startswith('mistral:'):
|
|
291
|
+
from .mistral import MistralModel
|
|
292
|
+
|
|
293
|
+
return MistralModel(model[8:])
|
|
273
294
|
elif model.startswith('ollama:'):
|
|
274
295
|
from .ollama import OllamaModel
|
|
275
296
|
|
|
276
297
|
return OllamaModel(model[7:])
|
|
298
|
+
elif model.startswith('claude'):
|
|
299
|
+
from .anthropic import AnthropicModel
|
|
300
|
+
|
|
301
|
+
return AnthropicModel(model)
|
|
277
302
|
else:
|
|
278
303
|
raise UserError(f'Unknown model: {model}')
|
|
279
304
|
|