pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +23 -4
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +332 -124
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +224 -9
- pydantic_ai/models/__init__.py +59 -82
- pydantic_ai/models/anthropic.py +22 -22
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +86 -125
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/models/vertexai.py +1 -1
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.17.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.17.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -10,13 +10,14 @@ from typing import Literal, overload
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils,
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -24,15 +25,12 @@ from ..messages import (
|
|
|
24
25
|
ToolReturnPart,
|
|
25
26
|
UserPromptPart,
|
|
26
27
|
)
|
|
27
|
-
from ..result import Usage
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
31
|
AgentModel,
|
|
32
|
-
EitherStreamedResponse,
|
|
33
32
|
Model,
|
|
34
|
-
|
|
35
|
-
StreamTextResponse,
|
|
33
|
+
StreamedResponse,
|
|
36
34
|
cached_async_http_client,
|
|
37
35
|
check_allow_model_requests,
|
|
38
36
|
)
|
|
@@ -41,7 +39,6 @@ try:
|
|
|
41
39
|
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
42
40
|
from groq.types import chat
|
|
43
41
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
44
|
-
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
45
42
|
except ImportError as _import_error:
|
|
46
43
|
raise ImportError(
|
|
47
44
|
'Please install `groq` to use the Groq model, '
|
|
@@ -157,14 +154,14 @@ class GroqAgentModel(AgentModel):
|
|
|
157
154
|
|
|
158
155
|
async def request(
|
|
159
156
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
|
-
) -> tuple[ModelResponse,
|
|
157
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
161
158
|
response = await self._completions_create(messages, False, model_settings)
|
|
162
159
|
return self._process_response(response), _map_usage(response)
|
|
163
160
|
|
|
164
161
|
@asynccontextmanager
|
|
165
162
|
async def request_stream(
|
|
166
163
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
164
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
165
|
response = await self._completions_create(messages, True, model_settings)
|
|
169
166
|
async with response:
|
|
170
167
|
yield await self._process_streamed_response(response)
|
|
@@ -217,38 +214,23 @@ class GroqAgentModel(AgentModel):
|
|
|
217
214
|
choice = response.choices[0]
|
|
218
215
|
items: list[ModelResponsePart] = []
|
|
219
216
|
if choice.message.content is not None:
|
|
220
|
-
items.append(TextPart(choice.message.content))
|
|
217
|
+
items.append(TextPart(content=choice.message.content))
|
|
221
218
|
if choice.message.tool_calls is not None:
|
|
222
219
|
for c in choice.message.tool_calls:
|
|
223
|
-
items.append(
|
|
220
|
+
items.append(
|
|
221
|
+
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
222
|
+
)
|
|
224
223
|
return ModelResponse(items, timestamp=timestamp)
|
|
225
224
|
|
|
226
225
|
@staticmethod
|
|
227
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) ->
|
|
226
|
+
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
227
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
except StopAsyncIteration as e:
|
|
236
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
237
|
-
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
238
|
-
start_usage += _map_usage(chunk)
|
|
239
|
-
|
|
240
|
-
if chunk.choices:
|
|
241
|
-
delta = chunk.choices[0].delta
|
|
242
|
-
|
|
243
|
-
if delta.content is not None:
|
|
244
|
-
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
245
|
-
elif delta.tool_calls is not None:
|
|
246
|
-
return GroqStreamStructuredResponse(
|
|
247
|
-
response,
|
|
248
|
-
{c.index: c for c in delta.tool_calls},
|
|
249
|
-
timestamp,
|
|
250
|
-
start_usage,
|
|
251
|
-
)
|
|
228
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
229
|
+
first_chunk = await peekable_response.peek()
|
|
230
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
231
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
232
|
+
|
|
233
|
+
return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
|
|
252
234
|
|
|
253
235
|
@classmethod
|
|
254
236
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
@@ -301,90 +283,36 @@ class GroqAgentModel(AgentModel):
|
|
|
301
283
|
|
|
302
284
|
|
|
303
285
|
@dataclass
|
|
304
|
-
class
|
|
305
|
-
"""Implementation of `
|
|
286
|
+
class GroqStreamedResponse(StreamedResponse):
|
|
287
|
+
"""Implementation of `StreamedResponse` for Groq models."""
|
|
306
288
|
|
|
307
|
-
|
|
308
|
-
_response: AsyncStream[ChatCompletionChunk]
|
|
289
|
+
_response: AsyncIterable[ChatCompletionChunk]
|
|
309
290
|
_timestamp: datetime
|
|
310
|
-
_usage: result.Usage
|
|
311
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
312
|
-
|
|
313
|
-
async def __anext__(self) -> None:
|
|
314
|
-
if self._first is not None:
|
|
315
|
-
self._buffer.append(self._first)
|
|
316
|
-
self._first = None
|
|
317
|
-
return None
|
|
318
|
-
|
|
319
|
-
chunk = await self._response.__anext__()
|
|
320
|
-
self._usage = _map_usage(chunk)
|
|
321
|
-
|
|
322
|
-
try:
|
|
323
|
-
choice = chunk.choices[0]
|
|
324
|
-
except IndexError:
|
|
325
|
-
raise StopAsyncIteration()
|
|
326
291
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
if choice.delta.content is not None:
|
|
331
|
-
self._buffer.append(choice.delta.content)
|
|
292
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
293
|
+
async for chunk in self._response:
|
|
294
|
+
self._usage += _map_usage(chunk)
|
|
332
295
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
async def __anext__(self) -> None:
|
|
354
|
-
chunk = await self._response.__anext__()
|
|
355
|
-
self._usage = _map_usage(chunk)
|
|
356
|
-
|
|
357
|
-
try:
|
|
358
|
-
choice = chunk.choices[0]
|
|
359
|
-
except IndexError:
|
|
360
|
-
raise StopAsyncIteration()
|
|
361
|
-
|
|
362
|
-
if choice.finish_reason is not None:
|
|
363
|
-
raise StopAsyncIteration()
|
|
364
|
-
|
|
365
|
-
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
|
|
366
|
-
|
|
367
|
-
for new in choice.delta.tool_calls or []:
|
|
368
|
-
if current := self._delta_tool_calls.get(new.index):
|
|
369
|
-
if current.function is None:
|
|
370
|
-
current.function = new.function
|
|
371
|
-
elif new.function is not None:
|
|
372
|
-
current.function.name = _utils.add_optional(current.function.name, new.function.name)
|
|
373
|
-
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
|
|
374
|
-
else:
|
|
375
|
-
self._delta_tool_calls[new.index] = new
|
|
376
|
-
|
|
377
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
378
|
-
items: list[ModelResponsePart] = []
|
|
379
|
-
for c in self._delta_tool_calls.values():
|
|
380
|
-
if f := c.function:
|
|
381
|
-
if f.name is not None and f.arguments is not None:
|
|
382
|
-
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
383
|
-
|
|
384
|
-
return ModelResponse(items, timestamp=self._timestamp)
|
|
385
|
-
|
|
386
|
-
def usage(self) -> Usage:
|
|
387
|
-
return self._usage
|
|
296
|
+
try:
|
|
297
|
+
choice = chunk.choices[0]
|
|
298
|
+
except IndexError:
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
# Handle the text part of the response
|
|
302
|
+
content = choice.delta.content
|
|
303
|
+
if content is not None:
|
|
304
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
305
|
+
|
|
306
|
+
# Handle the tool calls
|
|
307
|
+
for dtc in choice.delta.tool_calls or []:
|
|
308
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
309
|
+
vendor_part_id=dtc.index,
|
|
310
|
+
tool_name=dtc.function and dtc.function.name,
|
|
311
|
+
args=dtc.function and dtc.function.arguments,
|
|
312
|
+
tool_call_id=dtc.id,
|
|
313
|
+
)
|
|
314
|
+
if maybe_event is not None:
|
|
315
|
+
yield maybe_event
|
|
388
316
|
|
|
389
317
|
def timestamp(self) -> datetime:
|
|
390
318
|
return self._timestamp
|
|
@@ -398,18 +326,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
|
398
326
|
)
|
|
399
327
|
|
|
400
328
|
|
|
401
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) ->
|
|
402
|
-
|
|
329
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
330
|
+
response_usage = None
|
|
403
331
|
if isinstance(completion, ChatCompletion):
|
|
404
|
-
|
|
332
|
+
response_usage = completion.usage
|
|
405
333
|
elif completion.x_groq is not None:
|
|
406
|
-
|
|
334
|
+
response_usage = completion.x_groq.usage
|
|
407
335
|
|
|
408
|
-
if
|
|
409
|
-
return
|
|
336
|
+
if response_usage is None:
|
|
337
|
+
return usage.Usage()
|
|
410
338
|
|
|
411
|
-
return
|
|
412
|
-
request_tokens=
|
|
413
|
-
response_tokens=
|
|
414
|
-
total_tokens=
|
|
339
|
+
return usage.Usage(
|
|
340
|
+
request_tokens=response_usage.prompt_tokens,
|
|
341
|
+
response_tokens=response_usage.completion_tokens,
|
|
342
|
+
total_tokens=response_usage.total_tokens,
|
|
415
343
|
)
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from collections.abc import AsyncIterator, Iterable
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
13
13
|
from typing_extensions import assert_never
|
|
14
14
|
|
|
15
|
-
from .. import UnexpectedModelBehavior
|
|
15
|
+
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
18
|
ArgsJson,
|
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -32,10 +33,8 @@ from ..settings import ModelSettings
|
|
|
32
33
|
from ..tools import ToolDefinition
|
|
33
34
|
from . import (
|
|
34
35
|
AgentModel,
|
|
35
|
-
EitherStreamedResponse,
|
|
36
36
|
Model,
|
|
37
|
-
|
|
38
|
-
StreamTextResponse,
|
|
37
|
+
StreamedResponse,
|
|
39
38
|
cached_async_http_client,
|
|
40
39
|
)
|
|
41
40
|
|
|
@@ -164,7 +163,7 @@ class MistralAgentModel(AgentModel):
|
|
|
164
163
|
@asynccontextmanager
|
|
165
164
|
async def request_stream(
|
|
166
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
166
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
167
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
169
168
|
response = await self._stream_completions_create(messages, model_settings)
|
|
170
169
|
async with response:
|
|
@@ -282,11 +281,11 @@ class MistralAgentModel(AgentModel):
|
|
|
282
281
|
|
|
283
282
|
parts: list[ModelResponsePart] = []
|
|
284
283
|
if text := _map_content(content):
|
|
285
|
-
parts.append(TextPart(text))
|
|
284
|
+
parts.append(TextPart(content=text))
|
|
286
285
|
|
|
287
286
|
if isinstance(tool_calls, list):
|
|
288
287
|
for tool_call in tool_calls:
|
|
289
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
288
|
+
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
290
289
|
parts.append(tool)
|
|
291
290
|
|
|
292
291
|
return ModelResponse(parts, timestamp=timestamp)
|
|
@@ -295,45 +294,19 @@ class MistralAgentModel(AgentModel):
|
|
|
295
294
|
async def _process_streamed_response(
|
|
296
295
|
result_tools: list[ToolDefinition],
|
|
297
296
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
|
-
) ->
|
|
297
|
+
) -> StreamedResponse:
|
|
299
298
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
-
|
|
299
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
300
|
+
first_chunk = await peekable_response.peek()
|
|
301
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
302
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
301
303
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
chunk = event.data
|
|
307
|
-
except StopAsyncIteration as e:
|
|
308
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
309
|
-
|
|
310
|
-
start_usage += _map_usage(chunk)
|
|
311
|
-
|
|
312
|
-
if chunk.created:
|
|
313
|
-
timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
314
|
-
else:
|
|
315
|
-
timestamp = _now_utc()
|
|
316
|
-
|
|
317
|
-
if chunk.choices:
|
|
318
|
-
delta = chunk.choices[0].delta
|
|
319
|
-
content = _map_content(delta.content)
|
|
320
|
-
|
|
321
|
-
tool_calls: list[MistralToolCall] | None = None
|
|
322
|
-
if delta.tool_calls:
|
|
323
|
-
tool_calls = delta.tool_calls
|
|
324
|
-
|
|
325
|
-
if tool_calls or content and result_tools:
|
|
326
|
-
return MistralStreamStructuredResponse(
|
|
327
|
-
{c.id if c.id else 'null': c for c in tool_calls or []},
|
|
328
|
-
{c.name: c for c in result_tools},
|
|
329
|
-
response,
|
|
330
|
-
content,
|
|
331
|
-
timestamp,
|
|
332
|
-
start_usage,
|
|
333
|
-
)
|
|
304
|
+
if first_chunk.data.created:
|
|
305
|
+
timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
|
|
306
|
+
else:
|
|
307
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
334
308
|
|
|
335
|
-
|
|
336
|
-
return MistralStreamTextResponse(content, response, timestamp, start_usage)
|
|
309
|
+
return MistralStreamedResponse(peekable_response, timestamp, {c.name: c for c in result_tools})
|
|
337
310
|
|
|
338
311
|
@staticmethod
|
|
339
312
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
@@ -467,108 +440,73 @@ class MistralAgentModel(AgentModel):
|
|
|
467
440
|
assert_never(message)
|
|
468
441
|
|
|
469
442
|
|
|
470
|
-
|
|
471
|
-
class MistralStreamTextResponse(StreamTextResponse):
|
|
472
|
-
"""Implementation of `StreamTextResponse` for Mistral models."""
|
|
473
|
-
|
|
474
|
-
_first: str | None
|
|
475
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
|
-
_timestamp: datetime
|
|
477
|
-
_usage: Usage
|
|
478
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
|
-
|
|
480
|
-
async def __anext__(self) -> None:
|
|
481
|
-
if self._first is not None and len(self._first) > 0:
|
|
482
|
-
self._buffer.append(self._first)
|
|
483
|
-
self._first = None
|
|
484
|
-
return None
|
|
485
|
-
|
|
486
|
-
chunk = await self._response.__anext__()
|
|
487
|
-
self._usage += _map_usage(chunk.data)
|
|
488
|
-
|
|
489
|
-
try:
|
|
490
|
-
choice = chunk.data.choices[0]
|
|
491
|
-
except IndexError:
|
|
492
|
-
raise StopAsyncIteration()
|
|
493
|
-
|
|
494
|
-
content = choice.delta.content
|
|
495
|
-
if choice.finish_reason is None:
|
|
496
|
-
assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
497
|
-
|
|
498
|
-
if text := _map_content(content):
|
|
499
|
-
self._buffer.append(text)
|
|
500
|
-
|
|
501
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
502
|
-
yield from self._buffer
|
|
503
|
-
self._buffer.clear()
|
|
504
|
-
|
|
505
|
-
def usage(self) -> Usage:
|
|
506
|
-
return self._usage
|
|
507
|
-
|
|
508
|
-
def timestamp(self) -> datetime:
|
|
509
|
-
return self._timestamp
|
|
443
|
+
MistralToolCallId = Union[str, None]
|
|
510
444
|
|
|
511
445
|
|
|
512
446
|
@dataclass
|
|
513
|
-
class
|
|
514
|
-
"""Implementation of `
|
|
447
|
+
class MistralStreamedResponse(StreamedResponse):
|
|
448
|
+
"""Implementation of `StreamedResponse` for Mistral models."""
|
|
515
449
|
|
|
516
|
-
|
|
517
|
-
_result_tools: dict[str, ToolDefinition]
|
|
518
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
|
-
_delta_content: str | None
|
|
450
|
+
_response: AsyncIterable[MistralCompletionEvent]
|
|
520
451
|
_timestamp: datetime
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
async def __anext__(self) -> None:
|
|
524
|
-
chunk = await self._response.__anext__()
|
|
525
|
-
self._usage += _map_usage(chunk.data)
|
|
526
|
-
|
|
527
|
-
try:
|
|
528
|
-
choice = chunk.data.choices[0]
|
|
529
|
-
|
|
530
|
-
except IndexError:
|
|
531
|
-
raise StopAsyncIteration()
|
|
532
|
-
|
|
533
|
-
if choice.finish_reason is not None:
|
|
534
|
-
raise StopAsyncIteration()
|
|
535
|
-
|
|
536
|
-
content = choice.delta.content
|
|
537
|
-
if self._result_tools:
|
|
538
|
-
if text := _map_content(content):
|
|
539
|
-
self._delta_content = (self._delta_content or '') + text
|
|
540
|
-
|
|
541
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
542
|
-
calls: list[ModelResponsePart] = []
|
|
543
|
-
if self._function_tools and self._result_tools or self._function_tools:
|
|
544
|
-
for tool_call in self._function_tools.values():
|
|
545
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
546
|
-
calls.append(tool)
|
|
547
|
-
|
|
548
|
-
elif self._delta_content and self._result_tools:
|
|
549
|
-
output_json: dict[str, Any] | None = pydantic_core.from_json(
|
|
550
|
-
self._delta_content, allow_partial='trailing-strings'
|
|
551
|
-
)
|
|
452
|
+
_result_tools: dict[str, ToolDefinition]
|
|
552
453
|
|
|
553
|
-
|
|
554
|
-
for result_tool in self._result_tools.values():
|
|
555
|
-
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
556
|
-
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
557
|
-
# Example with BaseModel and required fields.
|
|
558
|
-
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
559
|
-
continue
|
|
454
|
+
_delta_content: str = field(default='', init=False)
|
|
560
455
|
|
|
561
|
-
|
|
562
|
-
|
|
456
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
457
|
+
chunk: MistralCompletionEvent
|
|
458
|
+
async for chunk in self._response:
|
|
459
|
+
self._usage += _map_usage(chunk.data)
|
|
563
460
|
|
|
564
|
-
|
|
461
|
+
try:
|
|
462
|
+
choice = chunk.data.choices[0]
|
|
463
|
+
except IndexError:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
# Handle the text part of the response
|
|
467
|
+
content = choice.delta.content
|
|
468
|
+
text = _map_content(content)
|
|
469
|
+
if text:
|
|
470
|
+
# Attempt to produce a result tool call from the received text
|
|
471
|
+
if self._result_tools:
|
|
472
|
+
self._delta_content += text
|
|
473
|
+
maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
|
|
474
|
+
if maybe_tool_call_part:
|
|
475
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
476
|
+
vendor_part_id='result',
|
|
477
|
+
tool_name=maybe_tool_call_part.tool_name,
|
|
478
|
+
args=maybe_tool_call_part.args_as_dict(),
|
|
479
|
+
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
480
|
+
)
|
|
481
|
+
else:
|
|
482
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
|
|
565
483
|
|
|
566
|
-
|
|
567
|
-
|
|
484
|
+
# Handle the explicit tool calls
|
|
485
|
+
for index, dtc in enumerate(choice.delta.tool_calls or []):
|
|
486
|
+
# It seems that mistral just sends full tool calls, so we just use them directly, rather than building
|
|
487
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
488
|
+
vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
|
|
489
|
+
)
|
|
568
490
|
|
|
569
491
|
def timestamp(self) -> datetime:
|
|
570
492
|
return self._timestamp
|
|
571
493
|
|
|
494
|
+
@staticmethod
|
|
495
|
+
def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
|
|
496
|
+
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
|
|
497
|
+
if output_json:
|
|
498
|
+
for result_tool in result_tools.values():
|
|
499
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
500
|
+
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
501
|
+
# Example with BaseModel and required fields.
|
|
502
|
+
if not MistralStreamedResponse._validate_required_json_schema(
|
|
503
|
+
output_json, result_tool.parameters_json_schema
|
|
504
|
+
):
|
|
505
|
+
continue
|
|
506
|
+
|
|
507
|
+
# The following part_id will be thrown away
|
|
508
|
+
return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
|
|
509
|
+
|
|
572
510
|
@staticmethod
|
|
573
511
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
574
512
|
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
@@ -587,20 +525,20 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
587
525
|
if not isinstance(json_dict[param], list):
|
|
588
526
|
return False
|
|
589
527
|
for item in json_dict[param]:
|
|
590
|
-
if not isinstance(item,
|
|
528
|
+
if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
|
|
591
529
|
return False
|
|
592
|
-
elif param_type and not isinstance(json_dict[param],
|
|
530
|
+
elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
|
|
593
531
|
return False
|
|
594
532
|
|
|
595
533
|
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
596
534
|
nested_schema = param_schema
|
|
597
|
-
if not
|
|
535
|
+
if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
598
536
|
return False
|
|
599
537
|
|
|
600
538
|
return True
|
|
601
539
|
|
|
602
540
|
|
|
603
|
-
|
|
541
|
+
VALID_JSON_TYPE_MAPPING: dict[str, Any] = {
|
|
604
542
|
'string': str,
|
|
605
543
|
'integer': int,
|
|
606
544
|
'number': float,
|