pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable
3
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
@@ -10,13 +10,14 @@ from typing import Literal, overload
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
12
12
 
13
- from .. import UnexpectedModelBehavior, _utils, result
13
+ from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -24,15 +25,12 @@ from ..messages import (
24
25
  ToolReturnPart,
25
26
  UserPromptPart,
26
27
  )
27
- from ..result import Usage
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
31
  AgentModel,
32
- EitherStreamedResponse,
33
32
  Model,
34
- StreamStructuredResponse,
35
- StreamTextResponse,
33
+ StreamedResponse,
36
34
  cached_async_http_client,
37
35
  check_allow_model_requests,
38
36
  )
@@ -41,7 +39,6 @@ try:
41
39
  from groq import NOT_GIVEN, AsyncGroq, AsyncStream
42
40
  from groq.types import chat
43
41
  from groq.types.chat import ChatCompletion, ChatCompletionChunk
44
- from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
45
42
  except ImportError as _import_error:
46
43
  raise ImportError(
47
44
  'Please install `groq` to use the Groq model, '
@@ -157,14 +154,14 @@ class GroqAgentModel(AgentModel):
157
154
 
158
155
  async def request(
159
156
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
160
- ) -> tuple[ModelResponse, result.Usage]:
157
+ ) -> tuple[ModelResponse, usage.Usage]:
161
158
  response = await self._completions_create(messages, False, model_settings)
162
159
  return self._process_response(response), _map_usage(response)
163
160
 
164
161
  @asynccontextmanager
165
162
  async def request_stream(
166
163
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
167
- ) -> AsyncIterator[EitherStreamedResponse]:
164
+ ) -> AsyncIterator[StreamedResponse]:
168
165
  response = await self._completions_create(messages, True, model_settings)
169
166
  async with response:
170
167
  yield await self._process_streamed_response(response)
@@ -200,7 +197,7 @@ class GroqAgentModel(AgentModel):
200
197
  model=str(self.model_name),
201
198
  messages=groq_messages,
202
199
  n=1,
203
- parallel_tool_calls=True if self.tools else NOT_GIVEN,
200
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
204
201
  tools=self.tools or NOT_GIVEN,
205
202
  tool_choice=tool_choice or NOT_GIVEN,
206
203
  stream=stream,
@@ -210,45 +207,32 @@ class GroqAgentModel(AgentModel):
210
207
  timeout=model_settings.get('timeout', NOT_GIVEN),
211
208
  )
212
209
 
213
- @staticmethod
214
- def _process_response(response: chat.ChatCompletion) -> ModelResponse:
210
+ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
215
211
  """Process a non-streamed response, and prepare a message to return."""
216
212
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
217
213
  choice = response.choices[0]
218
214
  items: list[ModelResponsePart] = []
219
215
  if choice.message.content is not None:
220
- items.append(TextPart(choice.message.content))
216
+ items.append(TextPart(content=choice.message.content))
221
217
  if choice.message.tool_calls is not None:
222
218
  for c in choice.message.tool_calls:
223
- items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
224
- return ModelResponse(items, timestamp=timestamp)
219
+ items.append(
220
+ ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
221
+ )
222
+ return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
225
223
 
226
- @staticmethod
227
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
224
+ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
228
225
  """Process a streamed response, and prepare a streaming response to return."""
229
- timestamp: datetime | None = None
230
- start_usage = Usage()
231
- # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
232
- while True:
233
- try:
234
- chunk = await response.__anext__()
235
- except StopAsyncIteration as e:
236
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
237
- timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
238
- start_usage += _map_usage(chunk)
239
-
240
- if chunk.choices:
241
- delta = chunk.choices[0].delta
242
-
243
- if delta.content is not None:
244
- return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
245
- elif delta.tool_calls is not None:
246
- return GroqStreamStructuredResponse(
247
- response,
248
- {c.index: c for c in delta.tool_calls},
249
- timestamp,
250
- start_usage,
251
- )
226
+ peekable_response = _utils.PeekableAsyncStream(response)
227
+ first_chunk = await peekable_response.peek()
228
+ if isinstance(first_chunk, _utils.Unset):
229
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
230
+
231
+ return GroqStreamedResponse(
232
+ _response=peekable_response,
233
+ _model_name=self.model_name,
234
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
235
+ )
252
236
 
253
237
  @classmethod
254
238
  def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
@@ -301,90 +285,36 @@ class GroqAgentModel(AgentModel):
301
285
 
302
286
 
303
287
  @dataclass
