pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterable, AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
+ from datetime import datetime, timezone
7
+ from json import JSONDecodeError, loads as json_loads
6
8
  from typing import Any, Literal, Union, cast, overload
7
9
 
8
10
  from httpx import AsyncClient as AsyncHTTPClient
9
11
  from typing_extensions import assert_never
10
12
 
11
- from .. import result
13
+ from .. import UnexpectedModelBehavior, _utils, usage
12
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
16
  ArgsDict,
@@ -16,6 +18,7 @@ from ..messages import (
16
18
  ModelRequest,
17
19
  ModelResponse,
18
20
  ModelResponsePart,
21
+ ModelResponseStreamEvent,
19
22
  RetryPromptPart,
20
23
  SystemPromptPart,
21
24
  TextPart,
@@ -27,8 +30,8 @@ from ..settings import ModelSettings
27
30
  from ..tools import ToolDefinition
28
31
  from . import (
29
32
  AgentModel,
30
- EitherStreamedResponse,
31
33
  Model,
34
+ StreamedResponse,
32
35
  cached_async_http_client,
33
36
  check_allow_model_requests,
34
37
  )
@@ -38,11 +41,16 @@ try:
38
41
  from anthropic.types import (
39
42
  Message as AnthropicMessage,
40
43
  MessageParam,
44
+ RawContentBlockDeltaEvent,
45
+ RawContentBlockStartEvent,
46
+ RawContentBlockStopEvent,
41
47
  RawMessageDeltaEvent,
42
48
  RawMessageStartEvent,
49
+ RawMessageStopEvent,
43
50
  RawMessageStreamEvent,
44
51
  TextBlock,
45
52
  TextBlockParam,
53
+ TextDelta,
46
54
  ToolChoiceParam,
47
55
  ToolParam,
48
56
  ToolResultBlockParam,
@@ -152,20 +160,20 @@ class AnthropicAgentModel(AgentModel):
152
160
  """Implementation of `AgentModel` for Anthropic models."""
153
161
 
154
162
  client: AsyncAnthropic
155
- model_name: str
163
+ model_name: AnthropicModelName
156
164
  allow_text_result: bool
157
165
  tools: list[ToolParam]
158
166
 
159
167
  async def request(
160
168
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> tuple[ModelResponse, result.Usage]:
169
+ ) -> tuple[ModelResponse, usage.Usage]:
162
170
  response = await self._messages_create(messages, False, model_settings)
163
171
  return self._process_response(response), _map_usage(response)
164
172
 
165
173
  @asynccontextmanager
166
174
  async def request_stream(
167
175
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
- ) -> AsyncIterator[EitherStreamedResponse]:
176
+ ) -> AsyncIterator[StreamedResponse]:
169
177
  response = await self._messages_create(messages, True, model_settings)
170
178
  async with response:
171
179
  yield await self._process_streamed_response(response)
@@ -186,16 +194,22 @@ class AnthropicAgentModel(AgentModel):
186
194
  self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187
195
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
196
  # standalone function to make it easier to override
197
+ model_settings = model_settings or {}
198
+
199
+ tool_choice: ToolChoiceParam | None
200
+
189
201
  if not self.tools:
190
- tool_choice: ToolChoiceParam | None = None
191
- elif not self.allow_text_result:
192
- tool_choice = {'type': 'any'}
202
+ tool_choice = None
193
203
  else:
194
- tool_choice = {'type': 'auto'}
204
+ if not self.allow_text_result:
205
+ tool_choice = {'type': 'any'}
206
+ else:
207
+ tool_choice = {'type': 'auto'}
195
208
 
196
- system_prompt, anthropic_messages = self._map_message(messages)
209
+ if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
210
+ tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
197
211
 
198
- model_settings = model_settings or {}
212
+ system_prompt, anthropic_messages = self._map_message(messages)
199
213
 
200
214
  return await self.client.messages.create(
201
215
  max_tokens=model_settings.get('max_tokens', 1024),
@@ -210,43 +224,33 @@ class AnthropicAgentModel(AgentModel):
210
224
  timeout=model_settings.get('timeout', NOT_GIVEN),
211
225
  )
212
226
 
213
- @staticmethod
214
- def _process_response(response: AnthropicMessage) -> ModelResponse:
227
+ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
215
228
  """Process a non-streamed response, and prepare a message to return."""
216
229
  items: list[ModelResponsePart] = []
217
230
  for item in response.content:
218
231
  if isinstance(item, TextBlock):
219
- items.append(TextPart(item.text))
232
+ items.append(TextPart(content=item.text))
220
233
  else:
221
234
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
235
  items.append(
223
236
  ToolCallPart.from_raw_args(
224
- item.name,
225
- cast(dict[str, Any], item.input),
226
- item.id,
237
+ tool_name=item.name,
238
+ args=cast(dict[str, Any], item.input),
239
+ tool_call_id=item.id,
227
240
  )
228
241
  )
229
242
 
230
- return ModelResponse(items)
243
+ return ModelResponse(items, model_name=self.model_name)
231
244
 
232
- @staticmethod
233
- async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
234
- """TODO: Process a streamed response, and prepare a streaming response to return."""
235
- # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
- # Streamed responses will be supported in a future release.
237
-
238
- raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
-
240
- # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
241
- # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
- # RawMessageStartEvent
243
- # RawMessageDeltaEvent
244
- # RawMessageStopEvent
245
- # RawContentBlockStartEvent
246
- # RawContentBlockDeltaEvent
247
- # RawContentBlockDeltaEvent
248
- #
249
- # We might refactor streaming internally before we implement this...
245
+ async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
246
+ peekable_response = _utils.PeekableAsyncStream(response)
247
+ first_chunk = await peekable_response.peek()
248
+ if isinstance(first_chunk, _utils.Unset):
249
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
250
+
251
+ # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
252
+ timestamp = datetime.now(tz=timezone.utc)
253
+ return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
250
254
 
251
255
  @staticmethod
252
256
  def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
@@ -315,30 +319,90 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
315
319
  )
