pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,344 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal, Union, cast, overload
7
+
8
+ from httpx import AsyncClient as AsyncHTTPClient
9
+ from typing_extensions import assert_never
10
+
11
+ from .. import result
12
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from ..messages import (
14
+ ArgsDict,
15
+ ModelMessage,
16
+ ModelRequest,
17
+ ModelResponse,
18
+ ModelResponsePart,
19
+ RetryPromptPart,
20
+ SystemPromptPart,
21
+ TextPart,
22
+ ToolCallPart,
23
+ ToolReturnPart,
24
+ UserPromptPart,
25
+ )
26
+ from ..settings import ModelSettings
27
+ from ..tools import ToolDefinition
28
+ from . import (
29
+ AgentModel,
30
+ EitherStreamedResponse,
31
+ Model,
32
+ cached_async_http_client,
33
+ check_allow_model_requests,
34
+ )
35
+
36
+ try:
37
+ from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
38
+ from anthropic.types import (
39
+ Message as AnthropicMessage,
40
+ MessageParam,
41
+ RawMessageDeltaEvent,
42
+ RawMessageStartEvent,
43
+ RawMessageStreamEvent,
44
+ TextBlock,
45
+ TextBlockParam,
46
+ ToolChoiceParam,
47
+ ToolParam,
48
+ ToolResultBlockParam,
49
+ ToolUseBlock,
50
+ ToolUseBlockParam,
51
+ )
52
+ except ImportError as _import_error:
53
+ raise ImportError(
54
+ 'Please install `anthropic` to use the Anthropic model, '
55
+ "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
56
+ ) from _import_error
57
+
58
+ LatestAnthropicModelNames = Literal[
59
+ 'claude-3-5-haiku-latest',
60
+ 'claude-3-5-sonnet-latest',
61
+ 'claude-3-opus-latest',
62
+ ]
63
+ """Latest named Anthropic models."""
64
+
65
+ AnthropicModelName = Union[str, LatestAnthropicModelNames]
66
+ """Possible Anthropic model names.
67
+
68
+ Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
69
+ allow any name in the type hints.
70
+ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
71
+ """
72
+
73
+
74
+ @dataclass(init=False)
75
+ class AnthropicModel(Model):
76
+ """A model that uses the Anthropic API.
77
+
78
+ Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
79
+
80
+ Apart from `__init__`, all methods are private or match those of the base class.
81
+
82
+ !!! note
83
+ The `AnthropicModel` class does not yet support streaming responses.
84
+ We anticipate adding support for streaming responses in a near-term future release.
85
+ """
86
+
87
+ model_name: AnthropicModelName
88
+ client: AsyncAnthropic = field(repr=False)
89
+
90
+ def __init__(
91
+ self,
92
+ model_name: AnthropicModelName,
93
+ *,
94
+ api_key: str | None = None,
95
+ anthropic_client: AsyncAnthropic | None = None,
96
+ http_client: AsyncHTTPClient | None = None,
97
+ ):
98
+ """Initialize an Anthropic model.
99
+
100
+ Args:
101
+ model_name: The name of the Anthropic model to use. List of model names available
102
+ [here](https://docs.anthropic.com/en/docs/about-claude/models).
103
+ api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
104
+ will be used if available.
105
+ anthropic_client: An existing
106
+ [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
107
+ client to use, if provided, `api_key` and `http_client` must be `None`.
108
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
109
+ """
110
+ self.model_name = model_name
111
+ if anthropic_client is not None:
112
+ assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
113
+ assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
114
+ self.client = anthropic_client
115
+ elif http_client is not None:
116
+ self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
117
+ else:
118
+ self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
119
+
120
+ async def agent_model(
121
+ self,
122
+ *,
123
+ function_tools: list[ToolDefinition],
124
+ allow_text_result: bool,
125
+ result_tools: list[ToolDefinition],
126
+ ) -> AgentModel:
127
+ check_allow_model_requests()
128
+ tools = [self._map_tool_definition(r) for r in function_tools]
129
+ if result_tools:
130
+ tools += [self._map_tool_definition(r) for r in result_tools]
131
+ return AnthropicAgentModel(
132
+ self.client,
133
+ self.model_name,
134
+ allow_text_result,
135
+ tools,
136
+ )
137
+
138
+ def name(self) -> str:
139
+ return self.model_name
140
+
141
+ @staticmethod
142
+ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
143
+ return {
144
+ 'name': f.name,
145
+ 'description': f.description,
146
+ 'input_schema': f.parameters_json_schema,
147
+ }
148
+
149
+
150
+ @dataclass
151
+ class AnthropicAgentModel(AgentModel):
152
+ """Implementation of `AgentModel` for Anthropic models."""
153
+
154
+ client: AsyncAnthropic
155
+ model_name: str
156
+ allow_text_result: bool
157
+ tools: list[ToolParam]
158
+
159
+ async def request(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> tuple[ModelResponse, result.Cost]:
162
+ response = await self._messages_create(messages, False, model_settings)
163
+ return self._process_response(response), _map_cost(response)
164
+
165
+ @asynccontextmanager
166
+ async def request_stream(
167
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
+ ) -> AsyncIterator[EitherStreamedResponse]:
169
+ response = await self._messages_create(messages, True, model_settings)
170
+ async with response:
171
+ yield await self._process_streamed_response(response)
172
+
173
+ @overload
174
+ async def _messages_create(
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
176
+ ) -> AsyncStream[RawMessageStreamEvent]:
177
+ pass
178
+
179
+ @overload
180
+ async def _messages_create(
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
182
+ ) -> AnthropicMessage:
183
+ pass
184
+
185
+ async def _messages_create(
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187
+ ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
+ # standalone function to make it easier to override
189
+ if not self.tools:
190
+ tool_choice: ToolChoiceParam | None = None
191
+ elif not self.allow_text_result:
192
+ tool_choice = {'type': 'any'}
193
+ else:
194
+ tool_choice = {'type': 'auto'}
195
+
196
+ system_prompt, anthropic_messages = self._map_message(messages)
197
+
198
+ model_settings = model_settings or {}
199
+
200
+ return await self.client.messages.create(
201
+ max_tokens=model_settings.get('max_tokens', 1024),
202
+ system=system_prompt or NOT_GIVEN,
203
+ messages=anthropic_messages,
204
+ model=self.model_name,
205
+ tools=self.tools or NOT_GIVEN,
206
+ tool_choice=tool_choice or NOT_GIVEN,
207
+ stream=stream,
208
+ temperature=model_settings.get('temperature', NOT_GIVEN),
209
+ top_p=model_settings.get('top_p', NOT_GIVEN),
210
+ timeout=model_settings.get('timeout', NOT_GIVEN),
211
+ )
212
+
213
+ @staticmethod
214
+ def _process_response(response: AnthropicMessage) -> ModelResponse:
215
+ """Process a non-streamed response, and prepare a message to return."""
216
+ items: list[ModelResponsePart] = []
217
+ for item in response.content:
218
+ if isinstance(item, TextBlock):
219
+ items.append(TextPart(item.text))
220
+ else:
221
+ assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
+ items.append(
223
+ ToolCallPart.from_dict(
224
+ item.name,
225
+ cast(dict[str, Any], item.input),
226
+ item.id,
227
+ )
228
+ )
229
+
230
+ return ModelResponse(items)
231
+
232
+ @staticmethod
233
+ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
234
+ """TODO: Process a streamed response, and prepare a streaming response to return."""
235
+ # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
+ # Streamed responses will be supported in a future release.
237
+
238
+ raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
+
240
+ # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
241
+ # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
+ # RawMessageStartEvent
243
+ # RawMessageDeltaEvent
244
+ # RawMessageStopEvent
245
+ # RawContentBlockStartEvent
246
+ # RawContentBlockDeltaEvent
247
+ # RawContentBlockDeltaEvent
248
+ #
249
+ # We might refactor streaming internally before we implement this...
250
+
251
+ @staticmethod
252
+ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
253
+ """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
254
+ system_prompt: str = ''
255
+ anthropic_messages: list[MessageParam] = []
256
+ for m in messages:
257
+ if isinstance(m, ModelRequest):
258
+ for part in m.parts:
259
+ if isinstance(part, SystemPromptPart):
260
+ system_prompt += part.content
261
+ elif isinstance(part, UserPromptPart):
262
+ anthropic_messages.append(MessageParam(role='user', content=part.content))
263
+ elif isinstance(part, ToolReturnPart):
264
+ anthropic_messages.append(
265
+ MessageParam(
266
+ role='user',
267
+ content=[
268
+ ToolResultBlockParam(
269
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
270
+ type='tool_result',
271
+ content=part.model_response_str(),
272
+ is_error=False,
273
+ )
274
+ ],
275
+ )
276
+ )
277
+ elif isinstance(part, RetryPromptPart):
278
+ if part.tool_name is None:
279
+ anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
280
+ else:
281
+ anthropic_messages.append(
282
+ MessageParam(
283
+ role='user',
284
+ content=[
285
+ ToolUseBlockParam(
286
+ id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
+ input=part.model_response(),
288
+ name=part.tool_name,
289
+ type='tool_use',
290
+ ),
291
+ ],
292
+ )
293
+ )
294
+ elif isinstance(m, ModelResponse):
295
+ content: list[TextBlockParam | ToolUseBlockParam] = []
296
+ for item in m.parts:
297
+ if isinstance(item, TextPart):
298
+ content.append(TextBlockParam(text=item.content, type='text'))
299
+ else:
300
+ assert isinstance(item, ToolCallPart)
301
+ content.append(_map_tool_call(item))
302
+ anthropic_messages.append(MessageParam(role='assistant', content=content))
303
+ else:
304
+ assert_never(m)
305
+ return system_prompt, anthropic_messages
306
+
307
+
308
+ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
309
+ assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
310
+ return ToolUseBlockParam(
311
+ id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312
+ type='tool_use',
313
+ name=t.tool_name,
314
+ input=t.args.args_dict,
315
+ )
316
+
317
+
318
+ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
319
+ if isinstance(message, AnthropicMessage):
320
+ usage = message.usage
321
+ else:
322
+ if isinstance(message, RawMessageStartEvent):
323
+ usage = message.message.usage
324
+ elif isinstance(message, RawMessageDeltaEvent):
325
+ usage = message.usage
326
+ else:
327
+ # No usage information provided in:
328
+ # - RawMessageStopEvent
329
+ # - RawContentBlockStartEvent
330
+ # - RawContentBlockDeltaEvent
331
+ # - RawContentBlockStopEvent
332
+ usage = None
333
+
334
+ if usage is None:
335
+ return result.Cost()
336
+
337
+ request_tokens = getattr(usage, 'input_tokens', None)
338
+
339
+ return result.Cost(
340
+ # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
+ request_tokens=request_tokens,
342
+ response_tokens=usage.output_tokens,
343
+ total_tokens=(request_tokens or 0) + usage.output_tokens,
344
+ )
@@ -2,9 +2,9 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
5
+ from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union, cast
@@ -13,15 +13,22 @@ import pydantic_core
13
13
  from typing_extensions import TypeAlias, assert_never, overload
14
14
 
15
15
  from .. import _utils, result
16
- from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
17
- from . import (
18
- AbstractToolDefinition,
19
- AgentModel,
20
- EitherStreamedResponse,
21
- Model,
22
- StreamStructuredResponse,
23
- StreamTextResponse,
16
+ from ..messages import (
17
+ ArgsJson,
18
+ ModelMessage,
19
+ ModelRequest,
20
+ ModelResponse,
21
+ ModelResponsePart,
22
+ RetryPromptPart,
23
+ SystemPromptPart,
24
+ TextPart,
25
+ ToolCallPart,
26
+ ToolReturnPart,
27
+ UserPromptPart,
24
28
  )
29
+ from ..settings import ModelSettings
30
+ from ..tools import ToolDefinition
31
+ from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
25
32
 
26
33
 
27
34
  @dataclass(init=False)
@@ -59,13 +66,13 @@ class FunctionModel(Model):
59
66
 
60
67
  async def agent_model(
61
68
  self,
62
- function_tools: Mapping[str, AbstractToolDefinition],
69
+ *,
70
+ function_tools: list[ToolDefinition],
63
71
  allow_text_result: bool,
64
- result_tools: Sequence[AbstractToolDefinition] | None,
72
+ result_tools: list[ToolDefinition],
65
73
  ) -> AgentModel:
66
- result_tools = list(result_tools) if result_tools is not None else None
67
74
  return FunctionAgentModel(
68
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
75
+ self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
69
76
  )
70
77
 
71
78
  def name(self) -> str:
@@ -84,7 +91,7 @@ class AgentInfo:
84
91
  This is passed as the second to functions used within [`FunctionModel`][pydantic_ai.models.function.FunctionModel].
85
92
  """
86
93
 
87
- function_tools: Mapping[str, AbstractToolDefinition]
94
+ function_tools: list[ToolDefinition]
88
95
  """The function tools available on this agent.
89
96
 
90
97
  These are the tools registered via the [`tool`][pydantic_ai.Agent.tool] and
@@ -92,8 +99,10 @@ class AgentInfo:
92
99
  """
93
100
  allow_text_result: bool
94
101
  """Whether a plain text result is allowed."""
95
- result_tools: list[AbstractToolDefinition] | None
102
+ result_tools: list[ToolDefinition]
96
103
  """The tools that can called as the final result of the run."""
104
+ model_settings: ModelSettings | None
105
+ """The model settings passed to the run call."""
97
106
 
98
107
 
99
108
  @dataclass
@@ -112,10 +121,10 @@ class DeltaToolCall:
112
121
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
113
122
  """A mapping of tool call IDs to incremental changes."""
114
123
 
115
- FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]]
124
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
116
125
  """A function used to generate a non-streamed response."""
