pydantic-ai-slim 0.0.22__tar.gz → 0.0.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (31) hide show
  1. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/PKG-INFO +3 -3
  2. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_agent_graph.py +12 -8
  3. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/agent.py +5 -5
  4. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/__init__.py +52 -45
  5. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/anthropic.py +87 -66
  6. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/cohere.py +65 -67
  7. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/function.py +76 -60
  8. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/gemini.py +153 -99
  9. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/groq.py +97 -72
  10. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/mistral.py +90 -71
  11. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/openai.py +110 -71
  12. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/test.py +99 -94
  13. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/models/vertexai.py +48 -44
  14. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/result.py +2 -2
  15. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pyproject.toml +3 -3
  16. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/.gitignore +0 -0
  17. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/README.md +0 -0
  18. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/__init__.py +0 -0
  19. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_griffe.py +0 -0
  20. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_parts_manager.py +0 -0
  21. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_pydantic.py +0 -0
  22. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_result.py +0 -0
  23. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_system_prompt.py +0 -0
  24. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/_utils.py +0 -0
  25. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/exceptions.py +0 -0
  26. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/format_as_xml.py +0 -0
  27. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/messages.py +0 -0
  28. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/py.typed +0 -0
  29. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/settings.py +0 -0
  30. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/tools.py +0 -0
  31. {pydantic_ai_slim-0.0.22 → pydantic_ai_slim-0.0.24}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -28,7 +28,7 @@ Requires-Dist: eval-type-backport>=0.2.0
28
28
  Requires-Dist: griffe>=1.3.2
29
29
  Requires-Dist: httpx>=0.27
30
30
  Requires-Dist: logfire-api>=1.2.0
31
- Requires-Dist: pydantic-graph==0.0.22
31
+ Requires-Dist: pydantic-graph==0.0.24
32
32
  Requires-Dist: pydantic>=2.10
33
33
  Provides-Extra: anthropic
34
34
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
@@ -41,7 +41,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
41
41
  Provides-Extra: mistral
42
42
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
43
43
  Provides-Extra: openai
44
- Requires-Dist: openai>=1.59.0; extra == 'openai'
44
+ Requires-Dist: openai>=1.61.0; extra == 'openai'
45
45
  Provides-Extra: vertexai
46
46
  Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
47
47
  Requires-Dist: requests>=2.32.3; extra == 'vertexai'
@@ -204,9 +204,9 @@ class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
204
204
  return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205
205
 
206
206
 
207
- async def _prepare_model(
207
+ async def _prepare_request_parameters(
208
208
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209
- ) -> models.AgentModel:
209
+ ) -> models.ModelRequestParameters:
210
210
  """Build tools and create an agent model."""
211
211
  function_tool_defs: list[ToolDefinition] = []
212
212
 
