pydantic-ai-slim 0.0.22__tar.gz → 0.0.24__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_agent_graph.py +12 -8
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/agent.py +5 -5
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/__init__.py +52 -45
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/anthropic.py +87 -66
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/cohere.py +65 -67
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/function.py +76 -60
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/gemini.py +153 -99
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/groq.py +97 -72
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/mistral.py +90 -71
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/openai.py +110 -71
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/test.py +99 -94
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/vertexai.py +48 -44
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/result.py +2 -2
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pyproject.toml +3 -3
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/README.md +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.24
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,7 +28,7 @@ Requires-Dist: eval-type-backport>=0.2.0
|
|
|
28
28
|
Requires-Dist: griffe>=1.3.2
|
|
29
29
|
Requires-Dist: httpx>=0.27
|
|
30
30
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
31
|
+
Requires-Dist: pydantic-graph==0.0.24
|
|
32
32
|
Requires-Dist: pydantic>=2.10
|
|
33
33
|
Provides-Extra: anthropic
|
|
34
34
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
@@ -41,7 +41,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
|
41
41
|
Provides-Extra: mistral
|
|
42
42
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
43
43
|
Provides-Extra: openai
|
|
44
|
-
Requires-Dist: openai>=1.
|
|
44
|
+
Requires-Dist: openai>=1.61.0; extra == 'openai'
|
|
45
45
|
Provides-Extra: vertexai
|
|
46
46
|
Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
|
|
47
47
|
Requires-Dist: requests>=2.32.3; extra == 'vertexai'
|
|
@@ -204,9 +204,9 @@ class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
|
204
204
|
return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
205
205
|
|
|
206
206
|
|
|
207
|
-
async def
|
|
207
|
+
async def _prepare_request_parameters(
|
|
208
208
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
209
|
-
) -> models.
|
|
209
|
+
) -> models.ModelRequestParameters:
|
|
210
210
|
"""Build tools and create an agent model."""
|
|
211
211
|
function_tool_defs: list[ToolDefinition] = []
|
|
212
212
|
|
|
@@ -220,7 +220,7 @@ async def _prepare_model(
|
|
|
220
220
|
await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
|
|
221
221
|
|
|
222
222
|
result_schema = ctx.deps.result_schema
|
|
223
|
-
return
|
|
223
|
+
return models.ModelRequestParameters(
|
|
224
224
|
function_tools=function_tool_defs,
|
|
225
225
|
allow_text_result=_allow_text_result(result_schema),
|
|
226
226
|
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
@@ -245,13 +245,15 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
245
245
|
# Increment run_step
|
|
246
246
|
ctx.state.run_step += 1
|
|
247
247
|
|
|
248
|
-
with _logfire.span('preparing model
|
|
249
|
-
|
|
248
|
+
with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
|
|
249
|
+
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
250
250
|
|
|
251
251
|
# Actually make the model request
|
|
252
252
|
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
253
253
|
with _logfire.span('model request') as span:
|
|
254
|
-
model_response, request_usage = await
|
|
254
|
+
model_response, request_usage = await ctx.deps.model.request(
|
|
255
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
256
|
+
)
|
|
255
257
|
span.set_attribute('response', model_response)
|
|
256
258
|
span.set_attribute('usage', request_usage)
|
|
257
259
|
|
|
@@ -405,12 +407,14 @@ class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any
|
|
|
405
407
|
ctx.state.run_step += 1
|
|
406
408
|
|
|
407
409
|
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
|
|
408
|
-
|
|
410
|
+
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
409
411
|
|
|
410
412
|
# Actually make the model request
|
|
411
413
|
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
412
414
|
with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
|
|
413
|
-
async with
|
|
415
|
+
async with ctx.deps.model.request_stream(
|
|
416
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
417
|
+
) as streamed_response:
|
|
414
418
|
ctx.state.usage.requests += 1
|
|
415
419
|
model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
|
|
416
420
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
@@ -275,7 +275,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
275
275
|
"""
|
|
276
276
|
if infer_name and self.name is None:
|
|
277
277
|
self._infer_name(inspect.currentframe())
|
|
278
|
-
model_used =
|
|
278
|
+
model_used = self._get_model(model)
|
|
279
279
|
|
|
280
280
|
deps = self._get_deps(deps)
|
|
281
281
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -309,7 +309,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
309
309
|
'{agent_name} run {prompt=}',
|
|
310
310
|
prompt=user_prompt,
|
|
311
311
|
agent=self,
|
|
312
|
-
model_name=model_used.
|
|
312
|
+
model_name=model_used.model_name if model_used else 'no-model',
|
|
313
313
|
agent_name=self.name or 'agent',
|
|
314
314
|
) as run_span:
|
|
315
315
|
# Build the deps object for the graph
|
|
@@ -520,7 +520,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
520
520
|
# f_back because `asynccontextmanager` adds one frame
|
|
521
521
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
522
522
|
self._infer_name(frame.f_back)
|
|
523
|
-
model_used =
|
|
523
|
+
model_used = self._get_model(model)
|
|
524
524
|
|
|
525
525
|
deps = self._get_deps(deps)
|
|
526
526
|
new_message_index = len(message_history) if message_history else 0
|
|
@@ -554,7 +554,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
554
554
|
'{agent_name} run stream {prompt=}',
|
|
555
555
|
prompt=user_prompt,
|
|
556
556
|
agent=self,
|
|
557
|
-
model_name=model_used.
|
|
557
|
+
model_name=model_used.model_name if model_used else 'no-model',
|
|
558
558
|
agent_name=self.name or 'agent',
|
|
559
559
|
) as run_span:
|
|
560
560
|
# Build the deps object for the graph
|
|
@@ -971,7 +971,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
971
971
|
|
|
972
972
|
self._function_tools[tool.name] = tool
|
|
973
973
|
|
|
974
|
-
|
|
974
|
+
def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
|
|
975
975
|
"""Create a model configured for this agent.
|
|
976
976
|
|
|
977
977
|
Args:
|
|
@@ -54,6 +54,8 @@ KnownModelName = Literal[
|
|
|
54
54
|
'google-gla:gemini-2.0-flash-exp',
|
|
55
55
|
'google-gla:gemini-2.0-flash-thinking-exp-01-21',
|
|
56
56
|
'google-gla:gemini-exp-1206',
|
|
57
|
+
'google-gla:gemini-2.0-flash',
|
|
58
|
+
'google-gla:gemini-2.0-flash-lite-preview-02-05',
|
|
57
59
|
'google-vertex:gemini-1.0-pro',
|
|
58
60
|
'google-vertex:gemini-1.5-flash',
|
|
59
61
|
'google-vertex:gemini-1.5-flash-8b',
|
|
@@ -61,6 +63,8 @@ KnownModelName = Literal[
|
|
|
61
63
|
'google-vertex:gemini-2.0-flash-exp',
|
|
62
64
|
'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
|
|
63
65
|
'google-vertex:gemini-exp-1206',
|
|
66
|
+
'google-vertex:gemini-2.0-flash',
|
|
67
|
+
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
|
|
64
68
|
'gpt-3.5-turbo',
|
|
65
69
|
'gpt-3.5-turbo-0125',
|
|
66
70
|
'gpt-3.5-turbo-0301',
|
|
@@ -112,6 +116,8 @@ KnownModelName = Literal[
|
|
|
112
116
|
'o1-mini-2024-09-12',
|
|
113
117
|
'o1-preview',
|
|
114
118
|
'o1-preview-2024-09-12',
|
|
119
|
+
'o3-mini',
|
|
120
|
+
'o3-mini-2025-01-31',
|
|
115
121
|
'openai:chatgpt-4o-latest',
|
|
116
122
|
'openai:gpt-3.5-turbo',
|
|
117
123
|
'openai:gpt-3.5-turbo-0125',
|
|
@@ -149,6 +155,8 @@ KnownModelName = Literal[
|
|
|
149
155
|
'openai:o1-mini-2024-09-12',
|
|
150
156
|
'openai:o1-preview',
|
|
151
157
|
'openai:o1-preview-2024-09-12',
|
|
158
|
+
'openai:o3-mini',
|
|
159
|
+
'openai:o3-mini-2025-01-31',
|
|
152
160
|
'test',
|
|
153
161
|
]
|
|
154
162
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -157,49 +165,34 @@ KnownModelName = Literal[
|
|
|
157
165
|
"""
|
|
158
166
|
|
|
159
167
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
@abstractmethod
|
|
164
|
-
async def agent_model(
|
|
165
|
-
self,
|
|
166
|
-
*,
|
|
167
|
-
function_tools: list[ToolDefinition],
|
|
168
|
-
allow_text_result: bool,
|
|
169
|
-
result_tools: list[ToolDefinition],
|
|
170
|
-
) -> AgentModel:
|
|
171
|
-
"""Create an agent model, this is called for each step of an agent run.
|
|
172
|
-
|
|
173
|
-
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
function_tools: The tools available to the agent.
|
|
177
|
-
allow_text_result: Whether a plain text final response/result is permitted.
|
|
178
|
-
result_tools: Tool definitions for the final result tool(s), if any.
|
|
179
|
-
|
|
180
|
-
Returns:
|
|
181
|
-
An agent model.
|
|
182
|
-
"""
|
|
183
|
-
raise NotImplementedError()
|
|
168
|
+
@dataclass
|
|
169
|
+
class ModelRequestParameters:
|
|
170
|
+
"""Configuration for an agent's request to a model, specifically related to tools and result handling."""
|
|
184
171
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
172
|
+
function_tools: list[ToolDefinition]
|
|
173
|
+
allow_text_result: bool
|
|
174
|
+
result_tools: list[ToolDefinition]
|
|
188
175
|
|
|
189
176
|
|
|
190
|
-
class
|
|
191
|
-
"""
|
|
177
|
+
class Model(ABC):
|
|
178
|
+
"""Abstract class for a model."""
|
|
192
179
|
|
|
193
180
|
@abstractmethod
|
|
194
181
|
async def request(
|
|
195
|
-
self,
|
|
182
|
+
self,
|
|
183
|
+
messages: list[ModelMessage],
|
|
184
|
+
model_settings: ModelSettings | None,
|
|
185
|
+
model_request_parameters: ModelRequestParameters,
|
|
196
186
|
) -> tuple[ModelResponse, Usage]:
|
|
197
187
|
"""Make a request to the model."""
|
|
198
188
|
raise NotImplementedError()
|
|
199
189
|
|
|
200
190
|
@asynccontextmanager
|
|
201
191
|
async def request_stream(
|
|
202
|
-
self,
|
|
192
|
+
self,
|
|
193
|
+
messages: list[ModelMessage],
|
|
194
|
+
model_settings: ModelSettings | None,
|
|
195
|
+
model_request_parameters: ModelRequestParameters,
|
|
203
196
|
) -> AsyncIterator[StreamedResponse]:
|
|
204
197
|
"""Make a request to the model and return a streaming response."""
|
|
205
198
|
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
@@ -208,15 +201,26 @@ class AgentModel(ABC):
|
|
|
208
201
|
# noinspection PyUnreachableCode
|
|
209
202
|
yield # pragma: no cover
|
|
210
203
|
|
|
204
|
+
@property
|
|
205
|
+
@abstractmethod
|
|
206
|
+
def model_name(self) -> str:
|
|
207
|
+
"""The model name."""
|
|
208
|
+
raise NotImplementedError()
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
@abstractmethod
|
|
212
|
+
def system(self) -> str | None:
|
|
213
|
+
"""The system / model provider, ex: openai."""
|
|
214
|
+
raise NotImplementedError()
|
|
215
|
+
|
|
211
216
|
|
|
212
217
|
@dataclass
|
|
213
218
|
class StreamedResponse(ABC):
|
|
214
219
|
"""Streamed response from an LLM when calling a tool."""
|
|
215
220
|
|
|
216
|
-
_model_name: str
|
|
217
|
-
_usage: Usage = field(default_factory=Usage, init=False)
|
|
218
221
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
219
222
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
223
|
+
_usage: Usage = field(default_factory=Usage, init=False)
|
|
220
224
|
|
|
221
225
|
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
222
226
|
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
@@ -238,17 +242,20 @@ class StreamedResponse(ABC):
|
|
|
238
242
|
def get(self) -> ModelResponse:
|
|
239
243
|
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
240
244
|
return ModelResponse(
|
|
241
|
-
parts=self._parts_manager.get_parts(), model_name=self.
|
|
245
|
+
parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp
|
|
242
246
|
)
|
|
243
247
|
|
|
244
|
-
def model_name(self) -> str:
|
|
245
|
-
"""Get the model name of the response."""
|
|
246
|
-
return self._model_name
|
|
247
|
-
|
|
248
248
|
def usage(self) -> Usage:
|
|
249
249
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
250
250
|
return self._usage
|
|
251
251
|
|
|
252
|
+
@property
|
|
253
|
+
@abstractmethod
|
|
254
|
+
def model_name(self) -> str:
|
|
255
|
+
"""Get the model name of the response."""
|
|
256
|
+
raise NotImplementedError()
|
|
257
|
+
|
|
258
|
+
@property
|
|
252
259
|
@abstractmethod
|
|
253
260
|
def timestamp(self) -> datetime:
|
|
254
261
|
"""Get the timestamp of the response."""
|
|
@@ -270,7 +277,7 @@ def check_allow_model_requests() -> None:
|
|
|
270
277
|
"""Check if model requests are allowed.
|
|
271
278
|
|
|
272
279
|
If you're defining your own models that have costs or latency associated with their use, you should call this in
|
|
273
|
-
[`Model.
|
|
280
|
+
[`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream].
|
|
274
281
|
|
|
275
282
|
Raises:
|
|
276
283
|
RuntimeError: If model requests are not allowed.
|
|
@@ -311,33 +318,33 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
311
318
|
from .openai import OpenAIModel
|
|
312
319
|
|
|
313
320
|
return OpenAIModel(model[7:])
|
|
314
|
-
elif model.startswith(('gpt', 'o1')):
|
|
321
|
+
elif model.startswith(('gpt', 'o1', 'o3')):
|
|
315
322
|
from .openai import OpenAIModel
|
|
316
323
|
|
|
317
324
|
return OpenAIModel(model)
|
|
318
325
|
elif model.startswith('google-gla'):
|
|
319
326
|
from .gemini import GeminiModel
|
|
320
327
|
|
|
321
|
-
return GeminiModel(model[11:])
|
|
328
|
+
return GeminiModel(model[11:])
|
|
322
329
|
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
323
330
|
elif model.startswith('gemini'):
|
|
324
331
|
from .gemini import GeminiModel
|
|
325
332
|
|
|
326
333
|
# noinspection PyTypeChecker
|
|
327
|
-
return GeminiModel(model)
|
|
334
|
+
return GeminiModel(model)
|
|
328
335
|
elif model.startswith('groq:'):
|
|
329
336
|
from .groq import GroqModel
|
|
330
337
|
|
|
331
|
-
return GroqModel(model[5:])
|
|
338
|
+
return GroqModel(model[5:])
|
|
332
339
|
elif model.startswith('google-vertex'):
|
|
333
340
|
from .vertexai import VertexAIModel
|
|
334
341
|
|
|
335
|
-
return VertexAIModel(model[14:])
|
|
342
|
+
return VertexAIModel(model[14:])
|
|
336
343
|
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
337
344
|
elif model.startswith('vertexai:'):
|
|
338
345
|
from .vertexai import VertexAIModel
|
|
339
346
|
|
|
340
|
-
return VertexAIModel(model[9:])
|
|
347
|
+
return VertexAIModel(model[9:])
|
|
341
348
|
elif model.startswith('mistral:'):
|
|
342
349
|
from .mistral import MistralModel
|
|
343
350
|
|
|
@@ -28,8 +28,8 @@ from ..messages import (
|
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
|
-
AgentModel,
|
|
32
31
|
Model,
|
|
32
|
+
ModelRequestParameters,
|
|
33
33
|
StreamedResponse,
|
|
34
34
|
cached_async_http_client,
|
|
35
35
|
check_allow_model_requests,
|
|
@@ -68,14 +68,14 @@ LatestAnthropicModelNames = Literal[
|
|
|
68
68
|
'claude-3-5-sonnet-latest',
|
|
69
69
|
'claude-3-opus-latest',
|
|
70
70
|
]
|
|
71
|
-
"""Latest
|
|
71
|
+
"""Latest Anthropic models."""
|
|
72
72
|
|
|
73
73
|
AnthropicModelName = Union[str, LatestAnthropicModelNames]
|
|
74
74
|
"""Possible Anthropic model names.
|
|
75
75
|
|
|
76
76
|
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
77
77
|
allow any name in the type hints.
|
|
78
|
-
|
|
78
|
+
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
|
|
@@ -101,9 +101,11 @@ class AnthropicModel(Model):
|
|
|
101
101
|
We anticipate adding support for streaming responses in a near-term future release.
|
|
102
102
|
"""
|
|
103
103
|
|
|
104
|
-
model_name: AnthropicModelName
|
|
105
104
|
client: AsyncAnthropic = field(repr=False)
|
|
106
105
|
|
|
106
|
+
_model_name: AnthropicModelName = field(repr=False)
|
|
107
|
+
_system: str | None = field(default='anthropic', repr=False)
|
|
108
|
+
|
|
107
109
|
def __init__(
|
|
108
110
|
self,
|
|
109
111
|
model_name: AnthropicModelName,
|
|
@@ -124,7 +126,7 @@ class AnthropicModel(Model):
|
|
|
124
126
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
125
127
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
126
128
|
"""
|
|
127
|
-
self.
|
|
129
|
+
self._model_name = model_name
|
|
128
130
|
if anthropic_client is not None:
|
|
129
131
|
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
130
132
|
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
@@ -134,81 +136,77 @@ class AnthropicModel(Model):
|
|
|
134
136
|
else:
|
|
135
137
|
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
136
138
|
|
|
137
|
-
async def
|
|
139
|
+
async def request(
|
|
138
140
|
self,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
) -> AgentModel:
|
|
141
|
+
messages: list[ModelMessage],
|
|
142
|
+
model_settings: ModelSettings | None,
|
|
143
|
+
model_request_parameters: ModelRequestParameters,
|
|
144
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
144
145
|
check_allow_model_requests()
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
148
|
-
return AnthropicAgentModel(
|
|
149
|
-
self.client,
|
|
150
|
-
self.model_name,
|
|
151
|
-
allow_text_result,
|
|
152
|
-
tools,
|
|
146
|
+
response = await self._messages_create(
|
|
147
|
+
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
153
148
|
)
|
|
154
|
-
|
|
155
|
-
def name(self) -> str:
|
|
156
|
-
return f'anthropic:{self.model_name}'
|
|
157
|
-
|
|
158
|
-
@staticmethod
|
|
159
|
-
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
160
|
-
return {
|
|
161
|
-
'name': f.name,
|
|
162
|
-
'description': f.description,
|
|
163
|
-
'input_schema': f.parameters_json_schema,
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
@dataclass
|
|
168
|
-
class AnthropicAgentModel(AgentModel):
|
|
169
|
-
"""Implementation of `AgentModel` for Anthropic models."""
|
|
170
|
-
|
|
171
|
-
client: AsyncAnthropic
|
|
172
|
-
model_name: AnthropicModelName
|
|
173
|
-
allow_text_result: bool
|
|
174
|
-
tools: list[ToolParam]
|
|
175
|
-
|
|
176
|
-
async def request(
|
|
177
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
178
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
179
|
-
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
|
|
180
149
|
return self._process_response(response), _map_usage(response)
|
|
181
150
|
|
|
182
151
|
@asynccontextmanager
|
|
183
152
|
async def request_stream(
|
|
184
|
-
self,
|
|
153
|
+
self,
|
|
154
|
+
messages: list[ModelMessage],
|
|
155
|
+
model_settings: ModelSettings | None,
|
|
156
|
+
model_request_parameters: ModelRequestParameters,
|
|
185
157
|
) -> AsyncIterator[StreamedResponse]:
|
|
186
|
-
|
|
158
|
+
check_allow_model_requests()
|
|
159
|
+
response = await self._messages_create(
|
|
160
|
+
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
161
|
+
)
|
|
187
162
|
async with response:
|
|
188
163
|
yield await self._process_streamed_response(response)
|
|
189
164
|
|
|
165
|
+
@property
|
|
166
|
+
def model_name(self) -> AnthropicModelName:
|
|
167
|
+
"""The model name."""
|
|
168
|
+
return self._model_name
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def system(self) -> str | None:
|
|
172
|
+
"""The system / model provider."""
|
|
173
|
+
return self._system
|
|
174
|
+
|
|
190
175
|
@overload
|
|
191
176
|
async def _messages_create(
|
|
192
|
-
self,
|
|
177
|
+
self,
|
|
178
|
+
messages: list[ModelMessage],
|
|
179
|
+
stream: Literal[True],
|
|
180
|
+
model_settings: AnthropicModelSettings,
|
|
181
|
+
model_request_parameters: ModelRequestParameters,
|
|
193
182
|
) -> AsyncStream[RawMessageStreamEvent]:
|
|
194
183
|
pass
|
|
195
184
|
|
|
196
185
|
@overload
|
|
197
186
|
async def _messages_create(
|
|
198
|
-
self,
|
|
187
|
+
self,
|
|
188
|
+
messages: list[ModelMessage],
|
|
189
|
+
stream: Literal[False],
|
|
190
|
+
model_settings: AnthropicModelSettings,
|
|
191
|
+
model_request_parameters: ModelRequestParameters,
|
|
199
192
|
) -> AnthropicMessage:
|
|
200
193
|
pass
|
|
201
194
|
|
|
202
195
|
async def _messages_create(
|
|
203
|
-
self,
|
|
196
|
+
self,
|
|
197
|
+
messages: list[ModelMessage],
|
|
198
|
+
stream: bool,
|
|
199
|
+
model_settings: AnthropicModelSettings,
|
|
200
|
+
model_request_parameters: ModelRequestParameters,
|
|
204
201
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
205
202
|
# standalone function to make it easier to override
|
|
203
|
+
tools = self._get_tools(model_request_parameters)
|
|
206
204
|
tool_choice: ToolChoiceParam | None
|
|
207
205
|
|
|
208
|
-
if not
|
|
206
|
+
if not tools:
|
|
209
207
|
tool_choice = None
|
|
210
208
|
else:
|
|
211
|
-
if not
|
|
209
|
+
if not model_request_parameters.allow_text_result:
|
|
212
210
|
tool_choice = {'type': 'any'}
|
|
213
211
|
else:
|
|
214
212
|
tool_choice = {'type': 'auto'}
|
|
@@ -222,8 +220,8 @@ class AnthropicAgentModel(AgentModel):
|
|
|
222
220
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
223
221
|
system=system_prompt or NOT_GIVEN,
|
|
224
222
|
messages=anthropic_messages,
|
|
225
|
-
model=self.
|
|
226
|
-
tools=
|
|
223
|
+
model=self._model_name,
|
|
224
|
+
tools=tools or NOT_GIVEN,
|
|
227
225
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
228
226
|
stream=stream,
|
|
229
227
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
@@ -248,7 +246,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
248
246
|
)
|
|
249
247
|
)
|
|
250
248
|
|
|
251
|
-
return ModelResponse(items, model_name=
|
|
249
|
+
return ModelResponse(items, model_name=response.model)
|
|
252
250
|
|
|
253
251
|
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
254
252
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -258,10 +256,17 @@ class AnthropicAgentModel(AgentModel):
|
|
|
258
256
|
|
|
259
257
|
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
260
258
|
timestamp = datetime.now(tz=timezone.utc)
|
|
261
|
-
return AnthropicStreamedResponse(
|
|
259
|
+
return AnthropicStreamedResponse(
|
|
260
|
+
_model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
|
|
261
|
+
)
|
|
262
262
|
|
|
263
|
-
|
|
264
|
-
|
|
263
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
|
|
264
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
265
|
+
if model_request_parameters.result_tools:
|
|
266
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
267
|
+
return tools
|
|
268
|
+
|
|
269
|
+
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
265
270
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
266
271
|
system_prompt: str = ''
|
|
267
272
|
anthropic_messages: list[MessageParam] = []
|
|
@@ -310,20 +315,28 @@ class AnthropicAgentModel(AgentModel):
|
|
|
310
315
|
content.append(TextBlockParam(text=item.content, type='text'))
|
|
311
316
|
else:
|
|
312
317
|
assert isinstance(item, ToolCallPart)
|
|
313
|
-
content.append(_map_tool_call(item))
|
|
318
|
+
content.append(self._map_tool_call(item))
|
|
314
319
|
anthropic_messages.append(MessageParam(role='assistant', content=content))
|
|
315
320
|
else:
|
|
316
321
|
assert_never(m)
|
|
317
322
|
return system_prompt, anthropic_messages
|
|
318
323
|
|
|
324
|
+
@staticmethod
|
|
325
|
+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
326
|
+
return ToolUseBlockParam(
|
|
327
|
+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
328
|
+
type='tool_use',
|
|
329
|
+
name=t.tool_name,
|
|
330
|
+
input=t.args_as_dict(),
|
|
331
|
+
)
|
|
319
332
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
333
|
+
@staticmethod
|
|
334
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
335
|
+
return {
|
|
336
|
+
'name': f.name,
|
|
337
|
+
'description': f.description,
|
|
338
|
+
'input_schema': f.parameters_json_schema,
|
|
339
|
+
}
|
|
327
340
|
|
|
328
341
|
|
|
329
342
|
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
|
@@ -359,6 +372,7 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
|
|
|
359
372
|
class AnthropicStreamedResponse(StreamedResponse):
|
|
360
373
|
"""Implementation of `StreamedResponse` for Anthropic models."""
|
|
361
374
|
|
|
375
|
+
_model_name: AnthropicModelName
|
|
362
376
|
_response: AsyncIterable[RawMessageStreamEvent]
|
|
363
377
|
_timestamp: datetime
|
|
364
378
|
|
|
@@ -411,5 +425,12 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
411
425
|
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
|
|
412
426
|
current_block = None
|
|
413
427
|
|
|
428
|
+
@property
|
|
429
|
+
def model_name(self) -> AnthropicModelName:
|
|
430
|
+
"""Get the model name of the response."""
|
|
431
|
+
return self._model_name
|
|
432
|
+
|
|
433
|
+
@property
|
|
414
434
|
def timestamp(self) -> datetime:
|
|
435
|
+
"""Get the timestamp of the response."""
|
|
415
436
|
return self._timestamp
|