pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,344 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal, Union, cast, overload
7
+
8
+ from httpx import AsyncClient as AsyncHTTPClient
9
+ from typing_extensions import assert_never
10
+
11
+ from .. import result
12
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from ..messages import (
14
+ ArgsDict,
15
+ ModelMessage,
16
+ ModelRequest,
17
+ ModelResponse,
18
+ ModelResponsePart,
19
+ RetryPromptPart,
20
+ SystemPromptPart,
21
+ TextPart,
22
+ ToolCallPart,
23
+ ToolReturnPart,
24
+ UserPromptPart,
25
+ )
26
+ from ..settings import ModelSettings
27
+ from ..tools import ToolDefinition
28
+ from . import (
29
+ AgentModel,
30
+ EitherStreamedResponse,
31
+ Model,
32
+ cached_async_http_client,
33
+ check_allow_model_requests,
34
+ )
35
+
36
+ try:
37
+ from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
38
+ from anthropic.types import (
39
+ Message as AnthropicMessage,
40
+ MessageParam,
41
+ RawMessageDeltaEvent,
42
+ RawMessageStartEvent,
43
+ RawMessageStreamEvent,
44
+ TextBlock,
45
+ TextBlockParam,
46
+ ToolChoiceParam,
47
+ ToolParam,
48
+ ToolResultBlockParam,
49
+ ToolUseBlock,
50
+ ToolUseBlockParam,
51
+ )
52
+ except ImportError as _import_error:
53
+ raise ImportError(
54
+ 'Please install `anthropic` to use the Anthropic model, '
55
+ "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
56
+ ) from _import_error
57
+
58
+ LatestAnthropicModelNames = Literal[
59
+ 'claude-3-5-haiku-latest',
60
+ 'claude-3-5-sonnet-latest',
61
+ 'claude-3-opus-latest',
62
+ ]
63
+ """Latest named Anthropic models."""
64
+
65
+ AnthropicModelName = Union[str, LatestAnthropicModelNames]
66
+ """Possible Anthropic model names.
67
+
68
+ Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
69
+ allow any name in the type hints.
70
+ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
71
+ """
72
+
73
+
74
+ @dataclass(init=False)
75
+ class AnthropicModel(Model):
76
+ """A model that uses the Anthropic API.
77
+
78
+ Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
79
+
80
+ Apart from `__init__`, all methods are private or match those of the base class.
81
+
82
+ !!! note
83
+ The `AnthropicModel` class does not yet support streaming responses.
84
+ We anticipate adding support for streaming responses in a near-term future release.
85
+ """
86
+
87
+ model_name: AnthropicModelName
88
+ client: AsyncAnthropic = field(repr=False)
89
+
90
+ def __init__(
91
+ self,
92
+ model_name: AnthropicModelName,
93
+ *,
94
+ api_key: str | None = None,
95
+ anthropic_client: AsyncAnthropic | None = None,
96
+ http_client: AsyncHTTPClient | None = None,
97
+ ):
98
+ """Initialize an Anthropic model.
99
+
100
+ Args:
101
+ model_name: The name of the Anthropic model to use. List of model names available
102
+ [here](https://docs.anthropic.com/en/docs/about-claude/models).
103
+ api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
104
+ will be used if available.
105
+ anthropic_client: An existing
106
+ [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
107
+ client to use, if provided, `api_key` and `http_client` must be `None`.
108
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
109
+ """
110
+ self.model_name = model_name
111
+ if anthropic_client is not None:
112
+ assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
113
+ assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
114
+ self.client = anthropic_client
115
+ elif http_client is not None:
116
+ self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
117
+ else:
118
+ self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
119
+
120
+ async def agent_model(
121
+ self,
122
+ *,
123
+ function_tools: list[ToolDefinition],
124
+ allow_text_result: bool,
125
+ result_tools: list[ToolDefinition],
126
+ ) -> AgentModel:
127
+ check_allow_model_requests()
128
+ tools = [self._map_tool_definition(r) for r in function_tools]
129
+ if result_tools:
130
+ tools += [self._map_tool_definition(r) for r in result_tools]
131
+ return AnthropicAgentModel(
132
+ self.client,
133
+ self.model_name,
134
+ allow_text_result,
135
+ tools,
136
+ )
137
+
138
+ def name(self) -> str:
139
+ return self.model_name
140
+
141
+ @staticmethod
142
+ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
143
+ return {
144
+ 'name': f.name,
145
+ 'description': f.description,
146
+ 'input_schema': f.parameters_json_schema,
147
+ }
148
+
149
+
150
+ @dataclass
151
+ class AnthropicAgentModel(AgentModel):
152
+ """Implementation of `AgentModel` for Anthropic models."""
153
+
154
+ client: AsyncAnthropic
155
+ model_name: str
156
+ allow_text_result: bool
157
+ tools: list[ToolParam]
158
+
159
+ async def request(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> tuple[ModelResponse, result.Usage]:
162
+ response = await self._messages_create(messages, False, model_settings)
163
+ return self._process_response(response), _map_usage(response)
164
+
165
+ @asynccontextmanager
166
+ async def request_stream(
167
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
+ ) -> AsyncIterator[EitherStreamedResponse]:
169
+ response = await self._messages_create(messages, True, model_settings)
170
+ async with response:
171
+ yield await self._process_streamed_response(response)
172
+
173
+ @overload
174
+ async def _messages_create(
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
176
+ ) -> AsyncStream[RawMessageStreamEvent]:
177
+ pass
178
+
179
+ @overload
180
+ async def _messages_create(
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
182
+ ) -> AnthropicMessage:
183
+ pass
184
+
185
+ async def _messages_create(
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187
+ ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
+ # standalone function to make it easier to override
189
+ if not self.tools:
190
+ tool_choice: ToolChoiceParam | None = None
191
+ elif not self.allow_text_result:
192
+ tool_choice = {'type': 'any'}
193
+ else:
194
+ tool_choice = {'type': 'auto'}
195
+
196
+ system_prompt, anthropic_messages = self._map_message(messages)
197
+
198
+ model_settings = model_settings or {}
199
+
200
+ return await self.client.messages.create(
201
+ max_tokens=model_settings.get('max_tokens', 1024),
202
+ system=system_prompt or NOT_GIVEN,
203
+ messages=anthropic_messages,
204
+ model=self.model_name,
205
+ tools=self.tools or NOT_GIVEN,
206
+ tool_choice=tool_choice or NOT_GIVEN,
207
+ stream=stream,
208
+ temperature=model_settings.get('temperature', NOT_GIVEN),
209
+ top_p=model_settings.get('top_p', NOT_GIVEN),
210
+ timeout=model_settings.get('timeout', NOT_GIVEN),
211
+ )
212
+
213
+ @staticmethod
214
+ def _process_response(response: AnthropicMessage) -> ModelResponse:
215
+ """Process a non-streamed response, and prepare a message to return."""
216
+ items: list[ModelResponsePart] = []
217
+ for item in response.content:
218
+ if isinstance(item, TextBlock):
219
+ items.append(TextPart(item.text))
220
+ else:
221
+ assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
+ items.append(
223
+ ToolCallPart.from_raw_args(
224
+ item.name,
225
+ cast(dict[str, Any], item.input),
226
+ item.id,
227
+ )
228
+ )
229
+
230
+ return ModelResponse(items)
231
+
232
+ @staticmethod
233
+ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
234
+ """TODO: Process a streamed response, and prepare a streaming response to return."""
235
+ # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
+ # Streamed responses will be supported in a future release.
237
+
238
+ raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
+
240
+ # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
241
+ # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
+ # RawMessageStartEvent
243
+ # RawMessageDeltaEvent
244
+ # RawMessageStopEvent
245
+ # RawContentBlockStartEvent
246
+ # RawContentBlockDeltaEvent
247
+ # RawContentBlockDeltaEvent
248
+ #
249
+ # We might refactor streaming internally before we implement this...
250
+
251
+ @staticmethod
252
+ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
253
+ """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
254
+ system_prompt: str = ''
255
+ anthropic_messages: list[MessageParam] = []
256
+ for m in messages:
257
+ if isinstance(m, ModelRequest):
258
+ for part in m.parts:
259
+ if isinstance(part, SystemPromptPart):
260
+ system_prompt += part.content
261
+ elif isinstance(part, UserPromptPart):
262
+ anthropic_messages.append(MessageParam(role='user', content=part.content))
263
+ elif isinstance(part, ToolReturnPart):
264
+ anthropic_messages.append(
265
+ MessageParam(
266
+ role='user',
267
+ content=[
268
+ ToolResultBlockParam(
269
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
270
+ type='tool_result',
271
+ content=part.model_response_str(),
272
+ is_error=False,
273
+ )
274
+ ],
275
+ )
276
+ )
277
+ elif isinstance(part, RetryPromptPart):
278
+ if part.tool_name is None:
279
+ anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
280
+ else:
281
+ anthropic_messages.append(
282
+ MessageParam(
283
+ role='user',
284
+ content=[
285
+ ToolResultBlockParam(
286
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
+ type='tool_result',
288
+ content=part.model_response(),
289
+ is_error=True,
290
+ ),
291
+ ],
292
+ )
293
+ )
294
+ elif isinstance(m, ModelResponse):
295
+ content: list[TextBlockParam | ToolUseBlockParam] = []
296
+ for item in m.parts:
297
+ if isinstance(item, TextPart):
298
+ content.append(TextBlockParam(text=item.content, type='text'))
299
+ else:
300
+ assert isinstance(item, ToolCallPart)
301
+ content.append(_map_tool_call(item))
302
+ anthropic_messages.append(MessageParam(role='assistant', content=content))
303
+ else:
304
+ assert_never(m)
305
+ return system_prompt, anthropic_messages
306
+
307
+
308
+ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
309
+ assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
310
+ return ToolUseBlockParam(
311
+ id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312
+ type='tool_use',
313
+ name=t.tool_name,
314
+ input=t.args_as_dict(),
315
+ )
316
+
317
+
318
+ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
319
+ if isinstance(message, AnthropicMessage):
320
+ usage = message.usage
321
+ else:
322
+ if isinstance(message, RawMessageStartEvent):
323
+ usage = message.message.usage
324
+ elif isinstance(message, RawMessageDeltaEvent):
325
+ usage = message.usage
326
+ else:
327
+ # No usage information provided in:
328
+ # - RawMessageStopEvent
329
+ # - RawContentBlockStartEvent
330
+ # - RawContentBlockDeltaEvent
331
+ # - RawContentBlockStopEvent
332
+ usage = None
333
+
334
+ if usage is None:
335
+ return result.Usage()
336
+
337
+ request_tokens = getattr(usage, 'input_tokens', None)
338
+
339
+ return result.Usage(
340
+ # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
+ request_tokens=request_tokens,
342
+ response_tokens=usage.output_tokens,
343
+ total_tokens=(request_tokens or 0) + usage.output_tokens,
344
+ )
@@ -4,16 +4,27 @@ import inspect
4
4
  import re
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterable
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union, cast
11
11
 
12
- import pydantic_core
13
12
  from typing_extensions import TypeAlias, assert_never, overload
14
13
 
15
14
  from .. import _utils, result
16
- from ..messages import ArgsJson, Message, ModelAnyResponse, ModelStructuredResponse, ToolCall
15
+ from ..messages import (
16
+ ModelMessage,
17
+ ModelRequest,
18
+ ModelResponse,
19
+ ModelResponsePart,
20
+ RetryPromptPart,
21
+ SystemPromptPart,
22
+ TextPart,
23
+ ToolCallPart,
24
+ ToolReturnPart,
25
+ UserPromptPart,
26
+ )
27
+ from ..settings import ModelSettings
17
28
  from ..tools import ToolDefinition
18
29
  from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
19
30
 
@@ -59,7 +70,7 @@ class FunctionModel(Model):
59
70
  result_tools: list[ToolDefinition],
60
71
  ) -> AgentModel:
61
72
  return FunctionAgentModel(
62
- self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools)
73
+ self.function, self.stream_function, AgentInfo(function_tools, allow_text_result, result_tools, None)
63
74
  )
64
75
 
65
76
  def name(self) -> str:
@@ -88,6 +99,8 @@ class AgentInfo:
88
99
  """Whether a plain text result is allowed."""
89
100
  result_tools: list[ToolDefinition]
90
101
  """The tools that can called as the final result of the run."""
102
+ model_settings: ModelSettings | None
103
+ """The model settings passed to the run call."""
91
104
 
92
105
 
93
106
  @dataclass
@@ -106,10 +119,10 @@ class DeltaToolCall:
106
119
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
107
120
  """A mapping of tool call IDs to incremental changes."""
108
121
 
109
- FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelAnyResponse, Awaitable[ModelAnyResponse]]]
122
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
110
123
  """A function used to generate a non-streamed response."""
111
124
 
112
- StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
125
+ StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
113
126
  """A function used to generate a streamed response.
114
127
 
115
128
  While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
@@ -127,18 +140,25 @@ class FunctionAgentModel(AgentModel):
127
140
  stream_function: StreamFunctionDef | None
128
141
  agent_info: AgentInfo
129
142
 
130
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, result.Cost]:
143
+ async def request(
144
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
+ ) -> tuple[ModelResponse, result.Usage]:
146
+ agent_info = replace(self.agent_info, model_settings=model_settings)
147
+
131
148
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
132
149
  if inspect.iscoroutinefunction(self.function):
133
- response = await self.function(messages, self.agent_info)
150
+ response = await self.function(messages, agent_info)
134
151
  else:
135
- response_ = await _utils.run_in_executor(self.function, messages, self.agent_info)
136
- response = cast(ModelAnyResponse, response_)
152
+ response_ = await _utils.run_in_executor(self.function, messages, agent_info)
153
+ assert isinstance(response_, ModelResponse), response_
154
+ response = response_
137
155
  # TODO is `messages` right here? Should it just be new messages?
138
- return response, _estimate_cost(chain(messages, [response]))
156
+ return response, _estimate_usage(chain(messages, [response]))
139
157
 
140
158
  @asynccontextmanager
141
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
159
+ async def request_stream(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> AsyncIterator[EitherStreamedResponse]:
142
162
  assert (
143
163
  self.stream_function is not None
144
164
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
@@ -176,8 +196,8 @@ class FunctionStreamTextResponse(StreamTextResponse):
176
196
  yield from self._buffer
177
197
  self._buffer.clear()
178
198
 
179
- def cost(self) -> result.Cost:
180
- return result.Cost()
199
+ def usage(self) -> result.Usage:
200
+ return result.Usage()
181
201
 
182
202
  def timestamp(self) -> datetime:
183
203
  return self._timestamp
@@ -206,53 +226,55 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
206
226
  else:
207
227
  self._delta_tool_calls[key] = new
208
228
 
209
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
210
- calls: list[ToolCall] = []
229
+ def get(self, *, final: bool = False) -> ModelResponse:
230
+ calls: list[ModelResponsePart] = []
211
231
  for c in self._delta_tool_calls.values():
212
232
  if c.name is not None and c.json_args is not None:
213
- calls.append(ToolCall.from_json(c.name, c.json_args))
233
+ calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
214
234
 
215
- return ModelStructuredResponse(calls, timestamp=self._timestamp)
235
+ return ModelResponse(calls, timestamp=self._timestamp)
216
236
 
217
- def cost(self) -> result.Cost:
218
- return result.Cost()
237
+ def usage(self) -> result.Usage:
238
+ return _estimate_usage([self.get()])
219
239
 
220
240
  def timestamp(self) -> datetime:
221
241
  return self._timestamp
222
242
 
223
243
 
224
- def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
225
- """Very rough guesstimate of the number of tokens associate with a series of messages.
244
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
245
+ """Very rough guesstimate of the token usage associated with a series of messages.
226
246
 
227
247
  This is designed to be used solely to give plausible numbers for testing!
228
248
  """
229
249
  # there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
230
-
231
250
  request_tokens = 50
232
251
  response_tokens = 0
233
252
  for message in messages:
234
- if message.role == 'system' or message.role == 'user':
235
- request_tokens += _string_cost(message.content)
236
- elif message.role == 'tool-return':
237
- request_tokens += _string_cost(message.model_response_str())
238
- elif message.role == 'retry-prompt':
239
- request_tokens += _string_cost(message.model_response())
240
- elif message.role == 'model-text-response':
241
- response_tokens += _string_cost(message.content)
242
- elif message.role == 'model-structured-response':
243
- for call in message.calls:
244
- if isinstance(call.args, ArgsJson):
245
- args_str = call.args.args_json
253
+ if isinstance(message, ModelRequest):
254
+ for part in message.parts:
255
+ if isinstance(part, (SystemPromptPart, UserPromptPart)):
256
+ request_tokens += _estimate_string_usage(part.content)
257
+ elif isinstance(part, ToolReturnPart):
258
+ request_tokens += _estimate_string_usage(part.model_response_str())
259
+ elif isinstance(part, RetryPromptPart):
260
+ request_tokens += _estimate_string_usage(part.model_response())
246
261
  else:
247
- args_str = pydantic_core.to_json(call.args.args_dict).decode()
248
-
249
- response_tokens += 1 + _string_cost(args_str)
262
+ assert_never(part)
263
+ elif isinstance(message, ModelResponse):
264
+ for part in message.parts:
265
+ if isinstance(part, TextPart):
266
+ response_tokens += _estimate_string_usage(part.content)
267
+ elif isinstance(part, ToolCallPart):
268
+ call = part
269
+ response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
270
+ else:
271
+ assert_never(part)
250
272
  else:
251
273
  assert_never(message)
252
- return result.Cost(
274
+ return result.Usage(
253
275
  request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
254
276
  )
255
277
 
256
278
 
257
- def _string_cost(content: str) -> int:
279
+ def _estimate_string_usage(content: str) -> int:
258
280
  return len(re.split(r'[\s",.:]+', content))