117
126
 
118
- StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
127
+ StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
119
128
  """A function used to generate a streamed response.
120
129
 
121
130
  While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
@@ -133,18 +142,25 @@ class FunctionAgentModel(AgentModel):
133
142
  stream_function: StreamFunctionDef | None
134
143
  agent_info: AgentInfo
135
144
 
136
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
145
+ async def request(
146
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
147
+ ) -> tuple[ModelResponse, result.Cost]:
148
+ agent_info = replace(self.agent_info, model_settings=model_settings)
149
+
137
150
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
138
151
  if inspect.iscoroutinefunction(self.function):
139
- response = await self.function(messages, self.agent_info)
152
+ response = await self.function(messages, agent_info)
140
153
  else:
141
- response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
142
- response = cast(ModelAnyResponse, response_)
154
+ response_ = await _utils.run_in_executor(self.function, messages, agent_info)
155
+ assert isinstance(response_, ModelResponse), response_
156
+ response = response_
143
157
  # TODO is `messages` right here? Should it just be new messages?
144
158
  return response, _estimate_cost(chain(messages, [response]))
145
159
 
146
160
  @asynccontextmanager
147
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
161
+ async def request_stream(
162
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
163
+ ) -> AsyncIterator[EitherStreamedResponse]:
148
164
  assert (
149
165
  self.stream_function is not None
150
166
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
@@ -212,13 +228,13 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
212
228
  else:
213
229
  self._delta_tool_calls[key] = new
214
230
 
215
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
216
- calls: list[ToolCall] = []
231
+ def get(self, *, final: bool = False) -> ModelResponse:
232
+ calls: list[ModelResponsePart] = []
217
233
  for c in self._delta_tool_calls.values():
218
234
  if c.name is not None and c.json_args is not None:
219
- calls.append(ToolCall.from_json(c.name, c.json_args))
235
+ calls.append(ToolCallPart.from_json(c.name, c.json_args))
220
236
 
221
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
237
+ return ModelResponse(calls, timestamp=self._timestamp)
222
238
 
223
239
  def cost(self) -> result.Cost:
224
240
  return result.Cost()
@@ -227,32 +243,38 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
227
243
  return self._timestamp
228
244
 
229
245
 
230
- def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
246
+ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
231
247
  """Very rough guesstimate of the number of tokens associate with a series of messages.
