pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -10,13 +10,14 @@ from typing import Literal, overload
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils,
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -24,15 +25,12 @@ from ..messages import (
|
|
|
24
25
|
ToolReturnPart,
|
|
25
26
|
UserPromptPart,
|
|
26
27
|
)
|
|
27
|
-
from ..result import Usage
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
31
|
AgentModel,
|
|
32
|
-
EitherStreamedResponse,
|
|
33
32
|
Model,
|
|
34
|
-
|
|
35
|
-
StreamTextResponse,
|
|
33
|
+
StreamedResponse,
|
|
36
34
|
cached_async_http_client,
|
|
37
35
|
check_allow_model_requests,
|
|
38
36
|
)
|
|
@@ -41,7 +39,6 @@ try:
|
|
|
41
39
|
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
42
40
|
from groq.types import chat
|
|
43
41
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
44
|
-
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
45
42
|
except ImportError as _import_error:
|
|
46
43
|
raise ImportError(
|
|
47
44
|
'Please install `groq` to use the Groq model, '
|
|
@@ -157,14 +154,14 @@ class GroqAgentModel(AgentModel):
|
|
|
157
154
|
|
|
158
155
|
async def request(
|
|
159
156
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
|
-
) -> tuple[ModelResponse,
|
|
157
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
161
158
|
response = await self._completions_create(messages, False, model_settings)
|
|
162
159
|
return self._process_response(response), _map_usage(response)
|
|
163
160
|
|
|
164
161
|
@asynccontextmanager
|
|
165
162
|
async def request_stream(
|
|
166
163
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
164
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
165
|
response = await self._completions_create(messages, True, model_settings)
|
|
169
166
|
async with response:
|
|
170
167
|
yield await self._process_streamed_response(response)
|
|
@@ -200,7 +197,7 @@ class GroqAgentModel(AgentModel):
|
|
|
200
197
|
model=str(self.model_name),
|
|
201
198
|
messages=groq_messages,
|
|
202
199
|
n=1,
|
|
203
|
-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
200
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
|
|
204
201
|
tools=self.tools or NOT_GIVEN,
|
|
205
202
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
206
203
|
stream=stream,
|
|
@@ -210,45 +207,32 @@ class GroqAgentModel(AgentModel):
|
|
|
210
207
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
208
|
)
|
|
212
209
|
|
|
213
|
-
|
|
214
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
210
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
215
211
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
212
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
217
213
|
choice = response.choices[0]
|
|
218
214
|
items: list[ModelResponsePart] = []
|
|
219
215
|
if choice.message.content is not None:
|
|
220
|
-
items.append(TextPart(choice.message.content))
|
|
216
|
+
items.append(TextPart(content=choice.message.content))
|
|
221
217
|
if choice.message.tool_calls is not None:
|
|
222
218
|
for c in choice.message.tool_calls:
|
|
223
|
-
items.append(
|
|
224
|
-
|
|
219
|
+
items.append(
|
|
220
|
+
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
221
|
+
)
|
|
222
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
225
223
|
|
|
226
|
-
|
|
227
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
224
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
225
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
if chunk.choices:
|
|
241
|
-
delta = chunk.choices[0].delta
|
|
242
|
-
|
|
243
|
-
if delta.content is not None:
|
|
244
|
-
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
245
|
-
elif delta.tool_calls is not None:
|
|
246
|
-
return GroqStreamStructuredResponse(
|
|
247
|
-
response,
|
|
248
|
-
{c.index: c for c in delta.tool_calls},
|
|
249
|
-
timestamp,
|
|
250
|
-
start_usage,
|
|
251
|
-
)
|
|
226
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
227
|
+
first_chunk = await peekable_response.peek()
|
|
228
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
229
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
230
|
+
|
|
231
|
+
return GroqStreamedResponse(
|
|
232
|
+
_response=peekable_response,
|
|
233
|
+
_model_name=self.model_name,
|
|
234
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
235
|
+
)
|
|
252
236
|
|
|
253
237
|
@classmethod
|
|
254
238
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
@@ -301,90 +285,36 @@ class GroqAgentModel(AgentModel):
|
|
|
301
285
|
|
|
302
286
|
|
|
303
287
|
@dataclass
|
|
304
|
-
class
|
|
305
|
-
"""Implementation of `
|
|
288
|
+
class GroqStreamedResponse(StreamedResponse):
|
|
289
|
+
"""Implementation of `StreamedResponse` for Groq models."""
|
|
306
290
|
|
|
307
|
-
|
|
308
|
-
_response: AsyncStream[ChatCompletionChunk]
|
|
291
|
+
_response: AsyncIterable[ChatCompletionChunk]
|
|
309
292
|
_timestamp: datetime
|
|
310
|
-
_usage: result.Usage
|
|
311
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
312
|
-
|
|
313
|
-
async def __anext__(self) -> None:
|
|
314
|
-
if self._first is not None:
|
|
315
|
-
self._buffer.append(self._first)
|
|
316
|
-
self._first = None
|
|
317
|
-
return None
|
|
318
|
-
|
|
319
|
-
chunk = await self._response.__anext__()
|
|
320
|
-
self._usage = _map_usage(chunk)
|
|
321
|
-
|
|
322
|
-
try:
|
|
323
|
-
choice = chunk.choices[0]
|
|
324
|
-
except IndexError:
|
|
325
|
-
raise StopAsyncIteration()
|
|
326
|
-
|
|
327
|
-
# we don't raise StopAsyncIteration on the last chunk because usage comes after this
|
|
328
|
-
if choice.finish_reason is None:
|
|
329
|
-
assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
330
|
-
if choice.delta.content is not None:
|
|
331
|
-
self._buffer.append(choice.delta.content)
|
|
332
|
-
|
|
333
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
334
|
-
yield from self._buffer
|
|
335
|
-
self._buffer.clear()
|
|
336
|
-
|
|
337
|
-
def usage(self) -> Usage:
|
|
338
|
-
return self._usage
|
|
339
|
-
|
|
340
|
-
def timestamp(self) -> datetime:
|
|
341
|
-
return self._timestamp
|
|
342
293
|
|
|
294
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
295
|
+
async for chunk in self._response:
|
|
296
|
+
self._usage += _map_usage(chunk)
|
|
343
297
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
|
|
366
|
-
|
|
367
|
-
for new in choice.delta.tool_calls or []:
|
|
368
|
-
if current := self._delta_tool_calls.get(new.index):
|
|
369
|
-
if current.function is None:
|
|
370
|
-
current.function = new.function
|
|
371
|
-
elif new.function is not None:
|
|
372
|
-
current.function.name = _utils.add_optional(current.function.name, new.function.name)
|
|
373
|
-
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
|
|
374
|
-
else:
|
|
375
|
-
self._delta_tool_calls[new.index] = new
|
|
376
|
-
|
|
377
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
378
|
-
items: list[ModelResponsePart] = []
|
|
379
|
-
for c in self._delta_tool_calls.values():
|
|
380
|
-
if f := c.function:
|
|
381
|
-
if f.name is not None and f.arguments is not None:
|
|
382
|
-
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
383
|
-
|
|
384
|
-
return ModelResponse(items, timestamp=self._timestamp)
|
|
385
|
-
|
|
386
|
-
def usage(self) -> Usage:
|
|
387
|
-
return self._usage
|
|
298
|
+
try:
|
|
299
|
+
choice = chunk.choices[0]
|
|
300
|
+
except IndexError:
|
|
301
|
+
continue
|
|
302
|
+
|
|
303
|
+
# Handle the text part of the response
|
|
304
|
+
content = choice.delta.content
|
|
305
|
+
if content is not None:
|
|
306
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
307
|
+
|
|
308
|
+
# Handle the tool calls
|
|
309
|
+
for dtc in choice.delta.tool_calls or []:
|
|
310
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
311
|
+
vendor_part_id=dtc.index,
|
|
312
|
+
tool_name=dtc.function and dtc.function.name,
|
|
313
|
+
args=dtc.function and dtc.function.arguments,
|
|
314
|
+
tool_call_id=dtc.id,
|
|
315
|
+
)
|
|
316
|
+
if maybe_event is not None:
|
|
317
|
+
yield maybe_event
|
|
388
318
|
|
|
389
319
|
def timestamp(self) -> datetime:
|
|
390
320
|
return self._timestamp
|
|
@@ -398,18 +328,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
|
398
328
|
)
|
|
399
329
|
|
|
400
330
|
|
|
401
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) ->
|
|
402
|
-
|
|
331
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
332
|
+
response_usage = None
|
|
403
333
|
if isinstance(completion, ChatCompletion):
|
|
404
|
-
|
|
334
|
+
response_usage = completion.usage
|
|
405
335
|
elif completion.x_groq is not None:
|
|
406
|
-
|
|
336
|
+
response_usage = completion.x_groq.usage
|
|
407
337
|
|
|
408
|
-
if
|
|
409
|
-
return
|
|
338
|
+
if response_usage is None:
|
|
339
|
+
return usage.Usage()
|
|
410
340
|
|
|
411
|
-
return
|
|
412
|
-
request_tokens=
|
|
413
|
-
response_tokens=
|
|
414
|
-
total_tokens=
|
|
341
|
+
return usage.Usage(
|
|
342
|
+
request_tokens=response_usage.prompt_tokens,
|
|
343
|
+
response_tokens=response_usage.completion_tokens,
|
|
344
|
+
total_tokens=response_usage.total_tokens,
|
|
415
345
|
)
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from collections.abc import AsyncIterator, Iterable
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
13
13
|
from typing_extensions import assert_never
|
|
14
14
|
|
|
15
|
-
from .. import UnexpectedModelBehavior
|
|
15
|
+
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
18
|
ArgsJson,
|
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -32,11 +33,10 @@ from ..settings import ModelSettings
|
|
|
32
33
|
from ..tools import ToolDefinition
|
|
33
34
|
from . import (
|
|
34
35
|
AgentModel,
|
|
35
|
-
EitherStreamedResponse,
|
|
36
36
|
Model,
|
|
37
|
-
|
|
38
|
-
StreamTextResponse,
|
|
37
|
+
StreamedResponse,
|
|
39
38
|
cached_async_http_client,
|
|
39
|
+
check_allow_model_requests,
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
try:
|
|
@@ -131,6 +131,7 @@ class MistralModel(Model):
|
|
|
131
131
|
result_tools: list[ToolDefinition],
|
|
132
132
|
) -> AgentModel:
|
|
133
133
|
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
134
|
+
check_allow_model_requests()
|
|
134
135
|
return MistralAgentModel(
|
|
135
136
|
self.client,
|
|
136
137
|
self.model_name,
|
|
@@ -148,7 +149,7 @@ class MistralAgentModel(AgentModel):
|
|
|
148
149
|
"""Implementation of `AgentModel` for Mistral models."""
|
|
149
150
|
|
|
150
151
|
client: Mistral
|
|
151
|
-
model_name:
|
|
152
|
+
model_name: MistralModelName
|
|
152
153
|
allow_text_result: bool
|
|
153
154
|
function_tools: list[ToolDefinition]
|
|
154
155
|
result_tools: list[ToolDefinition]
|
|
@@ -164,7 +165,7 @@ class MistralAgentModel(AgentModel):
|
|
|
164
165
|
@asynccontextmanager
|
|
165
166
|
async def request_stream(
|
|
166
167
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
168
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
169
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
169
170
|
response = await self._stream_completions_create(messages, model_settings)
|
|
170
171
|
async with response:
|
|
@@ -266,8 +267,7 @@ class MistralAgentModel(AgentModel):
|
|
|
266
267
|
]
|
|
267
268
|
return tools if tools else None
|
|
268
269
|
|
|
269
|
-
|
|
270
|
-
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
270
|
+
def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
|
|
271
271
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
272
272
|
assert response.choices, 'Unexpected empty response choice.'
|
|
273
273
|
|
|
@@ -282,58 +282,37 @@ class MistralAgentModel(AgentModel):
|
|
|
282
282
|
|
|
283
283
|
parts: list[ModelResponsePart] = []
|
|
284
284
|
if text := _map_content(content):
|
|
285
|
-
parts.append(TextPart(text))
|
|
285
|
+
parts.append(TextPart(content=text))
|
|
286
286
|
|
|
287
287
|
if isinstance(tool_calls, list):
|
|
288
288
|
for tool_call in tool_calls:
|
|
289
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
289
|
+
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
290
290
|
parts.append(tool)
|
|
291
291
|
|
|
292
|
-
return ModelResponse(parts, timestamp=timestamp)
|
|
292
|
+
return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
|
|
293
293
|
|
|
294
|
-
@staticmethod
|
|
295
294
|
async def _process_streamed_response(
|
|
295
|
+
self,
|
|
296
296
|
result_tools: list[ToolDefinition],
|
|
297
297
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
|
-
) ->
|
|
298
|
+
) -> StreamedResponse:
|
|
299
299
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
try:
|
|
305
|
-
event = await response.__anext__()
|
|
306
|
-
chunk = event.data
|
|
307
|
-
except StopAsyncIteration as e:
|
|
308
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
300
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
301
|
+
first_chunk = await peekable_response.peek()
|
|
302
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
303
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
309
304
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
else:
|
|
315
|
-
timestamp = _now_utc()
|
|
316
|
-
|
|
317
|
-
if chunk.choices:
|
|
318
|
-
delta = chunk.choices[0].delta
|
|
319
|
-
content = _map_content(delta.content)
|
|
320
|
-
|
|
321
|
-
tool_calls: list[MistralToolCall] | None = None
|
|
322
|
-
if delta.tool_calls:
|
|
323
|
-
tool_calls = delta.tool_calls
|
|
324
|
-
|
|
325
|
-
if tool_calls or content and result_tools:
|
|
326
|
-
return MistralStreamStructuredResponse(
|
|
327
|
-
{c.id if c.id else 'null': c for c in tool_calls or []},
|
|
328
|
-
{c.name: c for c in result_tools},
|
|
329
|
-
response,
|
|
330
|
-
content,
|
|
331
|
-
timestamp,
|
|
332
|
-
start_usage,
|
|
333
|
-
)
|
|
305
|
+
if first_chunk.data.created:
|
|
306
|
+
timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
|
|
307
|
+
else:
|
|
308
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
334
309
|
|
|
335
|
-
|
|
336
|
-
|
|
310
|
+
return MistralStreamedResponse(
|
|
311
|
+
_response=peekable_response,
|
|
312
|
+
_model_name=self.model_name,
|
|
313
|
+
_timestamp=timestamp,
|
|
314
|
+
_result_tools={c.name: c for c in result_tools},
|
|
315
|
+
)
|
|
337
316
|
|
|
338
317
|
@staticmethod
|
|
339
318
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
@@ -467,108 +446,73 @@ class MistralAgentModel(AgentModel):
|
|
|
467
446
|
assert_never(message)
|
|
468
447
|
|
|
469
448
|
|
|
470
|
-
|
|
471
|
-
class MistralStreamTextResponse(StreamTextResponse):
|
|
472
|
-
"""Implementation of `StreamTextResponse` for Mistral models."""
|
|
473
|
-
|
|
474
|
-
_first: str | None
|
|
475
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
|
-
_timestamp: datetime
|
|
477
|
-
_usage: Usage
|
|
478
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
|
-
|
|
480
|
-
async def __anext__(self) -> None:
|
|
481
|
-
if self._first is not None and len(self._first) > 0:
|
|
482
|
-
self._buffer.append(self._first)
|
|
483
|
-
self._first = None
|
|
484
|
-
return None
|
|
485
|
-
|
|
486
|
-
chunk = await self._response.__anext__()
|
|
487
|
-
self._usage += _map_usage(chunk.data)
|
|
488
|
-
|
|
489
|
-
try:
|
|
490
|
-
choice = chunk.data.choices[0]
|
|
491
|
-
except IndexError:
|
|
492
|
-
raise StopAsyncIteration()
|
|
493
|
-
|
|
494
|
-
content = choice.delta.content
|
|
495
|
-
if choice.finish_reason is None:
|
|
496
|
-
assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
497
|
-
|
|
498
|
-
if text := _map_content(content):
|
|
499
|
-
self._buffer.append(text)
|
|
500
|
-
|
|
501
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
502
|
-
yield from self._buffer
|
|
503
|
-
self._buffer.clear()
|
|
504
|
-
|
|
505
|
-
def usage(self) -> Usage:
|
|
506
|
-
return self._usage
|
|
507
|
-
|
|
508
|
-
def timestamp(self) -> datetime:
|
|
509
|
-
return self._timestamp
|
|
449
|
+
MistralToolCallId = Union[str, None]
|
|
510
450
|
|
|
511
451
|
|
|
512
452
|
@dataclass
|
|
513
|
-
class
|
|
514
|
-
"""Implementation of `
|
|
453
|
+
class MistralStreamedResponse(StreamedResponse):
|
|
454
|
+
"""Implementation of `StreamedResponse` for Mistral models."""
|
|
515
455
|
|
|
516
|
-
|
|
517
|
-
_result_tools: dict[str, ToolDefinition]
|
|
518
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
|
-
_delta_content: str | None
|
|
456
|
+
_response: AsyncIterable[MistralCompletionEvent]
|
|
520
457
|
_timestamp: datetime
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
async def __anext__(self) -> None:
|
|
524
|
-
chunk = await self._response.__anext__()
|
|
525
|
-
self._usage += _map_usage(chunk.data)
|
|
526
|
-
|
|
527
|
-
try:
|
|
528
|
-
choice = chunk.data.choices[0]
|
|
529
|
-
|
|
530
|
-
except IndexError:
|
|
531
|
-
raise StopAsyncIteration()
|
|
532
|
-
|
|
533
|
-
if choice.finish_reason is not None:
|
|
534
|
-
raise StopAsyncIteration()
|
|
535
|
-
|
|
536
|
-
content = choice.delta.content
|
|
537
|
-
if self._result_tools:
|
|
538
|
-
if text := _map_content(content):
|
|
539
|
-
self._delta_content = (self._delta_content or '') + text
|
|
540
|
-
|
|
541
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
542
|
-
calls: list[ModelResponsePart] = []
|
|
543
|
-
if self._function_tools and self._result_tools or self._function_tools:
|
|
544
|
-
for tool_call in self._function_tools.values():
|
|
545
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
546
|
-
calls.append(tool)
|
|
547
|
-
|
|
548
|
-
elif self._delta_content and self._result_tools:
|
|
549
|
-
output_json: dict[str, Any] | None = pydantic_core.from_json(
|
|
550
|
-
self._delta_content, allow_partial='trailing-strings'
|
|
551
|
-
)
|
|
458
|
+
_result_tools: dict[str, ToolDefinition]
|
|
552
459
|
|
|
553
|
-
|
|
554
|
-
for result_tool in self._result_tools.values():
|
|
555
|
-
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
556
|
-
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
557
|
-
# Example with BaseModel and required fields.
|
|
558
|
-
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
559
|
-
continue
|
|
460
|
+
_delta_content: str = field(default='', init=False)
|
|
560
461
|
|
|
561
|
-
|
|
562
|
-
|
|
462
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
463
|
+
chunk: MistralCompletionEvent
|
|
464
|
+
async for chunk in self._response:
|
|
465
|
+
self._usage += _map_usage(chunk.data)
|
|
563
466
|
|
|
564
|
-
|
|
467
|
+
try:
|
|
468
|
+
choice = chunk.data.choices[0]
|
|
469
|
+
except IndexError:
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
# Handle the text part of the response
|
|
473
|
+
content = choice.delta.content
|
|
474
|
+
text = _map_content(content)
|
|
475
|
+
if text:
|
|
476
|
+
# Attempt to produce a result tool call from the received text
|
|
477
|
+
if self._result_tools:
|
|
478
|
+
self._delta_content += text
|
|
479
|
+
maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
|
|
480
|
+
if maybe_tool_call_part:
|
|
481
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
482
|
+
vendor_part_id='result',
|
|
483
|
+
tool_name=maybe_tool_call_part.tool_name,
|
|
484
|
+
args=maybe_tool_call_part.args_as_dict(),
|
|
485
|
+
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
|
|
565
489
|
|
|
566
|
-
|
|
567
|
-
|
|
490
|
+
# Handle the explicit tool calls
|
|
491
|
+
for index, dtc in enumerate(choice.delta.tool_calls or []):
|
|
492
|
+
# It seems that mistral just sends full tool calls, so we just use them directly, rather than building
|
|
493
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
494
|
+
vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
|
|
495
|
+
)
|
|
568
496
|
|
|
569
497
|
def timestamp(self) -> datetime:
|
|
570
498
|
return self._timestamp
|
|
571
499
|
|
|
500
|
+
@staticmethod
|
|
501
|
+
def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
|
|
502
|
+
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
|
|
503
|
+
if output_json:
|
|
504
|
+
for result_tool in result_tools.values():
|
|
505
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
506
|
+
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
507
|
+
# Example with BaseModel and required fields.
|
|
508
|
+
if not MistralStreamedResponse._validate_required_json_schema(
|
|
509
|
+
output_json, result_tool.parameters_json_schema
|
|
510
|
+
):
|
|
511
|
+
continue
|
|
512
|
+
|
|
513
|
+
# The following part_id will be thrown away
|
|
514
|
+
return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
|
|
515
|
+
|
|
572
516
|
@staticmethod
|
|
573
517
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
574
518
|
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
@@ -587,20 +531,20 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
587
531
|
if not isinstance(json_dict[param], list):
|
|
588
532
|
return False
|
|
589
533
|
for item in json_dict[param]:
|
|
590
|
-
if not isinstance(item,
|
|
534
|
+
if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
|
|
591
535
|
return False
|
|
592
|
-
elif param_type and not isinstance(json_dict[param],
|
|
536
|
+
elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
|
|
593
537
|
return False
|
|
594
538
|
|
|
595
539
|
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
596
540
|
nested_schema = param_schema
|
|
597
|
-
if not
|
|
541
|
+
if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
598
542
|
return False
|
|
599
543
|
|
|
600
544
|
return True
|
|
601
545
|
|
|
602
546
|
|
|
603
|
-
|
|
547
|
+
VALID_JSON_TYPE_MAPPING: dict[str, Any] = {
|
|
604
548
|
'string': str,
|
|
605
549
|
'integer': int,
|
|
606
550
|
'number': float,
|
pydantic_ai/models/ollama.py
CHANGED
|
@@ -10,6 +10,7 @@ from . import (
|
|
|
10
10
|
AgentModel,
|
|
11
11
|
Model,
|
|
12
12
|
cached_async_http_client,
|
|
13
|
+
check_allow_model_requests,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
try:
|
|
@@ -25,6 +26,7 @@ from .openai import OpenAIModel
|
|
|
25
26
|
|
|
26
27
|
CommonOllamaModelNames = Literal[
|
|
27
28
|
'codellama',
|
|
29
|
+
'deepseek-r1',
|
|
28
30
|
'gemma',
|
|
29
31
|
'gemma2',
|
|
30
32
|
'llama3',
|
|
@@ -36,6 +38,7 @@ CommonOllamaModelNames = Literal[
|
|
|
36
38
|
'mistral-nemo',
|
|
37
39
|
'mixtral',
|
|
38
40
|
'phi3',
|
|
41
|
+
'phi4',
|
|
39
42
|
'qwq',
|
|
40
43
|
'qwen',
|
|
41
44
|
'qwen2',
|
|
@@ -109,6 +112,7 @@ class OllamaModel(Model):
|
|
|
109
112
|
allow_text_result: bool,
|
|
110
113
|
result_tools: list[ToolDefinition],
|
|
111
114
|
) -> AgentModel:
|
|
115
|
+
check_allow_model_requests()
|
|
112
116
|
return await self.openai_model.agent_model(
|
|
113
117
|
function_tools=function_tools,
|
|
114
118
|
allow_text_result=allow_text_result,
|