pydantic-ai-slim 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_output.py +19 -7
- pydantic_ai/_parts_manager.py +8 -10
- pydantic_ai/_tool_manager.py +18 -1
- pydantic_ai/ag_ui.py +32 -17
- pydantic_ai/agent/abstract.py +8 -0
- pydantic_ai/durable_exec/dbos/_agent.py +5 -2
- pydantic_ai/durable_exec/temporal/_agent.py +1 -1
- pydantic_ai/messages.py +30 -6
- pydantic_ai/models/anthropic.py +55 -25
- pydantic_ai/models/bedrock.py +82 -31
- pydantic_ai/models/cohere.py +39 -13
- pydantic_ai/models/function.py +8 -1
- pydantic_ai/models/google.py +62 -33
- pydantic_ai/models/groq.py +35 -7
- pydantic_ai/models/huggingface.py +27 -5
- pydantic_ai/models/mistral.py +54 -20
- pydantic_ai/models/openai.py +88 -45
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/bedrock.py +9 -1
- pydantic_ai/settings.py +1 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.3.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.3.dist-info}/RECORD +25 -25
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -23,6 +23,7 @@ from ..messages import (
|
|
|
23
23
|
BuiltinToolCallPart,
|
|
24
24
|
BuiltinToolReturnPart,
|
|
25
25
|
DocumentUrl,
|
|
26
|
+
FinishReason,
|
|
26
27
|
ImageUrl,
|
|
27
28
|
ModelMessage,
|
|
28
29
|
ModelRequest,
|
|
@@ -100,6 +101,14 @@ but allow any name in the type hints.
|
|
|
100
101
|
See <https://console.groq.com/docs/models> for an up to date date list of models and more details.
|
|
101
102
|
"""
|
|
102
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason] = {
|
|
105
|
+
'stop': 'stop',
|
|
106
|
+
'length': 'length',
|
|
107
|
+
'tool_calls': 'tool_call',
|
|
108
|
+
'content_filter': 'content_filter',
|
|
109
|
+
'function_call': 'tool_call',
|
|
110
|
+
}
|
|
111
|
+
|
|
103
112
|
|
|
104
113
|
class GroqModelSettings(ModelSettings, total=False):
|
|
105
114
|
"""Settings used for a Groq model request."""
|
|
@@ -186,7 +195,13 @@ class GroqModel(Model):
|
|
|
186
195
|
tool_name=error.error.failed_generation.name,
|
|
187
196
|
args=error.error.failed_generation.arguments,
|
|
188
197
|
)
|
|
189
|
-
return ModelResponse(
|
|
198
|
+
return ModelResponse(
|
|
199
|
+
parts=[tool_call_part],
|
|
200
|
+
model_name=e.model_name,
|
|
201
|
+
timestamp=_utils.now_utc(),
|
|
202
|
+
provider_name=self._provider.name,
|
|
203
|
+
finish_reason='error',
|
|
204
|
+
)
|
|
190
205
|
except ValidationError:
|
|
191
206
|
pass
|
|
192
207
|
raise
|
|
@@ -298,16 +313,16 @@ class GroqModel(Model):
|
|
|
298
313
|
tool_call_id = generate_tool_call_id()
|
|
299
314
|
items.append(
|
|
300
315
|
BuiltinToolCallPart(
|
|
301
|
-
tool_name=tool.type, args=tool.arguments, provider_name=
|
|
316
|
+
tool_name=tool.type, args=tool.arguments, provider_name=self.system, tool_call_id=tool_call_id
|
|
302
317
|
)
|
|
303
318
|
)
|
|
304
319
|
items.append(
|
|
305
320
|
BuiltinToolReturnPart(
|
|
306
|
-
provider_name=
|
|
321
|
+
provider_name=self.system, tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
|
|
307
322
|
)
|
|
308
323
|
)
|
|
309
|
-
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
310
324
|
if choice.message.reasoning is not None:
|
|
325
|
+
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
311
326
|
items.append(ThinkingPart(content=choice.message.reasoning))
|
|
312
327
|
if choice.message.content is not None:
|
|
313
328
|
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
|
|
@@ -315,6 +330,10 @@ class GroqModel(Model):
|
|
|
315
330
|
if choice.message.tool_calls is not None:
|
|
316
331
|
for c in choice.message.tool_calls:
|
|
317
332
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
333
|
+
|
|
334
|
+
raw_finish_reason = choice.finish_reason
|
|
335
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
336
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
318
337
|
return ModelResponse(
|
|
319
338
|
parts=items,
|
|
320
339
|
usage=_map_usage(response),
|
|
@@ -322,6 +341,8 @@ class GroqModel(Model):
|
|
|
322
341
|
timestamp=timestamp,
|
|
323
342
|
provider_response_id=response.id,
|
|
324
343
|
provider_name=self._provider.name,
|
|
344
|
+
finish_reason=finish_reason,
|
|
345
|
+
provider_details=provider_details,
|
|
325
346
|
)
|
|
326
347
|
|
|
327
348
|
async def _process_streamed_response(
|
|
@@ -338,7 +359,7 @@ class GroqModel(Model):
|
|
|
338
359
|
return GroqStreamedResponse(
|
|
339
360
|
model_request_parameters=model_request_parameters,
|
|
340
361
|
_response=peekable_response,
|
|
341
|
-
_model_name=
|
|
362
|
+
_model_name=first_chunk.model,
|
|
342
363
|
_model_profile=self.profile,
|
|
343
364
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
344
365
|
_provider_name=self._provider.name,
|
|
@@ -376,8 +397,8 @@ class GroqModel(Model):
|
|
|
376
397
|
elif isinstance(item, ToolCallPart):
|
|
377
398
|
tool_calls.append(self._map_tool_call(item))
|
|
378
399
|
elif isinstance(item, ThinkingPart):
|
|
379
|
-
|
|
380
|
-
|
|
400
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
401
|
+
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
381
402
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
382
403
|
# This is currently never returned from groq
|
|
383
404
|
pass
|
|
@@ -497,11 +518,18 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
497
518
|
async for chunk in self._response:
|
|
498
519
|
self._usage += _map_usage(chunk)
|
|
499
520
|
|
|
521
|
+
if chunk.id: # pragma: no branch
|
|
522
|
+
self.provider_response_id = chunk.id
|
|
523
|
+
|
|
500
524
|
try:
|
|
501
525
|
choice = chunk.choices[0]
|
|
502
526
|
except IndexError:
|
|
503
527
|
continue
|
|
504
528
|
|
|
529
|
+
if raw_finish_reason := choice.finish_reason:
|
|
530
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
531
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
532
|
+
|
|
505
533
|
# Handle the text part of the response
|
|
506
534
|
content = choice.delta.content
|
|
507
535
|
if content is not None:
|
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
BuiltinToolCallPart,
|
|
21
21
|
BuiltinToolReturnPart,
|
|
22
22
|
DocumentUrl,
|
|
23
|
+
FinishReason,
|
|
23
24
|
ImageUrl,
|
|
24
25
|
ModelMessage,
|
|
25
26
|
ModelRequest,
|
|
@@ -58,6 +59,7 @@ try:
|
|
|
58
59
|
ChatCompletionOutput,
|
|
59
60
|
ChatCompletionOutputMessage,
|
|
60
61
|
ChatCompletionStreamOutput,
|
|
62
|
+
TextGenerationOutputFinishReason,
|
|
61
63
|
)
|
|
62
64
|
from huggingface_hub.errors import HfHubHTTPError
|
|
63
65
|
|
|
@@ -94,6 +96,12 @@ HuggingFaceModelName = str | LatestHuggingFaceModelNames
|
|
|
94
96
|
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
95
97
|
"""
|
|
96
98
|
|
|
99
|
+
_FINISH_REASON_MAP: dict[TextGenerationOutputFinishReason, FinishReason] = {
|
|
100
|
+
'length': 'length',
|
|
101
|
+
'eos_token': 'stop',
|
|
102
|
+
'stop_sequence': 'stop',
|
|
103
|
+
}
|
|
104
|
+
|
|
97
105
|
|
|
98
106
|
class HuggingFaceModelSettings(ModelSettings, total=False):
|
|
99
107
|
"""Settings used for a Hugging Face model request."""
|
|
@@ -266,6 +274,11 @@ class HuggingFaceModel(Model):
|
|
|
266
274
|
if tool_calls is not None:
|
|
267
275
|
for c in tool_calls:
|
|
268
276
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
277
|
+
|
|
278
|
+
raw_finish_reason = choice.finish_reason
|
|
279
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
280
|
+
finish_reason = _FINISH_REASON_MAP.get(cast(TextGenerationOutputFinishReason, raw_finish_reason), None)
|
|
281
|
+
|
|
269
282
|
return ModelResponse(
|
|
270
283
|
parts=items,
|
|
271
284
|
usage=_map_usage(response),
|
|
@@ -273,6 +286,8 @@ class HuggingFaceModel(Model):
|
|
|
273
286
|
timestamp=timestamp,
|
|
274
287
|
provider_response_id=response.id,
|
|
275
288
|
provider_name=self._provider.name,
|
|
289
|
+
finish_reason=finish_reason,
|
|
290
|
+
provider_details=provider_details,
|
|
276
291
|
)
|
|
277
292
|
|
|
278
293
|
async def _process_streamed_response(
|
|
@@ -288,7 +303,7 @@ class HuggingFaceModel(Model):
|
|
|
288
303
|
|
|
289
304
|
return HuggingFaceStreamedResponse(
|
|
290
305
|
model_request_parameters=model_request_parameters,
|
|
291
|
-
_model_name=
|
|
306
|
+
_model_name=first_chunk.model,
|
|
292
307
|
_model_profile=self.profile,
|
|
293
308
|
_response=peekable_response,
|
|
294
309
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
@@ -316,10 +331,8 @@ class HuggingFaceModel(Model):
|
|
|
316
331
|
elif isinstance(item, ToolCallPart):
|
|
317
332
|
tool_calls.append(self._map_tool_call(item))
|
|
318
333
|
elif isinstance(item, ThinkingPart):
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
322
|
-
pass
|
|
334
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
335
|
+
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
323
336
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
324
337
|
# This is currently never returned from huggingface
|
|
325
338
|
pass
|
|
@@ -445,11 +458,20 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
445
458
|
async for chunk in self._response:
|
|
446
459
|
self._usage += _map_usage(chunk)
|
|
447
460
|
|
|
461
|
+
if chunk.id: # pragma: no branch
|
|
462
|
+
self.provider_response_id = chunk.id
|
|
463
|
+
|
|
448
464
|
try:
|
|
449
465
|
choice = chunk.choices[0]
|
|
450
466
|
except IndexError:
|
|
451
467
|
continue
|
|
452
468
|
|
|
469
|
+
if raw_finish_reason := choice.finish_reason:
|
|
470
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
471
|
+
self.finish_reason = _FINISH_REASON_MAP.get(
|
|
472
|
+
cast(TextGenerationOutputFinishReason, raw_finish_reason), None
|
|
473
|
+
)
|
|
474
|
+
|
|
453
475
|
# Handle the text part of the response
|
|
454
476
|
content = choice.delta.content
|
|
455
477
|
if content is not None:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
|
|
14
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
15
15
|
from .._run_context import RunContext
|
|
16
|
-
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
16
|
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
18
17
|
from ..exceptions import UserError
|
|
19
18
|
from ..messages import (
|
|
@@ -21,6 +20,7 @@ from ..messages import (
|
|
|
21
20
|
BuiltinToolCallPart,
|
|
22
21
|
BuiltinToolReturnPart,
|
|
23
22
|
DocumentUrl,
|
|
23
|
+
FinishReason,
|
|
24
24
|
ImageUrl,
|
|
25
25
|
ModelMessage,
|
|
26
26
|
ModelRequest,
|
|
@@ -61,12 +61,15 @@ try:
|
|
|
61
61
|
ImageURLChunk as MistralImageURLChunk,
|
|
62
62
|
Mistral,
|
|
63
63
|
OptionalNullable as MistralOptionalNullable,
|
|
64
|
+
ReferenceChunk as MistralReferenceChunk,
|
|
64
65
|
TextChunk as MistralTextChunk,
|
|
66
|
+
ThinkChunk as MistralThinkChunk,
|
|
65
67
|
ToolChoiceEnum as MistralToolChoiceEnum,
|
|
66
68
|
)
|
|
67
69
|
from mistralai.models import (
|
|
68
70
|
ChatCompletionResponse as MistralChatCompletionResponse,
|
|
69
71
|
CompletionEvent as MistralCompletionEvent,
|
|
72
|
+
FinishReason as MistralFinishReason,
|
|
70
73
|
Messages as MistralMessages,
|
|
71
74
|
SDKError,
|
|
72
75
|
Tool as MistralTool,
|
|
@@ -98,6 +101,14 @@ allow any name in the type hints.
|
|
|
98
101
|
Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list.
|
|
99
102
|
"""
|
|
100
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[MistralFinishReason, FinishReason] = {
|
|
105
|
+
'stop': 'stop',
|
|
106
|
+
'length': 'length',
|
|
107
|
+
'model_length': 'length',
|
|
108
|
+
'error': 'error',
|
|
109
|
+
'tool_calls': 'tool_call',
|
|
110
|
+
}
|
|
111
|
+
|
|
101
112
|
|
|
102
113
|
class MistralModelSettings(ModelSettings, total=False):
|
|
103
114
|
"""Settings used for a Mistral model request."""
|
|
@@ -339,14 +350,21 @@ class MistralModel(Model):
|
|
|
339
350
|
tool_calls = choice.message.tool_calls
|
|
340
351
|
|
|
341
352
|
parts: list[ModelResponsePart] = []
|
|
342
|
-
|
|
343
|
-
|
|
353
|
+
text, thinking = _map_content(content)
|
|
354
|
+
for thought in thinking:
|
|
355
|
+
parts.append(ThinkingPart(content=thought))
|
|
356
|
+
if text:
|
|
357
|
+
parts.append(TextPart(content=text))
|
|
344
358
|
|
|
345
359
|
if isinstance(tool_calls, list):
|
|
346
360
|
for tool_call in tool_calls:
|
|
347
361
|
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
348
362
|
parts.append(tool)
|
|
349
363
|
|
|
364
|
+
raw_finish_reason = choice.finish_reason
|
|
365
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
366
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
367
|
+
|
|
350
368
|
return ModelResponse(
|
|
351
369
|
parts=parts,
|
|
352
370
|
usage=_map_usage(response),
|
|
@@ -354,6 +372,8 @@ class MistralModel(Model):
|
|
|
354
372
|
timestamp=timestamp,
|
|
355
373
|
provider_response_id=response.id,
|
|
356
374
|
provider_name=self._provider.name,
|
|
375
|
+
finish_reason=finish_reason,
|
|
376
|
+
provider_details=provider_details,
|
|
357
377
|
)
|
|
358
378
|
|
|
359
379
|
async def _process_streamed_response(
|
|
@@ -377,7 +397,7 @@ class MistralModel(Model):
|
|
|
377
397
|
return MistralStreamedResponse(
|
|
378
398
|
model_request_parameters=model_request_parameters,
|
|
379
399
|
_response=peekable_response,
|
|
380
|
-
_model_name=
|
|
400
|
+
_model_name=first_chunk.data.model,
|
|
381
401
|
_timestamp=timestamp,
|
|
382
402
|
_provider_name=self._provider.name,
|
|
383
403
|
)
|
|
@@ -503,16 +523,14 @@ class MistralModel(Model):
|
|
|
503
523
|
mistral_messages.extend(self._map_user_message(message))
|
|
504
524
|
elif isinstance(message, ModelResponse):
|
|
505
525
|
content_chunks: list[MistralContentChunk] = []
|
|
526
|
+
thinking_chunks: list[MistralTextChunk | MistralReferenceChunk] = []
|
|
506
527
|
tool_calls: list[MistralToolCall] = []
|
|
507
528
|
|
|
508
529
|
for part in message.parts:
|
|
509
530
|
if isinstance(part, TextPart):
|
|
510
531
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
511
532
|
elif isinstance(part, ThinkingPart):
|
|
512
|
-
|
|
513
|
-
# please open an issue. The below code is the code to send thinking to the provider.
|
|
514
|
-
# content_chunks.append(MistralTextChunk(text=f'<think>{part.content}</think>'))
|
|
515
|
-
pass
|
|
533
|
+
thinking_chunks.append(MistralTextChunk(text=part.content))
|
|
516
534
|
elif isinstance(part, ToolCallPart):
|
|
517
535
|
tool_calls.append(self._map_tool_call(part))
|
|
518
536
|
elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
@@ -520,6 +538,8 @@ class MistralModel(Model):
|
|
|
520
538
|
pass
|
|
521
539
|
else:
|
|
522
540
|
assert_never(part)
|
|
541
|
+
if thinking_chunks:
|
|
542
|
+
content_chunks.insert(0, MistralThinkChunk(thinking=thinking_chunks))
|
|
523
543
|
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
524
544
|
else:
|
|
525
545
|
assert_never(message)
|
|
@@ -595,14 +615,23 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
595
615
|
async for chunk in self._response:
|
|
596
616
|
self._usage += _map_usage(chunk.data)
|
|
597
617
|
|
|
618
|
+
if chunk.data.id: # pragma: no branch
|
|
619
|
+
self.provider_response_id = chunk.data.id
|
|
620
|
+
|
|
598
621
|
try:
|
|
599
622
|
choice = chunk.data.choices[0]
|
|
600
623
|
except IndexError:
|
|
601
624
|
continue
|
|
602
625
|
|
|
626
|
+
if raw_finish_reason := choice.finish_reason:
|
|
627
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
628
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
629
|
+
|
|
603
630
|
# Handle the text part of the response
|
|
604
631
|
content = choice.delta.content
|
|
605
|
-
text = _map_content(content)
|
|
632
|
+
text, thinking = _map_content(content)
|
|
633
|
+
for thought in thinking:
|
|
634
|
+
self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
|
|
606
635
|
if text:
|
|
607
636
|
# Attempt to produce an output tool call from the received text
|
|
608
637
|
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
|
|
@@ -715,32 +744,37 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
|
|
|
715
744
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
716
745
|
if response.usage:
|
|
717
746
|
return RequestUsage(
|
|
718
|
-
input_tokens=response.usage.prompt_tokens,
|
|
719
|
-
output_tokens=response.usage.completion_tokens,
|
|
747
|
+
input_tokens=response.usage.prompt_tokens or 0,
|
|
748
|
+
output_tokens=response.usage.completion_tokens or 0,
|
|
720
749
|
)
|
|
721
750
|
else:
|
|
722
|
-
return RequestUsage()
|
|
751
|
+
return RequestUsage()
|
|
723
752
|
|
|
724
753
|
|
|
725
|
-
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
754
|
+
def _map_content(content: MistralOptionalNullable[MistralContent]) -> tuple[str | None, list[str]]:
|
|
726
755
|
"""Maps the delta content from a Mistral Completion Chunk to a string or None."""
|
|
727
|
-
|
|
756
|
+
text: str | None = None
|
|
757
|
+
thinking: list[str] = []
|
|
728
758
|
|
|
729
759
|
if isinstance(content, MistralUnset) or not content:
|
|
730
|
-
|
|
760
|
+
return None, []
|
|
731
761
|
elif isinstance(content, list):
|
|
732
762
|
for chunk in content:
|
|
733
763
|
if isinstance(chunk, MistralTextChunk):
|
|
734
|
-
|
|
764
|
+
text = text or '' + chunk.text
|
|
765
|
+
elif isinstance(chunk, MistralThinkChunk):
|
|
766
|
+
for thought in chunk.thinking:
|
|
767
|
+
if thought.type == 'text': # pragma: no branch
|
|
768
|
+
thinking.append(thought.text)
|
|
735
769
|
else:
|
|
736
770
|
assert False, ( # pragma: no cover
|
|
737
771
|
f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
738
772
|
)
|
|
739
773
|
elif isinstance(content, str):
|
|
740
|
-
|
|
774
|
+
text = content
|
|
741
775
|
|
|
742
776
|
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
743
|
-
if
|
|
744
|
-
|
|
777
|
+
if text and len(text) == 0: # pragma: no cover
|
|
778
|
+
text = None
|
|
745
779
|
|
|
746
|
-
return
|
|
780
|
+
return text, thinking
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -4,7 +4,7 @@ import base64
|
|
|
4
4
|
import warnings
|
|
5
5
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from typing import Any, Literal, cast, overload
|
|
10
10
|
|
|
@@ -31,6 +31,7 @@ from ..messages import (
|
|
|
31
31
|
ModelResponse,
|
|
32
32
|
ModelResponsePart,
|
|
33
33
|
ModelResponseStreamEvent,
|
|
34
|
+
PartStartEvent,
|
|
34
35
|
RetryPromptPart,
|
|
35
36
|
SystemPromptPart,
|
|
36
37
|
TextPart,
|
|
@@ -73,6 +74,7 @@ try:
|
|
|
73
74
|
)
|
|
74
75
|
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
|
|
75
76
|
from openai.types.responses.response_input_param import FunctionCallOutput, Message
|
|
77
|
+
from openai.types.responses.response_reasoning_item_param import Summary
|
|
76
78
|
from openai.types.responses.response_status import ResponseStatus
|
|
77
79
|
from openai.types.shared import ReasoningEffort
|
|
78
80
|
from openai.types.shared_params import Reasoning
|
|
@@ -491,9 +493,17 @@ class OpenAIChatModel(Model):
|
|
|
491
493
|
|
|
492
494
|
choice = response.choices[0]
|
|
493
495
|
items: list[ModelResponsePart] = []
|
|
494
|
-
# The `reasoning_content` is only present in DeepSeek models.
|
|
496
|
+
# The `reasoning_content` field is only present in DeepSeek models.
|
|
497
|
+
# https://api-docs.deepseek.com/guides/reasoning_model
|
|
495
498
|
if reasoning_content := getattr(choice.message, 'reasoning_content', None):
|
|
496
|
-
items.append(ThinkingPart(content=reasoning_content))
|
|
499
|
+
items.append(ThinkingPart(id='reasoning_content', content=reasoning_content, provider_name=self.system))
|
|
500
|
+
|
|
501
|
+
# NOTE: We don't currently handle OpenRouter `reasoning_details`:
|
|
502
|
+
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
|
|
503
|
+
# NOTE: We don't currently handle OpenRouter/gpt-oss `reasoning`:
|
|
504
|
+
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
|
|
505
|
+
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
|
|
506
|
+
# If you need this, please file an issue.
|
|
497
507
|
|
|
498
508
|
vendor_details: dict[str, Any] = {}
|
|
499
509
|
|
|
@@ -513,7 +523,10 @@ class OpenAIChatModel(Model):
|
|
|
513
523
|
]
|
|
514
524
|
|
|
515
525
|
if choice.message.content is not None:
|
|
516
|
-
items.extend(
|
|
526
|
+
items.extend(
|
|
527
|
+
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
|
|
528
|
+
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
|
|
529
|
+
)
|
|
517
530
|
if choice.message.tool_calls is not None:
|
|
518
531
|
for c in choice.message.tool_calls:
|
|
519
532
|
if isinstance(c, ChatCompletionMessageFunctionToolCall):
|
|
@@ -527,10 +540,9 @@ class OpenAIChatModel(Model):
|
|
|
527
540
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
528
541
|
items.append(part)
|
|
529
542
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
543
|
+
raw_finish_reason = choice.finish_reason
|
|
544
|
+
vendor_details['finish_reason'] = raw_finish_reason
|
|
545
|
+
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
534
546
|
|
|
535
547
|
return ModelResponse(
|
|
536
548
|
parts=items,
|
|
@@ -556,7 +568,7 @@ class OpenAIChatModel(Model):
|
|
|
556
568
|
|
|
557
569
|
return OpenAIStreamedResponse(
|
|
558
570
|
model_request_parameters=model_request_parameters,
|
|
559
|
-
_model_name=
|
|
571
|
+
_model_name=first_chunk.model,
|
|
560
572
|
_model_profile=self.profile,
|
|
561
573
|
_response=peekable_response,
|
|
562
574
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
@@ -569,6 +581,12 @@ class OpenAIChatModel(Model):
|
|
|
569
581
|
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
|
|
570
582
|
for tool in model_request_parameters.builtin_tools:
|
|
571
583
|
if isinstance(tool, WebSearchTool): # pragma: no branch
|
|
584
|
+
if not OpenAIModelProfile.from_profile(self.profile).openai_chat_supports_web_search:
|
|
585
|
+
raise UserError(
|
|
586
|
+
f'WebSearchTool is not supported with `OpenAIChatModel` and model {self.model_name!r}. '
|
|
587
|
+
f'Please use `OpenAIResponsesModel` instead.'
|
|
588
|
+
)
|
|
589
|
+
|
|
572
590
|
if tool.user_location:
|
|
573
591
|
return WebSearchOptions(
|
|
574
592
|
search_context_size=tool.search_context_size,
|
|
@@ -580,7 +598,7 @@ class OpenAIChatModel(Model):
|
|
|
580
598
|
return WebSearchOptions(search_context_size=tool.search_context_size)
|
|
581
599
|
else:
|
|
582
600
|
raise UserError(
|
|
583
|
-
f'`{tool.__class__.__name__}` is not supported by `
|
|
601
|
+
f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.'
|
|
584
602
|
)
|
|
585
603
|
|
|
586
604
|
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
@@ -597,10 +615,11 @@ class OpenAIChatModel(Model):
|
|
|
597
615
|
if isinstance(item, TextPart):
|
|
598
616
|
texts.append(item.content)
|
|
599
617
|
elif isinstance(item, ThinkingPart):
|
|
600
|
-
# NOTE:
|
|
601
|
-
#
|
|
602
|
-
#
|
|
603
|
-
|
|
618
|
+
# NOTE: DeepSeek `reasoning_content` field should NOT be sent back per https://api-docs.deepseek.com/guides/reasoning_model,
|
|
619
|
+
# but we currently just send it in `<think>` tags anyway as we don't want DeepSeek-specific checks here.
|
|
620
|
+
# If you need this changed, please file an issue.
|
|
621
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
622
|
+
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
604
623
|
elif isinstance(item, ToolCallPart):
|
|
605
624
|
tool_calls.append(self._map_tool_call(item))
|
|
606
625
|
# OpenAI doesn't return built-in tool calls
|
|
@@ -838,16 +857,27 @@ class OpenAIResponsesModel(Model):
|
|
|
838
857
|
timestamp = number_to_datetime(response.created_at)
|
|
839
858
|
items: list[ModelResponsePart] = []
|
|
840
859
|
for item in response.output:
|
|
841
|
-
if item.
|
|
860
|
+
if isinstance(item, responses.ResponseReasoningItem):
|
|
861
|
+
signature = item.encrypted_content
|
|
842
862
|
for summary in item.summary:
|
|
843
|
-
#
|
|
844
|
-
#
|
|
845
|
-
items.append(
|
|
846
|
-
|
|
863
|
+
# We use the same id for all summaries so that we can merge them on the round trip.
|
|
864
|
+
# We only need to store the signature once.
|
|
865
|
+
items.append(
|
|
866
|
+
ThinkingPart(
|
|
867
|
+
content=summary.text,
|
|
868
|
+
id=item.id,
|
|
869
|
+
signature=signature,
|
|
870
|
+
provider_name=self.system if signature else None,
|
|
871
|
+
)
|
|
872
|
+
)
|
|
873
|
+
signature = None
|
|
874
|
+
# NOTE: We don't currently handle the raw CoT from gpt-oss `reasoning_text`: https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot
|
|
875
|
+
# If you need this, please file an issue.
|
|
876
|
+
elif isinstance(item, responses.ResponseOutputMessage):
|
|
847
877
|
for content in item.content:
|
|
848
|
-
if content.
|
|
878
|
+
if isinstance(content, responses.ResponseOutputText): # pragma: no branch
|
|
849
879
|
items.append(TextPart(content.text))
|
|
850
|
-
elif item.
|
|
880
|
+
elif isinstance(item, responses.ResponseFunctionToolCall):
|
|
851
881
|
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
852
882
|
|
|
853
883
|
finish_reason: FinishReason | None = None
|
|
@@ -882,7 +912,7 @@ class OpenAIResponsesModel(Model):
|
|
|
882
912
|
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
|
|
883
913
|
return OpenAIResponsesStreamedResponse(
|
|
884
914
|
model_request_parameters=model_request_parameters,
|
|
885
|
-
_model_name=
|
|
915
|
+
_model_name=first_chunk.response.model,
|
|
886
916
|
_response=peekable_response,
|
|
887
917
|
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
888
918
|
_provider_name=self._provider.name,
|
|
@@ -974,6 +1004,7 @@ class OpenAIResponsesModel(Model):
|
|
|
974
1004
|
reasoning=reasoning,
|
|
975
1005
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
976
1006
|
text=text or NOT_GIVEN,
|
|
1007
|
+
include=['reasoning.encrypted_content'],
|
|
977
1008
|
extra_headers=extra_headers,
|
|
978
1009
|
extra_body=model_settings.get('extra_body'),
|
|
979
1010
|
)
|
|
@@ -1035,7 +1066,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1035
1066
|
),
|
|
1036
1067
|
}
|
|
1037
1068
|
|
|
1038
|
-
async def _map_messages(
|
|
1069
|
+
async def _map_messages( # noqa: C901
|
|
1039
1070
|
self, messages: list[ModelMessage]
|
|
1040
1071
|
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
|
|
1041
1072
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
@@ -1072,33 +1103,30 @@ class OpenAIResponsesModel(Model):
|
|
|
1072
1103
|
else:
|
|
1073
1104
|
assert_never(part)
|
|
1074
1105
|
elif isinstance(message, ModelResponse):
|
|
1075
|
-
|
|
1106
|
+
reasoning_item: responses.ResponseReasoningItemParam | None = None
|
|
1076
1107
|
for item in message.parts:
|
|
1077
1108
|
if isinstance(item, TextPart):
|
|
1078
1109
|
openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
|
|
1079
1110
|
elif isinstance(item, ToolCallPart):
|
|
1080
1111
|
openai_messages.append(self._map_tool_call(item))
|
|
1081
|
-
# OpenAI doesn't return built-in tool calls
|
|
1082
1112
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
1113
|
+
# We don't currently track built-in tool calls from OpenAI
|
|
1083
1114
|
pass
|
|
1084
1115
|
elif isinstance(item, ThinkingPart):
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
# )
|
|
1100
|
-
# )
|
|
1101
|
-
pass
|
|
1116
|
+
if reasoning_item is not None and item.id == reasoning_item['id']:
|
|
1117
|
+
reasoning_item['summary'] = [
|
|
1118
|
+
*reasoning_item['summary'],
|
|
1119
|
+
Summary(text=item.content, type='summary_text'),
|
|
1120
|
+
]
|
|
1121
|
+
continue
|
|
1122
|
+
|
|
1123
|
+
reasoning_item = responses.ResponseReasoningItemParam(
|
|
1124
|
+
id=item.id or _utils.generate_tool_call_id(),
|
|
1125
|
+
summary=[Summary(text=item.content, type='summary_text')],
|
|
1126
|
+
encrypted_content=item.signature if item.provider_name == self.system else None,
|
|
1127
|
+
type='reasoning',
|
|
1128
|
+
)
|
|
1129
|
+
openai_messages.append(reasoning_item)
|
|
1102
1130
|
else:
|
|
1103
1131
|
assert_never(item)
|
|
1104
1132
|
else:
|
|
@@ -1231,12 +1259,19 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1231
1259
|
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1232
1260
|
)
|
|
1233
1261
|
if maybe_event is not None: # pragma: no branch
|
|
1262
|
+
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
|
|
1263
|
+
maybe_event.part.id = 'content'
|
|
1264
|
+
maybe_event.part.provider_name = self.provider_name
|
|
1234
1265
|
yield maybe_event
|
|
1235
1266
|
|
|
1236
|
-
#
|
|
1267
|
+
# The `reasoning_content` field is only present in DeepSeek models.
|
|
1268
|
+
# https://api-docs.deepseek.com/guides/reasoning_model
|
|
1237
1269
|
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
|
|
1238
1270
|
yield self._parts_manager.handle_thinking_delta(
|
|
1239
|
-
vendor_part_id='reasoning_content',
|
|
1271
|
+
vendor_part_id='reasoning_content',
|
|
1272
|
+
id='reasoning_content',
|
|
1273
|
+
content=reasoning_content,
|
|
1274
|
+
provider_name=self.provider_name,
|
|
1240
1275
|
)
|
|
1241
1276
|
|
|
1242
1277
|
for dtc in choice.delta.tool_calls or []:
|
|
@@ -1340,7 +1375,15 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1340
1375
|
)
|
|
1341
1376
|
|
|
1342
1377
|
elif isinstance(chunk, responses.ResponseOutputItemDoneEvent):
|
|
1343
|
-
|
|
1378
|
+
if isinstance(chunk.item, responses.ResponseReasoningItem):
|
|
1379
|
+
# Add the signature to the part corresponding to the first summary item
|
|
1380
|
+
signature = chunk.item.encrypted_content
|
|
1381
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
1382
|
+
vendor_part_id=f'{chunk.item.id}-0',
|
|
1383
|
+
id=chunk.item.id,
|
|
1384
|
+
signature=signature,
|
|
1385
|
+
provider_name=self.provider_name if signature else None,
|
|
1386
|
+
)
|
|
1344
1387
|
pass
|
|
1345
1388
|
|
|
1346
1389
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
|