pydantic-ai-slim 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +50 -31
- pydantic_ai/_output.py +19 -7
- pydantic_ai/_parts_manager.py +8 -10
- pydantic_ai/_tool_manager.py +21 -0
- pydantic_ai/ag_ui.py +32 -17
- pydantic_ai/agent/__init__.py +3 -0
- pydantic_ai/agent/abstract.py +8 -0
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +721 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/durable_exec/temporal/_agent.py +1 -1
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +42 -6
- pydantic_ai/models/__init__.py +8 -0
- pydantic_ai/models/anthropic.py +79 -25
- pydantic_ai/models/bedrock.py +82 -31
- pydantic_ai/models/cohere.py +39 -13
- pydantic_ai/models/function.py +8 -1
- pydantic_ai/models/google.py +105 -37
- pydantic_ai/models/groq.py +35 -7
- pydantic_ai/models/huggingface.py +27 -5
- pydantic_ai/models/instrumented.py +27 -14
- pydantic_ai/models/mistral.py +54 -20
- pydantic_ai/models/openai.py +151 -57
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/bedrock.py +20 -4
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +7 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/METADATA +8 -6
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/RECORD +36 -31
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -23,6 +23,7 @@ from ..messages import (
|
|
|
23
23
|
BuiltinToolCallPart,
|
|
24
24
|
BuiltinToolReturnPart,
|
|
25
25
|
DocumentUrl,
|
|
26
|
+
FinishReason,
|
|
26
27
|
ImageUrl,
|
|
27
28
|
ModelMessage,
|
|
28
29
|
ModelRequest,
|
|
@@ -100,6 +101,14 @@ but allow any name in the type hints.
|
|
|
100
101
|
See <https://console.groq.com/docs/models> for an up to date date list of models and more details.
|
|
101
102
|
"""
|
|
102
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason] = {
|
|
105
|
+
'stop': 'stop',
|
|
106
|
+
'length': 'length',
|
|
107
|
+
'tool_calls': 'tool_call',
|
|
108
|
+
'content_filter': 'content_filter',
|
|
109
|
+
'function_call': 'tool_call',
|
|
110
|
+
}
|
|
111
|
+
|
|
103
112
|
|
|
104
113
|
class GroqModelSettings(ModelSettings, total=False):
|
|
105
114
|
"""Settings used for a Groq model request."""
|
|
@@ -186,7 +195,13 @@ class GroqModel(Model):
|
|
|
186
195
|
tool_name=error.error.failed_generation.name,
|
|
187
196
|
args=error.error.failed_generation.arguments,
|
|
188
197
|
)
|
|
189
|
-
return ModelResponse(
|
|
198
|
+
return ModelResponse(
|
|
199
|
+
parts=[tool_call_part],
|
|
200
|
+
model_name=e.model_name,
|
|
201
|
+
timestamp=_utils.now_utc(),
|
|
202
|
+
provider_name=self._provider.name,
|
|
203
|
+
finish_reason='error',
|
|
204
|
+
)
|
|
190
205
|
except ValidationError:
|
|
191
206
|
pass
|
|
192
207
|
raise
|
|
@@ -298,16 +313,16 @@ class GroqModel(Model):
|
|
|
298
313
|
tool_call_id = generate_tool_call_id()
|
|
299
314
|
items.append(
|
|
300
315
|
BuiltinToolCallPart(
|
|
301
|
-
tool_name=tool.type, args=tool.arguments, provider_name=
|
|
316
|
+
tool_name=tool.type, args=tool.arguments, provider_name=self.system, tool_call_id=tool_call_id
|
|
302
317
|
)
|
|
303
318
|
)
|
|
304
319
|
items.append(
|
|
305
320
|
BuiltinToolReturnPart(
|
|
306
|
-
provider_name=
|
|
321
|
+
provider_name=self.system, tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
|
|
307
322
|
)
|
|
308
323
|
)
|
|
309
|
-
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
310
324
|
if choice.message.reasoning is not None:
|
|
325
|
+
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
311
326
|
items.append(ThinkingPart(content=choice.message.reasoning))
|
|
312
327
|
if choice.message.content is not None:
|
|
313
328
|
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
|
|
@@ -315,6 +330,10 @@ class GroqModel(Model):
|
|
|
315
330
|
if choice.message.tool_calls is not None:
|
|
316
331
|
for c in choice.message.tool_calls:
|
|
317
332
|
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
333
|
+
|
|
334
|
+
raw_finish_reason = choice.finish_reason
|
|
335
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
336
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
318
337
|
return ModelResponse(
|
|
319
338
|
parts=items,
|
|
320
339
|
usage=_map_usage(response),
|
|
@@ -322,6 +341,8 @@ class GroqModel(Model):
|
|
|
322
341
|
timestamp=timestamp,
|
|
323
342
|
provider_response_id=response.id,
|
|
324
343
|
provider_name=self._provider.name,
|
|
344
|
+
finish_reason=finish_reason,
|
|
345
|
+
provider_details=provider_details,
|
|
325
346
|
)
|
|
326
347
|
|
|
327
348
|
async def _process_streamed_response(
|
|
@@ -338,7 +359,7 @@ class GroqModel(Model):
|
|
|
338
359
|
return GroqStreamedResponse(
|
|
339
360
|
model_request_parameters=model_request_parameters,
|
|
340
361
|
_response=peekable_response,
|
|
341
|
-
_model_name=
|
|
362
|
+
_model_name=first_chunk.model,
|
|
342
363
|
_model_profile=self.profile,
|
|
343
364
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
344
365
|
_provider_name=self._provider.name,
|
|
@@ -376,8 +397,8 @@ class GroqModel(Model):
|
|
|
376
397
|
elif isinstance(item, ToolCallPart):
|
|
377
398
|
tool_calls.append(self._map_tool_call(item))
|
|
378
399
|
elif isinstance(item, ThinkingPart):
|
|
379
|
-
|
|
380
|
-
|
|
400
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
401
|
+
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
381
402
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
382
403
|
# This is currently never returned from groq
|
|
383
404
|
pass
|
|
@@ -497,11 +518,18 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
497
518
|
async for chunk in self._response:
|
|
498
519
|
self._usage += _map_usage(chunk)
|
|
499
520
|
|
|
521
|
+
if chunk.id: # pragma: no branch
|
|
522
|
+
self.provider_response_id = chunk.id
|
|
523
|
+
|
|
500
524
|
try:
|
|
501
525
|
choice = chunk.choices[0]
|
|
502
526
|
except IndexError:
|
|
503
527
|
continue
|
|
504
528
|
|
|
529
|
+
if raw_finish_reason := choice.finish_reason:
|
|
530
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
531
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
532
|
+
|
|
505
533
|
# Handle the text part of the response
|
|
506
534
|
content = choice.delta.content
|
|
507
535
|
if content is not None:
|
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
BuiltinToolCallPart,
|
|
21
21
|
BuiltinToolReturnPart,
|
|
22
22
|
DocumentUrl,
|
|
23
|
+
FinishReason,
|
|
23
24
|
ImageUrl,
|
|
24
25
|
ModelMessage,
|
|
25
26
|
ModelRequest,
|
|
@@ -58,6 +59,7 @@ try:
|
|
|
58
59
|
ChatCompletionOutput,
|
|
59
60
|
ChatCompletionOutputMessage,
|
|
60
61
|
ChatCompletionStreamOutput,
|
|
62
|
+
TextGenerationOutputFinishReason,
|
|
61
63
|
)
|
|
62
64
|
from huggingface_hub.errors import HfHubHTTPError
|
|
63
65
|
|
|
@@ -94,6 +96,12 @@ HuggingFaceModelName = str | LatestHuggingFaceModelNames
|
|
|
94
96
|
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
95
97
|
"""
|
|
96
98
|
|
|
99
|
+
_FINISH_REASON_MAP: dict[TextGenerationOutputFinishReason, FinishReason] = {
|
|
100
|
+
'length': 'length',
|
|
101
|
+
'eos_token': 'stop',
|
|
102
|
+
'stop_sequence': 'stop',
|
|
103
|
+
}
|
|
104
|
+
|
|
97
105
|
|
|
98
106
|
class HuggingFaceModelSettings(ModelSettings, total=False):
|
|
99
107
|
"""Settings used for a Hugging Face model request."""
|
|
@@ -266,6 +274,11 @@ class HuggingFaceModel(Model):
|
|
|
266
274
|
if tool_calls is not None:
|
|
267
275
|
for c in tool_calls:
|
|
268
276
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
277
|
+
|
|
278
|
+
raw_finish_reason = choice.finish_reason
|
|
279
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
280
|
+
finish_reason = _FINISH_REASON_MAP.get(cast(TextGenerationOutputFinishReason, raw_finish_reason), None)
|
|
281
|
+
|
|
269
282
|
return ModelResponse(
|
|
270
283
|
parts=items,
|
|
271
284
|
usage=_map_usage(response),
|
|
@@ -273,6 +286,8 @@ class HuggingFaceModel(Model):
|
|
|
273
286
|
timestamp=timestamp,
|
|
274
287
|
provider_response_id=response.id,
|
|
275
288
|
provider_name=self._provider.name,
|
|
289
|
+
finish_reason=finish_reason,
|
|
290
|
+
provider_details=provider_details,
|
|
276
291
|
)
|
|
277
292
|
|
|
278
293
|
async def _process_streamed_response(
|
|
@@ -288,7 +303,7 @@ class HuggingFaceModel(Model):
|
|
|
288
303
|
|
|
289
304
|
return HuggingFaceStreamedResponse(
|
|
290
305
|
model_request_parameters=model_request_parameters,
|
|
291
|
-
_model_name=
|
|
306
|
+
_model_name=first_chunk.model,
|
|
292
307
|
_model_profile=self.profile,
|
|
293
308
|
_response=peekable_response,
|
|
294
309
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
@@ -316,10 +331,8 @@ class HuggingFaceModel(Model):
|
|
|
316
331
|
elif isinstance(item, ToolCallPart):
|
|
317
332
|
tool_calls.append(self._map_tool_call(item))
|
|
318
333
|
elif isinstance(item, ThinkingPart):
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
322
|
-
pass
|
|
334
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
335
|
+
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
323
336
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
324
337
|
# This is currently never returned from huggingface
|
|
325
338
|
pass
|
|
@@ -445,11 +458,20 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
445
458
|
async for chunk in self._response:
|
|
446
459
|
self._usage += _map_usage(chunk)
|
|
447
460
|
|
|
461
|
+
if chunk.id: # pragma: no branch
|
|
462
|
+
self.provider_response_id = chunk.id
|
|
463
|
+
|
|
448
464
|
try:
|
|
449
465
|
choice = chunk.choices[0]
|
|
450
466
|
except IndexError:
|
|
451
467
|
continue
|
|
452
468
|
|
|
469
|
+
if raw_finish_reason := choice.finish_reason:
|
|
470
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
471
|
+
self.finish_reason = _FINISH_REASON_MAP.get(
|
|
472
|
+
cast(TextGenerationOutputFinishReason, raw_finish_reason), None
|
|
473
|
+
)
|
|
474
|
+
|
|
453
475
|
# Handle the text part of the response
|
|
454
476
|
content = choice.delta.content
|
|
455
477
|
if content is not None:
|
|
@@ -221,7 +221,10 @@ class InstrumentationSettings:
|
|
|
221
221
|
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
|
|
222
222
|
)
|
|
223
223
|
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
224
|
-
|
|
224
|
+
otel_message = _otel_messages.OutputMessage(role='assistant', parts=message.otel_message_parts(self))
|
|
225
|
+
if message.finish_reason is not None:
|
|
226
|
+
otel_message['finish_reason'] = message.finish_reason
|
|
227
|
+
result.append(otel_message)
|
|
225
228
|
return result
|
|
226
229
|
|
|
227
230
|
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
|
|
@@ -246,12 +249,10 @@ class InstrumentationSettings:
|
|
|
246
249
|
else:
|
|
247
250
|
output_messages = self.messages_to_otel_messages([response])
|
|
248
251
|
assert len(output_messages) == 1
|
|
249
|
-
output_message =
|
|
250
|
-
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
251
|
-
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
252
|
+
output_message = output_messages[0]
|
|
252
253
|
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
253
254
|
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
254
|
-
attributes = {
|
|
255
|
+
attributes: dict[str, AttributeValue] = {
|
|
255
256
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
256
257
|
'gen_ai.output.messages': json.dumps([output_message]),
|
|
257
258
|
**system_instructions_attributes,
|
|
@@ -420,17 +421,25 @@ class InstrumentedModel(WrapperModel):
|
|
|
420
421
|
return
|
|
421
422
|
|
|
422
423
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
424
|
+
|
|
425
|
+
attributes_to_set = {
|
|
426
|
+
**response.usage.opentelemetry_attributes(),
|
|
427
|
+
'gen_ai.response.model': response_model,
|
|
428
|
+
}
|
|
423
429
|
try:
|
|
424
|
-
|
|
430
|
+
attributes_to_set['operation.cost'] = float(response.cost().total_price)
|
|
425
431
|
except LookupError:
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
'
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
432
|
+
# The cost of this provider/model is unknown, which is common.
|
|
433
|
+
pass
|
|
434
|
+
except Exception as e:
|
|
435
|
+
warnings.warn(
|
|
436
|
+
f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
|
|
437
|
+
)
|
|
438
|
+
if response.provider_response_id is not None:
|
|
439
|
+
attributes_to_set['gen_ai.response.id'] = response.provider_response_id
|
|
440
|
+
if response.finish_reason is not None:
|
|
441
|
+
attributes_to_set['gen_ai.response.finish_reasons'] = [response.finish_reason]
|
|
442
|
+
span.set_attributes(attributes_to_set)
|
|
434
443
|
span.update_name(f'{operation} {request_model}')
|
|
435
444
|
|
|
436
445
|
yield finish
|
|
@@ -478,3 +487,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
478
487
|
return str(value)
|
|
479
488
|
except Exception as e:
|
|
480
489
|
return f'Unable to serialize: {e}'
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
class CostCalculationFailedWarning(Warning):
|
|
493
|
+
"""Warning raised when cost calculation fails."""
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
|
|
14
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
15
15
|
from .._run_context import RunContext
|
|
16
|
-
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
16
|
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
18
17
|
from ..exceptions import UserError
|
|
19
18
|
from ..messages import (
|
|
@@ -21,6 +20,7 @@ from ..messages import (
|
|
|
21
20
|
BuiltinToolCallPart,
|
|
22
21
|
BuiltinToolReturnPart,
|
|
23
22
|
DocumentUrl,
|
|
23
|
+
FinishReason,
|
|
24
24
|
ImageUrl,
|
|
25
25
|
ModelMessage,
|
|
26
26
|
ModelRequest,
|
|
@@ -61,12 +61,15 @@ try:
|
|
|
61
61
|
ImageURLChunk as MistralImageURLChunk,
|
|
62
62
|
Mistral,
|
|
63
63
|
OptionalNullable as MistralOptionalNullable,
|
|
64
|
+
ReferenceChunk as MistralReferenceChunk,
|
|
64
65
|
TextChunk as MistralTextChunk,
|
|
66
|
+
ThinkChunk as MistralThinkChunk,
|
|
65
67
|
ToolChoiceEnum as MistralToolChoiceEnum,
|
|
66
68
|
)
|
|
67
69
|
from mistralai.models import (
|
|
68
70
|
ChatCompletionResponse as MistralChatCompletionResponse,
|
|
69
71
|
CompletionEvent as MistralCompletionEvent,
|
|
72
|
+
FinishReason as MistralFinishReason,
|
|
70
73
|
Messages as MistralMessages,
|
|
71
74
|
SDKError,
|
|
72
75
|
Tool as MistralTool,
|
|
@@ -98,6 +101,14 @@ allow any name in the type hints.
|
|
|
98
101
|
Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list.
|
|
99
102
|
"""
|
|
100
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[MistralFinishReason, FinishReason] = {
|
|
105
|
+
'stop': 'stop',
|
|
106
|
+
'length': 'length',
|
|
107
|
+
'model_length': 'length',
|
|
108
|
+
'error': 'error',
|
|
109
|
+
'tool_calls': 'tool_call',
|
|
110
|
+
}
|
|
111
|
+
|
|
101
112
|
|
|
102
113
|
class MistralModelSettings(ModelSettings, total=False):
|
|
103
114
|
"""Settings used for a Mistral model request."""
|
|
@@ -339,14 +350,21 @@ class MistralModel(Model):
|
|
|
339
350
|
tool_calls = choice.message.tool_calls
|
|
340
351
|
|
|
341
352
|
parts: list[ModelResponsePart] = []
|
|
342
|
-
|
|
343
|
-
|
|
353
|
+
text, thinking = _map_content(content)
|
|
354
|
+
for thought in thinking:
|
|
355
|
+
parts.append(ThinkingPart(content=thought))
|
|
356
|
+
if text:
|
|
357
|
+
parts.append(TextPart(content=text))
|
|
344
358
|
|
|
345
359
|
if isinstance(tool_calls, list):
|
|
346
360
|
for tool_call in tool_calls:
|
|
347
361
|
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
348
362
|
parts.append(tool)
|
|
349
363
|
|
|
364
|
+
raw_finish_reason = choice.finish_reason
|
|
365
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
366
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
367
|
+
|
|
350
368
|
return ModelResponse(
|
|
351
369
|
parts=parts,
|
|
352
370
|
usage=_map_usage(response),
|
|
@@ -354,6 +372,8 @@ class MistralModel(Model):
|
|
|
354
372
|
timestamp=timestamp,
|
|
355
373
|
provider_response_id=response.id,
|
|
356
374
|
provider_name=self._provider.name,
|
|
375
|
+
finish_reason=finish_reason,
|
|
376
|
+
provider_details=provider_details,
|
|
357
377
|
)
|
|
358
378
|
|
|
359
379
|
async def _process_streamed_response(
|
|
@@ -377,7 +397,7 @@ class MistralModel(Model):
|
|
|
377
397
|
return MistralStreamedResponse(
|
|
378
398
|
model_request_parameters=model_request_parameters,
|
|
379
399
|
_response=peekable_response,
|
|
380
|
-
_model_name=
|
|
400
|
+
_model_name=first_chunk.data.model,
|
|
381
401
|
_timestamp=timestamp,
|
|
382
402
|
_provider_name=self._provider.name,
|
|
383
403
|
)
|
|
@@ -503,16 +523,14 @@ class MistralModel(Model):
|
|
|
503
523
|
mistral_messages.extend(self._map_user_message(message))
|
|
504
524
|
elif isinstance(message, ModelResponse):
|
|
505
525
|
content_chunks: list[MistralContentChunk] = []
|
|
526
|
+
thinking_chunks: list[MistralTextChunk | MistralReferenceChunk] = []
|
|
506
527
|
tool_calls: list[MistralToolCall] = []
|
|
507
528
|
|
|
508
529
|
for part in message.parts:
|
|
509
530
|
if isinstance(part, TextPart):
|
|
510
531
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
511
532
|
elif isinstance(part, ThinkingPart):
|
|
512
|
-
|
|
513
|
-
# please open an issue. The below code is the code to send thinking to the provider.
|
|
514
|
-
# content_chunks.append(MistralTextChunk(text=f'<think>{part.content}</think>'))
|
|
515
|
-
pass
|
|
533
|
+
thinking_chunks.append(MistralTextChunk(text=part.content))
|
|
516
534
|
elif isinstance(part, ToolCallPart):
|
|
517
535
|
tool_calls.append(self._map_tool_call(part))
|
|
518
536
|
elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
@@ -520,6 +538,8 @@ class MistralModel(Model):
|
|
|
520
538
|
pass
|
|
521
539
|
else:
|
|
522
540
|
assert_never(part)
|
|
541
|
+
if thinking_chunks:
|
|
542
|
+
content_chunks.insert(0, MistralThinkChunk(thinking=thinking_chunks))
|
|
523
543
|
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
524
544
|
else:
|
|
525
545
|
assert_never(message)
|
|
@@ -595,14 +615,23 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
595
615
|
async for chunk in self._response:
|
|
596
616
|
self._usage += _map_usage(chunk.data)
|
|
597
617
|
|
|
618
|
+
if chunk.data.id: # pragma: no branch
|
|
619
|
+
self.provider_response_id = chunk.data.id
|
|
620
|
+
|
|
598
621
|
try:
|
|
599
622
|
choice = chunk.data.choices[0]
|
|
600
623
|
except IndexError:
|
|
601
624
|
continue
|
|
602
625
|
|
|
626
|
+
if raw_finish_reason := choice.finish_reason:
|
|
627
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
628
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
629
|
+
|
|
603
630
|
# Handle the text part of the response
|
|
604
631
|
content = choice.delta.content
|
|
605
|
-
text = _map_content(content)
|
|
632
|
+
text, thinking = _map_content(content)
|
|
633
|
+
for thought in thinking:
|
|
634
|
+
self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
|
|
606
635
|
if text:
|
|
607
636
|
# Attempt to produce an output tool call from the received text
|
|
608
637
|
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
|
|
@@ -715,32 +744,37 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
|
|
|
715
744
|
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
716
745
|
if response.usage:
|
|
717
746
|
return RequestUsage(
|
|
718
|
-
input_tokens=response.usage.prompt_tokens,
|
|
719
|
-
output_tokens=response.usage.completion_tokens,
|
|
747
|
+
input_tokens=response.usage.prompt_tokens or 0,
|
|
748
|
+
output_tokens=response.usage.completion_tokens or 0,
|
|
720
749
|
)
|
|
721
750
|
else:
|
|
722
|
-
return RequestUsage()
|
|
751
|
+
return RequestUsage()
|
|
723
752
|
|
|
724
753
|
|
|
725
|
-
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
754
|
+
def _map_content(content: MistralOptionalNullable[MistralContent]) -> tuple[str | None, list[str]]:
|
|
726
755
|
"""Maps the delta content from a Mistral Completion Chunk to a string or None."""
|
|
727
|
-
|
|
756
|
+
text: str | None = None
|
|
757
|
+
thinking: list[str] = []
|
|
728
758
|
|
|
729
759
|
if isinstance(content, MistralUnset) or not content:
|
|
730
|
-
|
|
760
|
+
return None, []
|
|
731
761
|
elif isinstance(content, list):
|
|
732
762
|
for chunk in content:
|
|
733
763
|
if isinstance(chunk, MistralTextChunk):
|
|
734
|
-
|
|
764
|
+
text = text or '' + chunk.text
|
|
765
|
+
elif isinstance(chunk, MistralThinkChunk):
|
|
766
|
+
for thought in chunk.thinking:
|
|
767
|
+
if thought.type == 'text': # pragma: no branch
|
|
768
|
+
thinking.append(thought.text)
|
|
735
769
|
else:
|
|
736
770
|
assert False, ( # pragma: no cover
|
|
737
771
|
f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
738
772
|
)
|
|
739
773
|
elif isinstance(content, str):
|
|
740
|
-
|
|
774
|
+
text = content
|
|
741
775
|
|
|
742
776
|
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
743
|
-
if
|
|
744
|
-
|
|
777
|
+
if text and len(text) == 0: # pragma: no cover
|
|
778
|
+
text = None
|
|
745
779
|
|
|
746
|
-
return
|
|
780
|
+
return text, thinking
|