316
320
 
317
321
 
318
- def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
322
+ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
319
323
  if isinstance(message, AnthropicMessage):
320
- usage = message.usage
324
+ response_usage = message.usage
321
325
  else:
322
326
  if isinstance(message, RawMessageStartEvent):
323
- usage = message.message.usage
327
+ response_usage = message.message.usage
324
328
  elif isinstance(message, RawMessageDeltaEvent):
325
- usage = message.usage
329
+ response_usage = message.usage
326
330
  else:
327
331
  # No usage information provided in:
328
332
  # - RawMessageStopEvent
329
333
  # - RawContentBlockStartEvent
330
334
  # - RawContentBlockDeltaEvent
331
335
  # - RawContentBlockStopEvent
332
- usage = None
336
+ response_usage = None
333
337
 
334
- if usage is None:
335
- return result.Usage()
338
+ if response_usage is None:
339
+ return usage.Usage()
336
340
 
337
- request_tokens = getattr(usage, 'input_tokens', None)
341
+ request_tokens = getattr(response_usage, 'input_tokens', None)
338
342
 
339
- return result.Usage(
343
+ return usage.Usage(
340
344
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
345
  request_tokens=request_tokens,
342
- response_tokens=usage.output_tokens,
343
- total_tokens=(request_tokens or 0) + usage.output_tokens,
346
+ response_tokens=response_usage.output_tokens,
347
+ total_tokens=(request_tokens or 0) + response_usage.output_tokens,
344
348
  )
349
+
350
+
351
+ @dataclass
352
+ class AnthropicStreamedResponse(StreamedResponse):
353
+ """Implementation of `StreamedResponse` for Anthropic models."""
354
+
355
+ _response: AsyncIterable[RawMessageStreamEvent]
356
+ _timestamp: datetime
357
+
358
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
359
+ current_block: TextBlock | ToolUseBlock | None = None
360
+ current_json: str = ''
361
+
362
+ async for event in self._response:
363
+ self._usage += _map_usage(event)
364
+
365
+ if isinstance(event, RawContentBlockStartEvent):
366
+ current_block = event.content_block
367
+ if isinstance(current_block, TextBlock) and current_block.text:
368
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
369
+ elif isinstance(current_block, ToolUseBlock):
370
+ maybe_event = self._parts_manager.handle_tool_call_delta(
371
+ vendor_part_id=current_block.id,
372
+ tool_name=current_block.name,
373
+ args=cast(dict[str, Any], current_block.input),
374
+ tool_call_id=current_block.id,
375
+ )
376
+ if maybe_event is not None:
377
+ yield maybe_event
378
+
379
+ elif isinstance(event, RawContentBlockDeltaEvent):
380
+ if isinstance(event.delta, TextDelta):
381
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
382
+ elif (
383
+ current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
384
+ ):
385
+ # Try to parse the JSON immediately, otherwise cache the value for later. This handles
386
+ # cases where the JSON is not currently valid but will be valid once we stream more tokens.
387
+ try:
388
+ parsed_args = json_loads(current_json + event.delta.partial_json)
389
+ current_json = ''
390
+ except JSONDecodeError:
391
+ current_json += event.delta.partial_json
392
+ continue
393
+
394
+ # For tool calls, we need to handle partial JSON updates
395
+ maybe_event = self._parts_manager.handle_tool_call_delta(
396
+ vendor_part_id=current_block.id,
397
+ tool_name='',
398
+ args=parsed_args,
399
+ tool_call_id=current_block.id,
400
+ )
401
+ if maybe_event is not None:
402
+ yield maybe_event
403
+
404
+ elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
405
+ current_block = None
406
+
407
+ def timestamp(self) -> datetime:
408
+ return self._timestamp
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from itertools import chain
6
+ from typing import Literal, TypeAlias, Union
7
+
8
+ from cohere import TextAssistantMessageContentItem
9
+ from typing_extensions import assert_never
10
+
11
+ from .. import result
12
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from ..messages import (
14
+ ModelMessage,
15
+ ModelRequest,
16
+ ModelResponse,
17
+ ModelResponsePart,
18
+ RetryPromptPart,
19
+ SystemPromptPart,
20
+ TextPart,
21
+ ToolCallPart,
22
+ ToolReturnPart,
23
+ UserPromptPart,
24
+ )
25
+ from ..settings import ModelSettings
26
+ from ..tools import ToolDefinition
27
+ from . import (
28
+ AgentModel,
29
+ Model,
30
+ check_allow_model_requests,
31
+ )
32
+
33
+ try:
34
+ from cohere import (
35
+ AssistantChatMessageV2,
36
+ AsyncClientV2,
37
+ ChatMessageV2,
38
+ ChatResponse,
39
+ SystemChatMessageV2,
40
+ ToolCallV2,
41
+ ToolCallV2Function,
42
+ ToolChatMessageV2,
43
+ ToolV2,
44
+ ToolV2Function,
45
+ UserChatMessageV2,
46
+ )
47
+ from cohere.v2.client import OMIT
48
+ except ImportError as _import_error:
49
+ raise ImportError(
50
+ 'Please install `cohere` to use the Cohere model, '
51
+ "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52
+ ) from _import_error
53
+
54
+ CohereModelName: TypeAlias = Union[
55
+ str,
56
+ Literal[
57
+ 'c4ai-aya-expanse-32b',
58
+ 'c4ai-aya-expanse-8b',
59
+ 'command',
60
+ 'command-light',
61
+ 'command-light-nightly',
62
+ 'command-nightly',
63
+ 'command-r',
64
+ 'command-r-03-2024',
65
+ 'command-r-08-2024',
66
+ 'command-r-plus',
67
+ 'command-r-plus-04-2024',
68
+ 'command-r-plus-08-2024',
69
+ 'command-r7b-12-2024',
70
+ ],
71
+ ]
72
+
73
+
74
+ @dataclass(init=False)
75
+ class CohereModel(Model):
76
+ """A model that uses the Cohere API.
77
+
78
+ Internally, this uses the [Cohere Python client](
79
+ https://github.com/cohere-ai/cohere-python) to interact with the API.
80
+
81
+ Apart from `__init__`, all methods are private or match those of the base class.
82
+ """
83
+
84
+ model_name: CohereModelName
85
+ client: AsyncClientV2 = field(repr=False)
86
+
87
+ def __init__(
88
+ self,
89
+ model_name: CohereModelName,
90
+ *,
91
+ api_key: str | None = None,
92
+ cohere_client: AsyncClientV2 | None = None,
93
+ ):
94
+ """Initialize an Cohere model.
95
+
96
+ Args:
97
+ model_name: The name of the Cohere model to use. List of model names
98
+ available [here](https://docs.cohere.com/docs/models#command).
99
+ api_key: The API key to use for authentication, if not provided, the
100
+ `COHERE_API_KEY` environment variable will be used if available.
101
+ cohere_client: An existing Cohere async client to use. If provided,
102
+ `api_key` must be `None`.
103
+ """
104
+ self.model_name: CohereModelName = model_name
105
+ if cohere_client is not None:
106
+ assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
107
+ self.client = cohere_client
108
+ else:
109
+ self.client = AsyncClientV2(api_key=api_key) # type: ignore
110
+
111
+ async def agent_model(
112
+ self,
113
+ *,
114
+ function_tools: list[ToolDefinition],
115
+ allow_text_result: bool,
116
+ result_tools: list[ToolDefinition],
117
+ ) -> AgentModel:
118
+ check_allow_model_requests()
119
+ tools = [self._map_tool_definition(r) for r in function_tools]
120
+ if result_tools:
121
+ tools += [self._map_tool_definition(r) for r in result_tools]
122
+ return CohereAgentModel(
123
+ self.client,
124
+ self.model_name,
125
+ allow_text_result,
126
+ tools,
127
+ )
128
+
129
+ def name(self) -> str:
130
+ return f'cohere:{self.model_name}'
131
+
132
+ @staticmethod
133
+ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
134
+ return ToolV2(
135
+ type='function',
136
+ function=ToolV2Function(
137
+ name=f.name,
138
+ description=f.description,
139
+ parameters=f.parameters_json_schema,
140
+ ),
141
+ )
142
+
143
+
144
+ @dataclass
145
+ class CohereAgentModel(AgentModel):
146
+ """Implementation of `AgentModel` for Cohere models."""
147
+
148
+ client: AsyncClientV2
149
+ model_name: CohereModelName
150
+ allow_text_result: bool
151
+ tools: list[ToolV2]
152
+
153
+ async def request(
154
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
155
+ ) -> tuple[ModelResponse, result.Usage]:
156
+ response = await self._chat(messages, model_settings)
157
+ return self._process_response(response), _map_usage(response)
158
+
159
+ async def _chat(
160
+ self,
161
+ messages: list[ModelMessage],
162
+ model_settings: ModelSettings | None,
163
+ ) -> ChatResponse:
164
+ cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
165
+ model_settings = model_settings or {}
166
+ return await self.client.chat(
167
+ model=self.model_name,
168
+ messages=cohere_messages,
169
+ tools=self.tools or OMIT,
170
+ max_tokens=model_settings.get('max_tokens', OMIT),
171
+ temperature=model_settings.get('temperature', OMIT),
172
+ p=model_settings.get('top_p', OMIT),
173
+ )
174
+
175
+ def _process_response(self, response: ChatResponse) -> ModelResponse:
176
+ """Process a non-streamed response, and prepare a message to return."""
177
+ parts: list[ModelResponsePart] = []
178
+ if response.message.content is not None and len(response.message.content) > 0:
179
+ # While Cohere's API returns a list, it only does that for future proofing
180
+ # and currently only one item is being returned.
181
+ choice = response.message.content[0]
182
+ parts.append(TextPart(choice.text))
183
+ for c in response.message.tool_calls or []:
184
+ if c.function and c.function.name and c.function.arguments:
185
+ parts.append(
186
+ ToolCallPart.from_raw_args(
187
+ tool_name=c.function.name,
188
+ args=c.function.arguments,
189
+ tool_call_id=c.id,
190
+ )
191
+ )
192
+ return ModelResponse(parts=parts, model_name=self.model_name)
193
+
194
+ @classmethod
195
+ def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
196
+ """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
197
+ if isinstance(message, ModelRequest):
198
+ yield from cls._map_user_message(message)
199
+ elif isinstance(message, ModelResponse):
200
+ texts: list[str] = []
201
+ tool_calls: list[ToolCallV2] = []
202
+ for item in message.parts:
203
+ if isinstance(item, TextPart):
204
+ texts.append(item.content)
205
+ elif isinstance(item, ToolCallPart):
206
+ tool_calls.append(_map_tool_call(item))
207
+ else:
208
+ assert_never(item)
209
+ message_param = AssistantChatMessageV2(role='assistant')
210
+ if texts:
211
+ message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
212
+ if tool_calls:
213
+ message_param.tool_calls = tool_calls
214
+ yield message_param
215
+ else:
216
+ assert_never(message)
217
+
218
+ @classmethod
219
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
220
+ for part in message.parts:
221
+ if isinstance(part, SystemPromptPart):
222
+ yield SystemChatMessageV2(role='system', content=part.content)
223
+ elif isinstance(part, UserPromptPart):
224
+ yield UserChatMessageV2(role='user', content=part.content)
225
+ elif isinstance(part, ToolReturnPart):
226
+ yield ToolChatMessageV2(
227
+ role='tool',
228
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
229
+ content=part.model_response_str(),
230
+ )
231
+ elif isinstance(part, RetryPromptPart):
232
+ if part.tool_name is None:
233
+ yield UserChatMessageV2(role='user', content=part.model_response())
234
+ else:
235
+ yield ToolChatMessageV2(
236
+ role='tool',
237
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
238
+ content=part.model_response(),
239
+ )
240
+ else:
241
+ assert_never(part)
242
+
243
+
244
+ def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
245
+ return ToolCallV2(
246
+ id=_guard_tool_call_id(t=t, model_source='Cohere'),
247
+ type='function',
248
+ function=ToolCallV2Function(
249
+ name=t.tool_name,
250
+ arguments=t.args_as_json_str(),
251
+ ),
252
+ )
253
+
254
+
255
+ def _map_usage(response: ChatResponse) -> result.Usage:
256
+ usage = response.usage
257
+ if usage is None:
258
+ return result.Usage()
259
+ else:
260
+ details: dict[str, int] = {}
261
+ if usage.billed_units is not None:
262
+ if usage.billed_units.input_tokens:
263
+ details['input_tokens'] = int(usage.billed_units.input_tokens)
264
+ if usage.billed_units.output_tokens:
265
+ details['output_tokens'] = int(usage.billed_units.output_tokens)
266
+ if usage.billed_units.search_units:
267
+ details['search_units'] = int(usage.billed_units.search_units)
268
+ if usage.billed_units.classifications:
269
+ details['classifications'] = int(usage.billed_units.classifications)
270
+
271
+ request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
272
+ response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
273
+ return result.Usage(
274
+ request_tokens=request_tokens,
275
+ response_tokens=response_tokens,
276
+ total_tokens=(request_tokens or 0) + (response_tokens or 0),
277
+ details=details,
278
+ )