pydantic-ai-slim 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_otel_messages.py +67 -0
- pydantic_ai/agent/__init__.py +11 -4
- pydantic_ai/builtin_tools.py +1 -0
- pydantic_ai/durable_exec/temporal/_model.py +4 -0
- pydantic_ai/messages.py +109 -18
- pydantic_ai/models/__init__.py +27 -9
- pydantic_ai/models/anthropic.py +20 -8
- pydantic_ai/models/bedrock.py +16 -10
- pydantic_ai/models/cohere.py +3 -1
- pydantic_ai/models/function.py +5 -0
- pydantic_ai/models/gemini.py +8 -1
- pydantic_ai/models/google.py +21 -4
- pydantic_ai/models/groq.py +8 -0
- pydantic_ai/models/huggingface.py +8 -0
- pydantic_ai/models/instrumented.py +103 -42
- pydantic_ai/models/mistral.py +8 -0
- pydantic_ai/models/openai.py +80 -36
- pydantic_ai/models/test.py +7 -0
- pydantic_ai/profiles/__init__.py +1 -1
- pydantic_ai/profiles/harmony.py +13 -0
- pydantic_ai/profiles/openai.py +6 -1
- pydantic_ai/profiles/qwen.py +8 -0
- pydantic_ai/providers/__init__.py +5 -1
- pydantic_ai/providers/anthropic.py +11 -8
- pydantic_ai/providers/azure.py +1 -1
- pydantic_ai/providers/cerebras.py +96 -0
- pydantic_ai/providers/cohere.py +2 -2
- pydantic_ai/providers/deepseek.py +4 -4
- pydantic_ai/providers/fireworks.py +3 -3
- pydantic_ai/providers/github.py +4 -4
- pydantic_ai/providers/grok.py +3 -3
- pydantic_ai/providers/groq.py +3 -3
- pydantic_ai/providers/heroku.py +3 -3
- pydantic_ai/providers/mistral.py +3 -3
- pydantic_ai/providers/moonshotai.py +3 -6
- pydantic_ai/providers/ollama.py +1 -1
- pydantic_ai/providers/openrouter.py +4 -4
- pydantic_ai/providers/together.py +3 -3
- pydantic_ai/providers/vercel.py +4 -4
- pydantic_ai/retries.py +154 -42
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/RECORD +45 -42
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -395,6 +395,7 @@ class GoogleModel(Model):
|
|
|
395
395
|
return _process_response_from_parts(
|
|
396
396
|
parts,
|
|
397
397
|
response.model_version or self._model_name,
|
|
398
|
+
self._provider.name,
|
|
398
399
|
usage,
|
|
399
400
|
vendor_id=vendor_id,
|
|
400
401
|
vendor_details=vendor_details,
|
|
@@ -414,6 +415,7 @@ class GoogleModel(Model):
|
|
|
414
415
|
_model_name=self._model_name,
|
|
415
416
|
_response=peekable_response,
|
|
416
417
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
418
|
+
_provider_name=self._provider.name,
|
|
417
419
|
)
|
|
418
420
|
|
|
419
421
|
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
|
|
@@ -523,6 +525,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
523
525
|
_model_name: GoogleModelName
|
|
524
526
|
_response: AsyncIterator[GenerateContentResponse]
|
|
525
527
|
_timestamp: datetime
|
|
528
|
+
_provider_name: str
|
|
526
529
|
|
|
527
530
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
528
531
|
async for chunk in self._response:
|
|
@@ -531,7 +534,10 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
531
534
|
assert chunk.candidates is not None
|
|
532
535
|
candidate = chunk.candidates[0]
|
|
533
536
|
if candidate.content is None or candidate.content.parts is None:
|
|
534
|
-
if candidate.finish_reason == '
|
|
537
|
+
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
538
|
+
# Normal completion - skip this chunk
|
|
539
|
+
continue
|
|
540
|
+
elif candidate.finish_reason == 'SAFETY': # pragma: no cover
|
|
535
541
|
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
|
|
536
542
|
else: # pragma: no cover
|
|
537
543
|
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
@@ -561,6 +567,11 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
561
567
|
"""Get the model name of the response."""
|
|
562
568
|
return self._model_name
|
|
563
569
|
|
|
570
|
+
@property
|
|
571
|
+
def provider_name(self) -> str:
|
|
572
|
+
"""Get the provider name."""
|
|
573
|
+
return self._provider_name
|
|
574
|
+
|
|
564
575
|
@property
|
|
565
576
|
def timestamp(self) -> datetime:
|
|
566
577
|
"""Get the timestamp of the response."""
|
|
@@ -596,6 +607,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
596
607
|
def _process_response_from_parts(
|
|
597
608
|
parts: list[Part],
|
|
598
609
|
model_name: GoogleModelName,
|
|
610
|
+
provider_name: str,
|
|
599
611
|
usage: usage.RequestUsage,
|
|
600
612
|
vendor_id: str | None,
|
|
601
613
|
vendor_details: dict[str, Any] | None = None,
|
|
@@ -633,7 +645,12 @@ def _process_response_from_parts(
|
|
|
633
645
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
634
646
|
)
|
|
635
647
|
return ModelResponse(
|
|
636
|
-
parts=items,
|
|
648
|
+
parts=items,
|
|
649
|
+
model_name=model_name,
|
|
650
|
+
usage=usage,
|
|
651
|
+
provider_request_id=vendor_id,
|
|
652
|
+
provider_details=vendor_details,
|
|
653
|
+
provider_name=provider_name,
|
|
637
654
|
)
|
|
638
655
|
|
|
639
656
|
|
|
@@ -661,7 +678,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
661
678
|
if cached_content_token_count := metadata.cached_content_token_count:
|
|
662
679
|
details['cached_content_tokens'] = cached_content_token_count
|
|
663
680
|
|
|
664
|
-
if thoughts_token_count := metadata.thoughts_token_count:
|
|
681
|
+
if thoughts_token_count := (metadata.thoughts_token_count or 0):
|
|
665
682
|
details['thoughts_tokens'] = thoughts_token_count
|
|
666
683
|
|
|
667
684
|
if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
|
|
@@ -694,7 +711,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
694
711
|
|
|
695
712
|
return usage.RequestUsage(
|
|
696
713
|
input_tokens=metadata.prompt_token_count or 0,
|
|
697
|
-
output_tokens=metadata.candidates_token_count or 0,
|
|
714
|
+
output_tokens=(metadata.candidates_token_count or 0) + thoughts_token_count,
|
|
698
715
|
cache_read_tokens=cached_content_token_count or 0,
|
|
699
716
|
input_audio_tokens=input_audio_tokens,
|
|
700
717
|
output_audio_tokens=output_audio_tokens,
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -290,6 +290,7 @@ class GroqModel(Model):
|
|
|
290
290
|
model_name=response.model,
|
|
291
291
|
timestamp=timestamp,
|
|
292
292
|
provider_request_id=response.id,
|
|
293
|
+
provider_name=self._provider.name,
|
|
293
294
|
)
|
|
294
295
|
|
|
295
296
|
async def _process_streamed_response(
|
|
@@ -309,6 +310,7 @@ class GroqModel(Model):
|
|
|
309
310
|
_model_name=self._model_name,
|
|
310
311
|
_model_profile=self.profile,
|
|
311
312
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
313
|
+
_provider_name=self._provider.name,
|
|
312
314
|
)
|
|
313
315
|
|
|
314
316
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -444,6 +446,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
444
446
|
_model_profile: ModelProfile
|
|
445
447
|
_response: AsyncIterable[chat.ChatCompletionChunk]
|
|
446
448
|
_timestamp: datetime
|
|
449
|
+
_provider_name: str
|
|
447
450
|
|
|
448
451
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
449
452
|
async for chunk in self._response:
|
|
@@ -482,6 +485,11 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
482
485
|
"""Get the model name of the response."""
|
|
483
486
|
return self._model_name
|
|
484
487
|
|
|
488
|
+
@property
|
|
489
|
+
def provider_name(self) -> str:
|
|
490
|
+
"""Get the provider name."""
|
|
491
|
+
return self._provider_name
|
|
492
|
+
|
|
485
493
|
@property
|
|
486
494
|
def timestamp(self) -> datetime:
|
|
487
495
|
"""Get the timestamp of the response."""
|
|
@@ -272,6 +272,7 @@ class HuggingFaceModel(Model):
|
|
|
272
272
|
model_name=response.model,
|
|
273
273
|
timestamp=timestamp,
|
|
274
274
|
provider_request_id=response.id,
|
|
275
|
+
provider_name=self._provider.name,
|
|
275
276
|
)
|
|
276
277
|
|
|
277
278
|
async def _process_streamed_response(
|
|
@@ -291,6 +292,7 @@ class HuggingFaceModel(Model):
|
|
|
291
292
|
_model_profile=self.profile,
|
|
292
293
|
_response=peekable_response,
|
|
293
294
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
295
|
+
_provider_name=self._provider.name,
|
|
294
296
|
)
|
|
295
297
|
|
|
296
298
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
@@ -437,6 +439,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
437
439
|
_model_profile: ModelProfile
|
|
438
440
|
_response: AsyncIterable[ChatCompletionStreamOutput]
|
|
439
441
|
_timestamp: datetime
|
|
442
|
+
_provider_name: str
|
|
440
443
|
|
|
441
444
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
442
445
|
async for chunk in self._response:
|
|
@@ -474,6 +477,11 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
474
477
|
"""Get the model name of the response."""
|
|
475
478
|
return self._model_name
|
|
476
479
|
|
|
480
|
+
@property
|
|
481
|
+
def provider_name(self) -> str:
|
|
482
|
+
"""Get the provider name."""
|
|
483
|
+
return self._provider_name
|
|
484
|
+
|
|
477
485
|
@property
|
|
478
486
|
def timestamp(self) -> datetime:
|
|
479
487
|
"""Get the timestamp of the response."""
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import itertools
|
|
3
4
|
import json
|
|
4
5
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
|
5
6
|
from contextlib import asynccontextmanager, contextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Callable, Literal
|
|
8
|
+
from typing import Any, Callable, Literal, cast
|
|
8
9
|
from urllib.parse import urlparse
|
|
9
10
|
|
|
10
11
|
from opentelemetry._events import (
|
|
@@ -18,8 +19,14 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
|
|
|
18
19
|
from opentelemetry.util.types import AttributeValue
|
|
19
20
|
from pydantic import TypeAdapter
|
|
20
21
|
|
|
22
|
+
from .. import _otel_messages
|
|
21
23
|
from .._run_context import RunContext
|
|
22
|
-
from ..messages import
|
|
24
|
+
from ..messages import (
|
|
25
|
+
ModelMessage,
|
|
26
|
+
ModelRequest,
|
|
27
|
+
ModelResponse,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
)
|
|
23
30
|
from ..settings import ModelSettings
|
|
24
31
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
25
32
|
from .wrapper import WrapperModel
|
|
@@ -80,6 +87,8 @@ class InstrumentationSettings:
|
|
|
80
87
|
event_logger: EventLogger = field(repr=False)
|
|
81
88
|
event_mode: Literal['attributes', 'logs'] = 'attributes'
|
|
82
89
|
include_binary_content: bool = True
|
|
90
|
+
include_content: bool = True
|
|
91
|
+
version: Literal[1, 2] = 1
|
|
83
92
|
|
|
84
93
|
def __init__(
|
|
85
94
|
self,
|
|
@@ -90,6 +99,7 @@ class InstrumentationSettings:
|
|
|
90
99
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
91
100
|
include_binary_content: bool = True,
|
|
92
101
|
include_content: bool = True,
|
|
102
|
+
version: Literal[1, 2] = 1,
|
|
93
103
|
):
|
|
94
104
|
"""Create instrumentation options.
|
|
95
105
|
|
|
@@ -109,6 +119,10 @@ class InstrumentationSettings:
|
|
|
109
119
|
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
110
120
|
include_content: Whether to include prompts, completions, and tool call arguments and responses
|
|
111
121
|
in the instrumentation events.
|
|
122
|
+
version: Version of the data format.
|
|
123
|
+
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
|
|
124
|
+
Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
|
|
125
|
+
Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
|
|
112
126
|
"""
|
|
113
127
|
from pydantic_ai import __version__
|
|
114
128
|
|
|
@@ -122,6 +136,7 @@ class InstrumentationSettings:
|
|
|
122
136
|
self.event_mode = event_mode
|
|
123
137
|
self.include_binary_content = include_binary_content
|
|
124
138
|
self.include_content = include_content
|
|
139
|
+
self.version = version
|
|
125
140
|
|
|
126
141
|
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
127
142
|
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
|
|
@@ -179,6 +194,90 @@ class InstrumentationSettings:
|
|
|
179
194
|
event.body = InstrumentedModel.serialize_any(event.body)
|
|
180
195
|
return events
|
|
181
196
|
|
|
197
|
+
def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_messages.ChatMessage]:
|
|
198
|
+
result: list[_otel_messages.ChatMessage] = []
|
|
199
|
+
for message in messages:
|
|
200
|
+
if isinstance(message, ModelRequest):
|
|
201
|
+
for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)):
|
|
202
|
+
message_parts: list[_otel_messages.MessagePart] = []
|
|
203
|
+
for part in group:
|
|
204
|
+
if hasattr(part, 'otel_message_parts'):
|
|
205
|
+
message_parts.extend(part.otel_message_parts(self))
|
|
206
|
+
result.append(
|
|
207
|
+
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
|
|
208
|
+
)
|
|
209
|
+
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
210
|
+
result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
|
|
214
|
+
if self.version == 1:
|
|
215
|
+
events = self.messages_to_otel_events(input_messages)
|
|
216
|
+
for event in self.messages_to_otel_events([response]):
|
|
217
|
+
events.append(
|
|
218
|
+
Event(
|
|
219
|
+
'gen_ai.choice',
|
|
220
|
+
body={
|
|
221
|
+
'index': 0,
|
|
222
|
+
'message': event.body,
|
|
223
|
+
},
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
for event in events:
|
|
227
|
+
event.attributes = {
|
|
228
|
+
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
229
|
+
**(event.attributes or {}),
|
|
230
|
+
}
|
|
231
|
+
self._emit_events(span, events)
|
|
232
|
+
else:
|
|
233
|
+
output_messages = self.messages_to_otel_messages([response])
|
|
234
|
+
assert len(output_messages) == 1
|
|
235
|
+
output_message = cast(_otel_messages.OutputMessage, output_messages[0])
|
|
236
|
+
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
237
|
+
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
238
|
+
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
239
|
+
attributes = {
|
|
240
|
+
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
241
|
+
'gen_ai.output.messages': json.dumps([output_message]),
|
|
242
|
+
'logfire.json_schema': json.dumps(
|
|
243
|
+
{
|
|
244
|
+
'type': 'object',
|
|
245
|
+
'properties': {
|
|
246
|
+
'gen_ai.input.messages': {'type': 'array'},
|
|
247
|
+
'gen_ai.output.messages': {'type': 'array'},
|
|
248
|
+
**({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
|
|
249
|
+
'model_request_parameters': {'type': 'object'},
|
|
250
|
+
},
|
|
251
|
+
}
|
|
252
|
+
),
|
|
253
|
+
}
|
|
254
|
+
if instructions is not None:
|
|
255
|
+
attributes['gen_ai.system_instructions'] = json.dumps(
|
|
256
|
+
[_otel_messages.TextPart(type='text', content=instructions)]
|
|
257
|
+
)
|
|
258
|
+
span.set_attributes(attributes)
|
|
259
|
+
|
|
260
|
+
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
261
|
+
if self.event_mode == 'logs':
|
|
262
|
+
for event in events:
|
|
263
|
+
self.event_logger.emit(event)
|
|
264
|
+
else:
|
|
265
|
+
attr_name = 'events'
|
|
266
|
+
span.set_attributes(
|
|
267
|
+
{
|
|
268
|
+
attr_name: json.dumps([InstrumentedModel.event_to_dict(event) for event in events]),
|
|
269
|
+
'logfire.json_schema': json.dumps(
|
|
270
|
+
{
|
|
271
|
+
'type': 'object',
|
|
272
|
+
'properties': {
|
|
273
|
+
attr_name: {'type': 'array'},
|
|
274
|
+
'model_request_parameters': {'type': 'object'},
|
|
275
|
+
},
|
|
276
|
+
}
|
|
277
|
+
),
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
|
|
182
281
|
|
|
183
282
|
GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
|
|
184
283
|
GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
|
|
@@ -269,7 +368,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
269
368
|
# FallbackModel updates these span attributes.
|
|
270
369
|
attributes.update(getattr(span, 'attributes', {}))
|
|
271
370
|
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
272
|
-
system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
|
|
371
|
+
system = cast(str, attributes[GEN_AI_SYSTEM_ATTRIBUTE])
|
|
273
372
|
|
|
274
373
|
response_model = response.model_name or request_model
|
|
275
374
|
|
|
@@ -297,18 +396,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
297
396
|
if not span.is_recording():
|
|
298
397
|
return
|
|
299
398
|
|
|
300
|
-
|
|
301
|
-
for event in self.instrumentation_settings.messages_to_otel_events([response]):
|
|
302
|
-
events.append(
|
|
303
|
-
Event(
|
|
304
|
-
'gen_ai.choice',
|
|
305
|
-
body={
|
|
306
|
-
# TODO finish_reason
|
|
307
|
-
'index': 0,
|
|
308
|
-
'message': event.body,
|
|
309
|
-
},
|
|
310
|
-
)
|
|
311
|
-
)
|
|
399
|
+
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
312
400
|
span.set_attributes(
|
|
313
401
|
{
|
|
314
402
|
**response.usage.opentelemetry_attributes(),
|
|
@@ -316,12 +404,6 @@ class InstrumentedModel(WrapperModel):
|
|
|
316
404
|
}
|
|
317
405
|
)
|
|
318
406
|
span.update_name(f'{operation} {request_model}')
|
|
319
|
-
for event in events:
|
|
320
|
-
event.attributes = {
|
|
321
|
-
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
322
|
-
**(event.attributes or {}),
|
|
323
|
-
}
|
|
324
|
-
self._emit_events(span, events)
|
|
325
407
|
|
|
326
408
|
yield finish
|
|
327
409
|
finally:
|
|
@@ -330,27 +412,6 @@ class InstrumentedModel(WrapperModel):
|
|
|
330
412
|
# to prevent them from being redundantly recorded in the span itself by logfire.
|
|
331
413
|
record_metrics()
|
|
332
414
|
|
|
333
|
-
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
334
|
-
if self.instrumentation_settings.event_mode == 'logs':
|
|
335
|
-
for event in events:
|
|
336
|
-
self.instrumentation_settings.event_logger.emit(event)
|
|
337
|
-
else:
|
|
338
|
-
attr_name = 'events'
|
|
339
|
-
span.set_attributes(
|
|
340
|
-
{
|
|
341
|
-
attr_name: json.dumps([self.event_to_dict(event) for event in events]),
|
|
342
|
-
'logfire.json_schema': json.dumps(
|
|
343
|
-
{
|
|
344
|
-
'type': 'object',
|
|
345
|
-
'properties': {
|
|
346
|
-
attr_name: {'type': 'array'},
|
|
347
|
-
'model_request_parameters': {'type': 'object'},
|
|
348
|
-
},
|
|
349
|
-
}
|
|
350
|
-
),
|
|
351
|
-
}
|
|
352
|
-
)
|
|
353
|
-
|
|
354
415
|
@staticmethod
|
|
355
416
|
def model_attributes(model: Model):
|
|
356
417
|
attributes: dict[str, AttributeValue] = {
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -353,6 +353,7 @@ class MistralModel(Model):
|
|
|
353
353
|
model_name=response.model,
|
|
354
354
|
timestamp=timestamp,
|
|
355
355
|
provider_request_id=response.id,
|
|
356
|
+
provider_name=self._provider.name,
|
|
356
357
|
)
|
|
357
358
|
|
|
358
359
|
async def _process_streamed_response(
|
|
@@ -378,6 +379,7 @@ class MistralModel(Model):
|
|
|
378
379
|
_response=peekable_response,
|
|
379
380
|
_model_name=self._model_name,
|
|
380
381
|
_timestamp=timestamp,
|
|
382
|
+
_provider_name=self._provider.name,
|
|
381
383
|
)
|
|
382
384
|
|
|
383
385
|
@staticmethod
|
|
@@ -584,6 +586,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
584
586
|
_model_name: MistralModelName
|
|
585
587
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
586
588
|
_timestamp: datetime
|
|
589
|
+
_provider_name: str
|
|
587
590
|
|
|
588
591
|
_delta_content: str = field(default='', init=False)
|
|
589
592
|
|
|
@@ -631,6 +634,11 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
631
634
|
"""Get the model name of the response."""
|
|
632
635
|
return self._model_name
|
|
633
636
|
|
|
637
|
+
@property
|
|
638
|
+
def provider_name(self) -> str:
|
|
639
|
+
"""Get the provider name."""
|
|
640
|
+
return self._provider_name
|
|
641
|
+
|
|
634
642
|
@property
|
|
635
643
|
def timestamp(self) -> datetime:
|
|
636
644
|
"""Get the timestamp of the response."""
|