304
- class GroqStreamTextResponse(StreamTextResponse):
305
- """Implementation of `StreamTextResponse` for Groq models."""
288
+ class GroqStreamedResponse(StreamedResponse):
289
+ """Implementation of `StreamedResponse` for Groq models."""
306
290
 
307
- _first: str | None
308
- _response: AsyncStream[ChatCompletionChunk]
291
+ _response: AsyncIterable[ChatCompletionChunk]
309
292
  _timestamp: datetime
310
- _usage: result.Usage
311
- _buffer: list[str] = field(default_factory=list, init=False)
312
-
313
- async def __anext__(self) -> None:
314
- if self._first is not None:
315
- self._buffer.append(self._first)
316
- self._first = None
317
- return None
318
-
319
- chunk = await self._response.__anext__()
320
- self._usage = _map_usage(chunk)
321
-
322
- try:
323
- choice = chunk.choices[0]
324
- except IndexError:
325
- raise StopAsyncIteration()
326
-
327
- # we don't raise StopAsyncIteration on the last chunk because usage comes after this
328
- if choice.finish_reason is None:
329
- assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
330
- if choice.delta.content is not None:
331
- self._buffer.append(choice.delta.content)
332
-
333
- def get(self, *, final: bool = False) -> Iterable[str]:
334
- yield from self._buffer
335
- self._buffer.clear()
336
-
337
- def usage(self) -> Usage:
338
- return self._usage
339
-
340
- def timestamp(self) -> datetime:
341
- return self._timestamp
342
293
 
294
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
295
+ async for chunk in self._response:
296
+ self._usage += _map_usage(chunk)
343
297
 
344
- @dataclass
345
- class GroqStreamStructuredResponse(StreamStructuredResponse):
346
- """Implementation of `StreamStructuredResponse` for Groq models."""
347
-
348
- _response: AsyncStream[ChatCompletionChunk]
349
- _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
350
- _timestamp: datetime
351
- _usage: result.Usage
352
-
353
- async def __anext__(self) -> None:
354
- chunk = await self._response.__anext__()
355
- self._usage = _map_usage(chunk)
356
-
357
- try:
358
- choice = chunk.choices[0]
359
- except IndexError:
360
- raise StopAsyncIteration()
361
-
362
- if choice.finish_reason is not None:
363
- raise StopAsyncIteration()
364
-
365
- assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
366
-
367
- for new in choice.delta.tool_calls or []:
368
- if current := self._delta_tool_calls.get(new.index):
369
- if current.function is None:
370
- current.function = new.function
371
- elif new.function is not None:
372
- current.function.name = _utils.add_optional(current.function.name, new.function.name)
373
- current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
374
- else:
375
- self._delta_tool_calls[new.index] = new
376
-
377
- def get(self, *, final: bool = False) -> ModelResponse:
378
- items: list[ModelResponsePart] = []
379
- for c in self._delta_tool_calls.values():
380
- if f := c.function:
381
- if f.name is not None and f.arguments is not None:
382
- items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
383
-
384
- return ModelResponse(items, timestamp=self._timestamp)
385
-
386
- def usage(self) -> Usage:
387
- return self._usage
298
+ try:
299
+ choice = chunk.choices[0]
300
+ except IndexError:
301
+ continue
302
+
303
+ # Handle the text part of the response
304
+ content = choice.delta.content
305
+ if content is not None:
306
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
307
+
308
+ # Handle the tool calls
309
+ for dtc in choice.delta.tool_calls or []:
310
+ maybe_event = self._parts_manager.handle_tool_call_delta(
311
+ vendor_part_id=dtc.index,
312
+ tool_name=dtc.function and dtc.function.name,
313
+ args=dtc.function and dtc.function.arguments,
314
+ tool_call_id=dtc.id,
315
+ )
316
+ if maybe_event is not None:
317
+ yield maybe_event
388
318
 
389
319
  def timestamp(self) -> datetime:
390
320
  return self._timestamp
@@ -398,18 +328,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
398
328
  )
399
329
 
400
330
 
401
- def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
402
- usage = None
331
+ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
332
+ response_usage = None
403
333
  if isinstance(completion, ChatCompletion):
404
- usage = completion.usage
334
+ response_usage = completion.usage
405
335
  elif completion.x_groq is not None:
406
- usage = completion.x_groq.usage
336
+ response_usage = completion.x_groq.usage
407
337
 
408
- if usage is None:
409
- return result.Usage()
338
+ if response_usage is None:
339
+ return usage.Usage()
410
340
 