@@ -220,7 +220,7 @@ async def _prepare_model(
220
220
  await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
221
221
 
222
222
  result_schema = ctx.deps.result_schema
223
- return await run_context.model.agent_model(
223
+ return models.ModelRequestParameters(
224
224
  function_tools=function_tool_defs,
225
225
  allow_text_result=_allow_text_result(result_schema),
226
226
  result_tools=result_schema.tool_defs() if result_schema is not None else [],
@@ -245,13 +245,15 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
245
245
  # Increment run_step
246
246
  ctx.state.run_step += 1
247
247
 
248
- with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
249
- agent_model = await _prepare_model(ctx)
248
+ with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
249
+ model_request_parameters = await _prepare_request_parameters(ctx)
250
250
 
251
251
  # Actually make the model request
252
252
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
253
253
  with _logfire.span('model request') as span:
254
- model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings)
254
+ model_response, request_usage = await ctx.deps.model.request(
255
+ ctx.state.message_history, model_settings, model_request_parameters
256
+ )
255
257
  span.set_attribute('response', model_response)
256
258
  span.set_attribute('usage', request_usage)
257
259
 
@@ -405,12 +407,14 @@ class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any
405
407
  ctx.state.run_step += 1
406
408
 
407
409
  with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
408
- agent_model = await _prepare_model(ctx)
410
+ model_request_parameters = await _prepare_request_parameters(ctx)
409
411
 
410
412
  # Actually make the model request
411
413
  model_settings = merge_model_settings(ctx.deps.model_settings, None)
412
414
  with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
413
- async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response:
415
+ async with ctx.deps.model.request_stream(
416
+ ctx.state.message_history, model_settings, model_request_parameters
417
+ ) as streamed_response:
414
418
  ctx.state.usage.requests += 1
415
419
  model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
416
420
  # We want to end the "model request" span here, but we can't exit the context manager
@@ -275,7 +275,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
275
275
  """
276
276
  if infer_name and self.name is None:
277
277
  self._infer_name(inspect.currentframe())
278
- model_used = await self._get_model(model)
278
+ model_used = self._get_model(model)
279
279
 
280
280
  deps = self._get_deps(deps)
281
281
  new_message_index = len(message_history) if message_history else 0
@@ -309,7 +309,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
309
309
  '{agent_name} run {prompt=}',
310
310
  prompt=user_prompt,
311
311
  agent=self,
312
- model_name=model_used.name() if model_used else 'no-model',
312
+ model_name=model_used.model_name if model_used else 'no-model',
313
313
  agent_name=self.name or 'agent',
314
314
  ) as run_span:
315
315
  # Build the deps object for the graph
@@ -520,7 +520,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
520
520
  # f_back because `asynccontextmanager` adds one frame
521
521
  if frame := inspect.currentframe(): # pragma: no branch
522
522
  self._infer_name(frame.f_back)
523
- model_used = await self._get_model(model)
523
+ model_used = self._get_model(model)
524
524
 
525
525
  deps = self._get_deps(deps)
526
526
  new_message_index = len(message_history) if message_history else 0
@@ -554,7 +554,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
554
554
  '{agent_name} run stream {prompt=}',
555
555
  prompt=user_prompt,
556
556
  agent=self,
557
- model_name=model_used.name(),
557
+ model_name=model_used.model_name if model_used else 'no-model',
558
558
  agent_name=self.name or 'agent',
559
559
  ) as run_span:
560
560
  # Build the deps object for the graph
@@ -971,7 +971,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
971
971
 
972
972
  self._function_tools[tool.name] = tool
973
973
 
974
- async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
974
+ def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
975
975
  """Create a model configured for this agent.
976
976
 
977
977
  Args:
@@ -54,6 +54,8 @@ KnownModelName = Literal[
54
54
  'google-gla:gemini-2.0-flash-exp',
55
55
  'google-gla:gemini-2.0-flash-thinking-exp-01-21',
56
56
  'google-gla:gemini-exp-1206',
57
+ 'google-gla:gemini-2.0-flash',
58
+ 'google-gla:gemini-2.0-flash-lite-preview-02-05',
57
59
  'google-vertex:gemini-1.0-pro',
58
60
  'google-vertex:gemini-1.5-flash',
59
61
  'google-vertex:gemini-1.5-flash-8b',
@@ -61,6 +63,8 @@ KnownModelName = Literal[
61
63
  'google-vertex:gemini-2.0-flash-exp',
62
64
  'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
63
65
  'google-vertex:gemini-exp-1206',
66
+ 'google-vertex:gemini-2.0-flash',
67
+ 'google-vertex:gemini-2.0-flash-lite-preview-02-05',
64
68
  'gpt-3.5-turbo',
65
69
  'gpt-3.5-turbo-0125',
66
70
  'gpt-3.5-turbo-0301',
@@ -112,6 +116,8 @@ KnownModelName = Literal[
112
116
  'o1-mini-2024-09-12',
113
117
  'o1-preview',
114
118
  'o1-preview-2024-09-12',
119
+ 'o3-mini',
120
+ 'o3-mini-2025-01-31',
115
121
  'openai:chatgpt-4o-latest',
116
122
  'openai:gpt-3.5-turbo',
117
123
  'openai:gpt-3.5-turbo-0125',
@@ -149,6 +155,8 @@ KnownModelName = Literal[
149
155
  'openai:o1-mini-2024-09-12',
150
156
  'openai:o1-preview',
151
157
  'openai:o1-preview-2024-09-12',
158
+ 'openai:o3-mini',
159
+ 'openai:o3-mini-2025-01-31',
152
160
  'test',
153
161
  ]
154
162
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -157,49 +165,34 @@ KnownModelName = Literal[
157
165
  """
158
166
 
159
167
 
160
- class Model(ABC):
161
- """Abstract class for a model."""
162
-
163
- @abstractmethod
164
- async def agent_model(
165
- self,
166
- *,
167
- function_tools: list[ToolDefinition],
168
- allow_text_result: bool,
169
- result_tools: list[ToolDefinition],
170
- ) -> AgentModel:
171
- """Create an agent model, this is called for each step of an agent run.
172
-
173
- This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
174
-
175
- Args:
176
- function_tools: The tools available to the agent.
177
- allow_text_result: Whether a plain text final response/result is permitted.
178
- result_tools: Tool definitions for the final result tool(s), if any.
179
-
180
- Returns:
181
- An agent model.
182
- """
183
- raise NotImplementedError()
168
+ @dataclass
169
+ class ModelRequestParameters:
170
+ """Configuration for an agent's request to a model, specifically related to tools and result handling."""
184
171
 
185
- @abstractmethod
186
- def name(self) -> str:
187
- raise NotImplementedError()
172
+ function_tools: list[ToolDefinition]
173
+ allow_text_result: bool
174
+ result_tools: list[ToolDefinition]
188
175
 
189
176
 
190
- class AgentModel(ABC):
191
- """Model configured for each step of an Agent run."""
177
+ class Model(ABC):
178
+ """Abstract class for a model."""
192
179
 
193
180
  @abstractmethod
194
181
  async def request(
195
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
182
+ self,
183
+ messages: list[ModelMessage],
184
+ model_settings: ModelSettings | None,
185
+ model_request_parameters: ModelRequestParameters,
196
186
  ) -> tuple[ModelResponse, Usage]:
197
187
  """Make a request to the model."""
198
188
  raise NotImplementedError()
199
189
 
200
190
  @asynccontextmanager
201
191
  async def request_stream(
202
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
192
+ self,
193
+ messages: list[ModelMessage],
194
+ model_settings: ModelSettings | None,
195
+ model_request_parameters: ModelRequestParameters,
203
196
  ) -> AsyncIterator[StreamedResponse]:
204
197
  """Make a request to the model and return a streaming response."""
205
198
  # This method is not required, but you need to implement it if you want to support streamed responses
@@ -208,15 +201,26 @@ class AgentModel(ABC):
208
201
  # noinspection PyUnreachableCode
209
202
  yield # pragma: no cover
210
203
 
204
+ @property
205
+ @abstractmethod
206
+ def model_name(self) -> str:
207
+ """The model name."""
208
+ raise NotImplementedError()
209
+
210
+ @property
211
+ @abstractmethod
212
+ def system(self) -> str | None:
213
+ """The system / model provider, ex: openai."""
214
+ raise NotImplementedError()
215
+
211
216
 
212
217
  @dataclass
213
218
  class StreamedResponse(ABC):
214
219
  """Streamed response from an LLM when calling a tool."""
215
220
 
216
- _model_name: str
217
- _usage: Usage = field(default_factory=Usage, init=False)
218
221
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
219
222
  _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
223
+ _usage: Usage = field(default_factory=Usage, init=False)
220
224
 
221
225
  def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
222
226
  """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
@@ -238,17 +242,20 @@ class StreamedResponse(ABC):
238
242
  def get(self) -> ModelResponse:
239
243
  """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
240
244
  return ModelResponse(
241
- parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
245
+ parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp
242
246
  )
243
247
 
244
- def model_name(self) -> str:
245
- """Get the model name of the response."""
246
- return self._model_name
247
-
248
248
  def usage(self) -> Usage:
249
249
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
250
250
  return self._usage
251
251
 
252
+ @property
253
+ @abstractmethod
254
+ def model_name(self) -> str:
255
+ """Get the model name of the response."""
256
+ raise NotImplementedError()
257
+
258
+ @property
252
259
  @abstractmethod
253
260
  def timestamp(self) -> datetime:
254
261
  """Get the timestamp of the response."""
@@ -270,7 +277,7 @@ def check_allow_model_requests() -> None:
270
277
  """Check if model requests are allowed.
271
278
 
272
279
  If you're defining your own models that have costs or latency associated with their use, you should call this in
273
- [`Model.agent_model`][pydantic_ai.models.Model.agent_model].
280
+ [`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream].
274
281
 
275
282
  Raises:
276
283
  RuntimeError: If model requests are not allowed.
@@ -311,33 +318,33 @@ def infer_model(model: Model | KnownModelName) -> Model:
311
318
  from .openai import OpenAIModel
312
319
 
313
320
  return OpenAIModel(model[7:])
314
- elif model.startswith(('gpt', 'o1')):
321
+ elif model.startswith(('gpt', 'o1', 'o3')):
315
322
  from .openai import OpenAIModel
316
323
 
317
324
  return OpenAIModel(model)
318
325
  elif model.startswith('google-gla'):
319
326
  from .gemini import GeminiModel
320
327
 
321
- return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
328
+ return GeminiModel(model[11:])
322
329
  # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
323
330
  elif model.startswith('gemini'):
324
331
  from .gemini import GeminiModel
325
332
 
326
333
  # noinspection PyTypeChecker
327
- return GeminiModel(model) # pyright: ignore[reportArgumentType]
334
+ return GeminiModel(model)
328
335
  elif model.startswith('groq:'):
329
336
  from .groq import GroqModel
330
337
 
331
- return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
338
+ return GroqModel(model[5:])
332
339
  elif model.startswith('google-vertex'):
333
340
  from .vertexai import VertexAIModel
334
341
 
335
- return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
342
+ return VertexAIModel(model[14:])
336
343
  # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
337
344
  elif model.startswith('vertexai:'):
338
345
  from .vertexai import VertexAIModel
339
346
 
340
- return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
347
+ return VertexAIModel(model[9:])
341
348
  elif model.startswith('mistral:'):
342
349
  from .mistral import MistralModel
343
350
 
@@ -28,8 +28,8 @@ from ..messages import (
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
- AgentModel,
32
31
  Model,
32
+ ModelRequestParameters,
33
33
  StreamedResponse,
34
34
  cached_async_http_client,
35
35
  check_allow_model_requests,
@@ -68,14 +68,14 @@ LatestAnthropicModelNames = Literal[
68
68
  'claude-3-5-sonnet-latest',
69
69
  'claude-3-opus-latest',
70
70
  ]
71
- """Latest named Anthropic models."""
71
+ """Latest Anthropic models."""
72
72
 
73
73
  AnthropicModelName = Union[str, LatestAnthropicModelNames]
74
74
  """Possible Anthropic model names.
75
75
 
76
76
  Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
77
77
  allow any name in the type hints.
78
- Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
78
+ See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
79
79
  """
80
80
 
81
81
 
@@ -101,9 +101,11 @@ class AnthropicModel(Model):
101
101
  We anticipate adding support for streaming responses in a near-term future release.
102
102
  """
103
103
 
104
- model_name: AnthropicModelName
105
104
  client: AsyncAnthropic = field(repr=False)
106
105
 
106
+ _model_name: AnthropicModelName = field(repr=False)
107
+ _system: str | None = field(default='anthropic', repr=False)
108
+
107
109
  def __init__(
108
110
  self,
109
111
  model_name: AnthropicModelName,
@@ -124,7 +126,7 @@ class AnthropicModel(Model):
124
126
  client to use, if provided, `api_key` and `http_client` must be `None`.
125
127
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
126
128
  """
127
- self.model_name = model_name
129
+ self._model_name = model_name
128
130
  if anthropic_client is not None:
129
131
  assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
130
132
  assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
@@ -134,81 +136,77 @@ class AnthropicModel(Model):
134
136
  else:
135
137
  self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
136
138
 
137
- async def agent_model(
139
+ async def request(
138
140
  self,
139
- *,
140
- function_tools: list[ToolDefinition],
141
- allow_text_result: bool,
142
- result_tools: list[ToolDefinition],
143
- ) -> AgentModel:
141
+ messages: list[ModelMessage],
142
+ model_settings: ModelSettings | None,
143
+ model_request_parameters: ModelRequestParameters,
144
+ ) -> tuple[ModelResponse, usage.Usage]:
144
145
  check_allow_model_requests()
145
- tools = [self._map_tool_definition(r) for r in function_tools]
146
- if result_tools:
147
- tools += [self._map_tool_definition(r) for r in result_tools]
148
- return AnthropicAgentModel(
149
- self.client,
150
- self.model_name,
151
- allow_text_result,
152
- tools,
146
+ response = await self._messages_create(
147
+ messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
153
148
  )
154
-
155
- def name(self) -> str:
156
- return f'anthropic:{self.model_name}'
157
-
158
- @staticmethod
159
- def _map_tool_definition(f: ToolDefinition) -> ToolParam:
160
- return {
161
- 'name': f.name,
162
- 'description': f.description,
163
- 'input_schema': f.parameters_json_schema,
164
- }
165
-
166
-
167
- @dataclass
168
- class AnthropicAgentModel(AgentModel):
169
- """Implementation of `AgentModel` for Anthropic models."""
170
-
171
- client: AsyncAnthropic
172
- model_name: AnthropicModelName
173
- allow_text_result: bool
174
- tools: list[ToolParam]
175
-
176
- async def request(
177
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
178
- ) -> tuple[ModelResponse, usage.Usage]:
179
- response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
180
149
  return self._process_response(response), _map_usage(response)
181
150
 
182
151
  @asynccontextmanager
183
152
  async def request_stream(
184
- self, messages: list[ModelMessage], model_settings: ModelSettings | None
153
+ self,
154
+ messages: list[ModelMessage],
155
+ model_settings: ModelSettings | None,
156
+ model_request_parameters: ModelRequestParameters,
185
157
  ) -> AsyncIterator[StreamedResponse]:
186
- response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
158
+ check_allow_model_requests()
159
+ response = await self._messages_create(
160
+ messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
161
+ )
187
162
  async with response:
188
163
  yield await self._process_streamed_response(response)
189
164
 
165
+ @property
166
+ def model_name(self) -> AnthropicModelName:
167
+ """The model name."""
168
+ return self._model_name
169
+
170
+ @property
171
+ def system(self) -> str | None:
172
+ """The system / model provider."""
173
+ return self._system
174
+
190
175
  @overload
191
176
  async def _messages_create(
192
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
177
+ self,
178
+ messages: list[ModelMessage],
179
+ stream: Literal[True],
180
+ model_settings: AnthropicModelSettings,
181
+ model_request_parameters: ModelRequestParameters,
193
182
  ) -> AsyncStream[RawMessageStreamEvent]:
194
183
  pass
195
184
 
196
185
  @overload
197
186
  async def _messages_create(
198
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
187
+ self,
188
+ messages: list[ModelMessage],
189
+ stream: Literal[False],
190
+ model_settings: AnthropicModelSettings,
191
+ model_request_parameters: ModelRequestParameters,
199
192
  ) -> AnthropicMessage:
200
193
  pass
201
194
 
202
195
  async def _messages_create(
203
- self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
196
+ self,
197
+ messages: list[ModelMessage],
198
+ stream: bool,
199
+ model_settings: AnthropicModelSettings,
200
+ model_request_parameters: ModelRequestParameters,
204
201
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
205
202
  # standalone function to make it easier to override
203
+ tools = self._get_tools(model_request_parameters)
206
204
  tool_choice: ToolChoiceParam | None
207
205
 
208
- if not self.tools:
206
+ if not tools:
209
207
  tool_choice = None
210
208
  else:
211
- if not self.allow_text_result:
209
+ if not model_request_parameters.allow_text_result:
212
210
  tool_choice = {'type': 'any'}
213
211
  else:
214
212
  tool_choice = {'type': 'auto'}
@@ -222,8 +220,8 @@ class AnthropicAgentModel(AgentModel):
222
220
  max_tokens=model_settings.get('max_tokens', 1024),
223
221
  system=system_prompt or NOT_GIVEN,
224
222
  messages=anthropic_messages,
225
- model=self.model_name,
226
- tools=self.tools or NOT_GIVEN,
223
+ model=self._model_name,
224
+ tools=tools or NOT_GIVEN,
227
225
  tool_choice=tool_choice or NOT_GIVEN,
228
226
  stream=stream,
229
227
  temperature=model_settings.get('temperature', NOT_GIVEN),
@@ -248,7 +246,7 @@ class AnthropicAgentModel(AgentModel):
248
246
  )
249
247
  )
250
248
 
251
- return ModelResponse(items, model_name=self.model_name)
249
+ return ModelResponse(items, model_name=response.model)
252
250
 
253
251
  async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
254
252
  peekable_response = _utils.PeekableAsyncStream(response)
@@ -258,10 +256,17 @@ class AnthropicAgentModel(AgentModel):
258
256
 
259
257
  # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
260
258
  timestamp = datetime.now(tz=timezone.utc)
261
- return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
259
+ return AnthropicStreamedResponse(
260
+ _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
261
+ )
262
262
 
263
- @staticmethod
264
- def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
263
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
264
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
265
+ if model_request_parameters.result_tools:
266
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
267
+ return tools
268
+
269
+ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
265
270
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
266
271
  system_prompt: str = ''
267
272
  anthropic_messages: list[MessageParam] = []
@@ -310,20 +315,28 @@ class AnthropicAgentModel(AgentModel):
310
315
  content.append(TextBlockParam(text=item.content, type='text'))
311
316
  else:
312
317
  assert isinstance(item, ToolCallPart)
313
- content.append(_map_tool_call(item))
318
+ content.append(self._map_tool_call(item))
314
319
  anthropic_messages.append(MessageParam(role='assistant', content=content))
315
320
  else:
316
321
  assert_never(m)
317
322
  return system_prompt, anthropic_messages
318
323
 
324
+ @staticmethod
325
+ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
326
+ return ToolUseBlockParam(
327
+ id=_guard_tool_call_id(t=t, model_source='Anthropic'),
328
+ type='tool_use',
329
+ name=t.tool_name,
330
+ input=t.args_as_dict(),
331
+ )
319
332
 
320
- def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
321
- return ToolUseBlockParam(
322
- id=_guard_tool_call_id(t=t, model_source='Anthropic'),
323
- type='tool_use',
324
- name=t.tool_name,
325
- input=t.args_as_dict(),
326
- )
333
+ @staticmethod
334
+ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
335
+ return {
336
+ 'name': f.name,
337
+ 'description': f.description,
338
+ 'input_schema': f.parameters_json_schema,
339
+ }
327
340
 
328
341
 
329
342
  def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
@@ -359,6 +372,7 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
359
372
  class AnthropicStreamedResponse(StreamedResponse):
360
373
  """Implementation of `StreamedResponse` for Anthropic models."""
361
374
 
375
+ _model_name: AnthropicModelName
362
376
  _response: AsyncIterable[RawMessageStreamEvent]
363
377
  _timestamp: datetime
364
378
 
@@ -411,5 +425,12 @@ class AnthropicStreamedResponse(StreamedResponse):
411
425
  elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
412
426
  current_block = None
413
427
 
428
+ @property
429
+ def model_name(self) -> AnthropicModelName:
430
+ """Get the model name of the response."""
431
+ return self._model_name
432
+
433
+ @property
414
434
  def timestamp(self) -> datetime:
435
+ """Get the timestamp of the response."""
415
436
  return self._timestamp