232
248
 
233
249
  This is designed to be used solely to give plausible numbers for testing!
234
250
  """
235
251
  # there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
236
-
237
252
  request_tokens = 50
238
253
  response_tokens = 0
239
254
  for message in messages:
240
- if message.role == 'system' or message.role == 'user':
241
- request_tokens += _string_cost(message.content)
242
- elif message.role == 'tool-return':
243
- request_tokens += _string_cost(message.model_response_str())
244
- elif message.role == 'retry-prompt':
245
- request_tokens += _string_cost(message.model_response())
246
- elif message.role == 'model-text-response':
247
- response_tokens += _string_cost(message.content)
248
- elif message.role == 'model-structured-response':
249
- for call in message.calls:
250
- if isinstance(call.args, ArgsJson):
251
- args_str = call.args.args_json
255
+ if isinstance(message, ModelRequest):
256
+ for part in message.parts:
257
+ if isinstance(part, (SystemPromptPart, UserPromptPart)):
258
+ request_tokens += _string_cost(part.content)
259
+ elif isinstance(part, ToolReturnPart):
260
+ request_tokens += _string_cost(part.model_response_str())
261
+ elif isinstance(part, RetryPromptPart):
262
+ request_tokens += _string_cost(part.model_response())
252
263
  else:
253
- args_str = pydantic_core.to_json(call.args.args_dict).decode()
254
-
255
- response_tokens += 1 + _string_cost(args_str)
264
+ assert_never(part)
265
+ elif isinstance(message, ModelResponse):
266
+ for part in message.parts:
267
+ if isinstance(part, TextPart):
268
+ response_tokens += _string_cost(part.content)
269
+ elif isinstance(part, ToolCallPart):
270
+ call = part
271
+ if isinstance(call.args, ArgsJson):
272
+ args_str = call.args.args_json
273
+ else:
274
+ args_str = pydantic_core.to_json(call.args.args_dict).decode()
275
+ response_tokens += 1 + _string_cost(args_str)
276
+ else:
277
+ assert_never(part)
256
278
  else:
257
279
  assert_never(message)
258
280
  return result.Cost(