411
- return result.Usage(
412
- request_tokens=usage.prompt_tokens,
413
- response_tokens=usage.completion_tokens,
414
- total_tokens=usage.total_tokens,
341
+ return usage.Usage(
342
+ request_tokens=response_usage.prompt_tokens,
343
+ response_tokens=response_usage.completion_tokens,
344
+ total_tokens=response_usage.total_tokens,
415
345
  )
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from collections.abc import AsyncIterator, Iterable
4
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
@@ -12,7 +12,7 @@ import pydantic_core
12
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
13
13
  from typing_extensions import assert_never
14
14
 
15
- from .. import UnexpectedModelBehavior
15
+ from .. import UnexpectedModelBehavior, _utils
16
16
  from .._utils import now_utc as _now_utc
17
17
  from ..messages import (
18
18
  ArgsJson,
@@ -20,6 +20,7 @@ from ..messages import (
20
20
  ModelRequest,
21
21
  ModelResponse,
22
22
  ModelResponsePart,
23
+ ModelResponseStreamEvent,
23
24
  RetryPromptPart,
24
25
  SystemPromptPart,
25
26
  TextPart,
@@ -32,11 +33,10 @@ from ..settings import ModelSettings
32
33
  from ..tools import ToolDefinition
33
34
  from . import (
34
35
  AgentModel,
35
- EitherStreamedResponse,
36
36
  Model,
37
- StreamStructuredResponse,
38
- StreamTextResponse,
37
+ StreamedResponse,
39
38
  cached_async_http_client,
39
+ check_allow_model_requests,
40
40
  )
41
41
 
42
42
  try:
@@ -131,6 +131,7 @@ class MistralModel(Model):
131
131
  result_tools: list[ToolDefinition],
132
132
  ) -> AgentModel:
133
133
  """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
134
+ check_allow_model_requests()
134
135
  return MistralAgentModel(
135
136
  self.client,
136
137
  self.model_name,
@@ -148,7 +149,7 @@ class MistralAgentModel(AgentModel):
148
149
  """Implementation of `AgentModel` for Mistral models."""
149
150
 
150
151
  client: Mistral
151
- model_name: str
152
+ model_name: MistralModelName
152
153
  allow_text_result: bool
153
154
  function_tools: list[ToolDefinition]
154
155
  result_tools: list[ToolDefinition]
@@ -164,7 +165,7 @@ class MistralAgentModel(AgentModel):
164
165
  @asynccontextmanager
165
166
  async def request_stream(
166
167
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
167
- ) -> AsyncIterator[EitherStreamedResponse]:
168
+ ) -> AsyncIterator[StreamedResponse]:
168
169
  """Make a streaming request to the model from Pydantic AI call."""
169
170
  response = await self._stream_completions_create(messages, model_settings)
170
171
  async with response:
@@ -266,8 +267,7 @@ class MistralAgentModel(AgentModel):
266
267
  ]
267
268
  return tools if tools else None
268
269
 
269
- @staticmethod
270
- def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
270
+ def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
271
271
  """Process a non-streamed response, and prepare a message to return."""
272
272
  assert response.choices, 'Unexpected empty response choice.'
273
273
 
@@ -282,58 +282,37 @@ class MistralAgentModel(AgentModel):
282
282
 
283
283
  parts: list[ModelResponsePart] = []
284
284
  if text := _map_content(content):
285
- parts.append(TextPart(text))
285
+ parts.append(TextPart(content=text))
286
286
 
287
287
  if isinstance(tool_calls, list):
288
288
  for tool_call in tool_calls:
289
- tool = _map_mistral_to_pydantic_tool_call(tool_call)
289
+ tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
290
290
  parts.append(tool)
291
291
 
292
- return ModelResponse(parts, timestamp=timestamp)
292
+ return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
293
293
 
294
- @staticmethod
295
294
  async def _process_streamed_response(
295
+ self,
296
296
  result_tools: list[ToolDefinition],
297
297
  response: MistralEventStreamAsync[MistralCompletionEvent],
298
- ) -> EitherStreamedResponse:
298
+ ) -> StreamedResponse:
299
299
  """Process a streamed response, and prepare a streaming response to return."""
300
- start_usage = Usage()
301
-
302
- # Iterate until we get either `tool_calls` or `content` from the first chunk.
303
- while True:
304
- try:
305
- event = await response.__anext__()
306
- chunk = event.data
307
- except StopAsyncIteration as e:
308
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
300
+ peekable_response = _utils.PeekableAsyncStream(response)
301
+ first_chunk = await peekable_response.peek()
302
+ if isinstance(first_chunk, _utils.Unset):
303
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
309
304
 
310
- start_usage += _map_usage(chunk)
311
-
312
- if chunk.created:
313
- timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
314
- else:
315
- timestamp = _now_utc()
316
-
317
- if chunk.choices:
318
- delta = chunk.choices[0].delta
319
- content = _map_content(delta.content)
320
-
321
- tool_calls: list[MistralToolCall] | None = None
322
- if delta.tool_calls:
323
- tool_calls = delta.tool_calls
324
-
325
- if tool_calls or content and result_tools:
326
- return MistralStreamStructuredResponse(
327
- {c.id if c.id else 'null': c for c in tool_calls or []},
328
- {c.name: c for c in result_tools},
329
- response,
330
- content,
331
- timestamp,
332
- start_usage,
333
- )
305
+ if first_chunk.data.created:
306
+ timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
307
+ else:
308
+ timestamp = datetime.now(tz=timezone.utc)
334
309
 
335
- elif content:
336
- return MistralStreamTextResponse(content, response, timestamp, start_usage)
310
+ return MistralStreamedResponse(
311
+ _response=peekable_response,
312
+ _model_name=self.model_name,
313
+ _timestamp=timestamp,
314
+ _result_tools={c.name: c for c in result_tools},
315
+ )
337
316
 
338
317
  @staticmethod
339
318
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
@@ -467,108 +446,73 @@ class MistralAgentModel(AgentModel):
467
446
  assert_never(message)
468
447
 
469
448
 
470
- @dataclass
471
- class MistralStreamTextResponse(StreamTextResponse):
472
- """Implementation of `StreamTextResponse` for Mistral models."""
473
-
474
- _first: str | None
475
- _response: MistralEventStreamAsync[MistralCompletionEvent]
476
- _timestamp: datetime
477
- _usage: Usage
478
- _buffer: list[str] = field(default_factory=list, init=False)
479
-
480
- async def __anext__(self) -> None:
481
- if self._first is not None and len(self._first) > 0:
482
- self._buffer.append(self._first)
483
- self._first = None
484
- return None
485
-
486
- chunk = await self._response.__anext__()
487
- self._usage += _map_usage(chunk.data)
488
-
489
- try:
490
- choice = chunk.data.choices[0]
491
- except IndexError:
492
- raise StopAsyncIteration()
493
-
494
- content = choice.delta.content
495
- if choice.finish_reason is None:
496
- assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
497
-
498
- if text := _map_content(content):
499
- self._buffer.append(text)
500
-
501
- def get(self, *, final: bool = False) -> Iterable[str]:
502
- yield from self._buffer
503
- self._buffer.clear()
504
-
505
- def usage(self) -> Usage:
506
- return self._usage
507
-
508
- def timestamp(self) -> datetime:
509
- return self._timestamp
449
+ MistralToolCallId = Union[str, None]
510
450
 
511
451
 
512
452
  @dataclass
513
- class MistralStreamStructuredResponse(StreamStructuredResponse):
514
- """Implementation of `StreamStructuredResponse` for Mistral models."""
453
+ class MistralStreamedResponse(StreamedResponse):
454
+ """Implementation of `StreamedResponse` for Mistral models."""
515
455
 
516
- _function_tools: dict[str, MistralToolCall]
517
- _result_tools: dict[str, ToolDefinition]
518
- _response: MistralEventStreamAsync[MistralCompletionEvent]
519
- _delta_content: str | None
456
+ _response: AsyncIterable[MistralCompletionEvent]
520
457
  _timestamp: datetime
521
- _usage: Usage
522
-
523
- async def __anext__(self) -> None:
524
- chunk = await self._response.__anext__()
525
- self._usage += _map_usage(chunk.data)
526
-
527
- try:
528
- choice = chunk.data.choices[0]
529
-
530
- except IndexError:
531
- raise StopAsyncIteration()
532
-
533
- if choice.finish_reason is not None:
534
- raise StopAsyncIteration()
535
-
536
- content = choice.delta.content
537
- if self._result_tools:
538
- if text := _map_content(content):
539
- self._delta_content = (self._delta_content or '') + text
540
-
541
- def get(self, *, final: bool = False) -> ModelResponse:
542
- calls: list[ModelResponsePart] = []
543
- if self._function_tools and self._result_tools or self._function_tools:
544
- for tool_call in self._function_tools.values():
545
- tool = _map_mistral_to_pydantic_tool_call(tool_call)
546
- calls.append(tool)
547
-
548
- elif self._delta_content and self._result_tools:
549
- output_json: dict[str, Any] | None = pydantic_core.from_json(
550
- self._delta_content, allow_partial='trailing-strings'
551
- )
458
+ _result_tools: dict[str, ToolDefinition]
552
459
 
553
- if output_json:
554
- for result_tool in self._result_tools.values():
555
- # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
556
- # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
557
- # Example with BaseModel and required fields.
558
- if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
559
- continue
460
+ _delta_content: str = field(default='', init=False)
560
461
 
561
- tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
562
- calls.append(tool)
462
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
463
+ chunk: MistralCompletionEvent
464
+ async for chunk in self._response:
465
+ self._usage += _map_usage(chunk.data)
563
466
 
564
- return ModelResponse(calls, timestamp=self._timestamp)
467
+ try:
468
+ choice = chunk.data.choices[0]
469
+ except IndexError:
470
+ continue
471
+
472
+ # Handle the text part of the response
473
+ content = choice.delta.content
474
+ text = _map_content(content)
475
+ if text:
476
+ # Attempt to produce a result tool call from the received text
477
+ if self._result_tools:
478
+ self._delta_content += text
479
+ maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
480
+ if maybe_tool_call_part:
481
+ yield self._parts_manager.handle_tool_call_part(
482
+ vendor_part_id='result',
483
+ tool_name=maybe_tool_call_part.tool_name,
484
+ args=maybe_tool_call_part.args_as_dict(),
485
+ tool_call_id=maybe_tool_call_part.tool_call_id,
486
+ )
487
+ else:
488
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
565
489
 
566
- def usage(self) -> Usage:
567
- return self._usage
490
+ # Handle the explicit tool calls
491
+ for index, dtc in enumerate(choice.delta.tool_calls or []):
492
+ # It seems that mistral just sends full tool calls, so we just use them directly, rather than building
493
+ yield self._parts_manager.handle_tool_call_part(
494
+ vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
495
+ )
568
496
 
569
497
  def timestamp(self) -> datetime:
570
498
  return self._timestamp
571
499
 
500
+ @staticmethod
501
+ def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
502
+ output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
503
+ if output_json:
504
+ for result_tool in result_tools.values():
505
+ # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
506
+ # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
507
+ # Example with BaseModel and required fields.
508
+ if not MistralStreamedResponse._validate_required_json_schema(
509
+ output_json, result_tool.parameters_json_schema
510
+ ):
511
+ continue
512
+
513
+ # The following part_id will be thrown away
514
+ return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
515
+
572
516
  @staticmethod
573
517
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
574
518
  """Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
@@ -587,20 +531,20 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
587
531
  if not isinstance(json_dict[param], list):
588
532
  return False
589
533
  for item in json_dict[param]:
590
- if not isinstance(item, VALIDE_JSON_TYPE_MAPPING[param_items_type]):
534
+ if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
591
535
  return False
592
- elif param_type and not isinstance(json_dict[param], VALIDE_JSON_TYPE_MAPPING[param_type]):
536
+ elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
593
537
  return False
594
538
 
595
539
  if isinstance(json_dict[param], dict) and 'properties' in param_schema:
596
540
  nested_schema = param_schema
597
- if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
541
+ if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
598
542
  return False
599
543
 
600
544
  return True
601
545
 
602
546
 
603
- VALIDE_JSON_TYPE_MAPPING: dict[str, Any] = {
547
+ VALID_JSON_TYPE_MAPPING: dict[str, Any] = {
604
548
  'string': str,
605
549
  'integer': int,
606
550
  'number': float,
@@ -10,6 +10,7 @@ from . import (
10
10
  AgentModel,
11
11
  Model,
12
12
  cached_async_http_client,
13
+ check_allow_model_requests,
13
14
  )
14
15
 
15
16
  try:
@@ -25,6 +26,7 @@ from .openai import OpenAIModel
25
26
 
26
27
  CommonOllamaModelNames = Literal[
27
28
  'codellama',
29
+ 'deepseek-r1',
28
30
  'gemma',
29
31
  'gemma2',
30
32
  'llama3',
@@ -36,6 +38,7 @@ CommonOllamaModelNames = Literal[
36
38
  'mistral-nemo',
37
39
  'mixtral',
38
40
  'phi3',
41
+ 'phi4',
39
42
  'qwq',
40
43
  'qwen',
41
44
  'qwen2',
@@ -109,6 +112,7 @@ class OllamaModel(Model):
109
112
  allow_text_result: bool,
110
113
  result_tools: list[ToolDefinition],
111
114
  ) -> AgentModel:
115
+ check_allow_model_requests()
112
116
  return await self.openai_model.agent_model(
113
117
  function_tools=function_tools,
114
118
  allow_text_result=allow_text_result,