pydantic-ai-slim 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -1
- pydantic_ai/_otel_messages.py +67 -0
- pydantic_ai/builtin_tools.py +10 -1
- pydantic_ai/durable_exec/temporal/_model.py +4 -0
- pydantic_ai/messages.py +109 -18
- pydantic_ai/models/__init__.py +7 -0
- pydantic_ai/models/anthropic.py +18 -6
- pydantic_ai/models/bedrock.py +15 -9
- pydantic_ai/models/cohere.py +3 -1
- pydantic_ai/models/function.py +5 -0
- pydantic_ai/models/gemini.py +8 -1
- pydantic_ai/models/google.py +37 -13
- pydantic_ai/models/groq.py +8 -0
- pydantic_ai/models/huggingface.py +8 -0
- pydantic_ai/models/instrumented.py +103 -42
- pydantic_ai/models/mistral.py +8 -0
- pydantic_ai/models/openai.py +18 -0
- pydantic_ai/models/test.py +7 -0
- pydantic_ai/providers/anthropic.py +11 -8
- pydantic_ai/tools.py +5 -2
- pydantic_ai/usage.py +1 -1
- {pydantic_ai_slim-0.7.3.dist-info → pydantic_ai_slim-0.7.5.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.7.3.dist-info → pydantic_ai_slim-0.7.5.dist-info}/RECORD +26 -25
- {pydantic_ai_slim-0.7.3.dist-info → pydantic_ai_slim-0.7.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.3.dist-info → pydantic_ai_slim-0.7.5.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.3.dist-info → pydantic_ai_slim-0.7.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -13,7 +13,7 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._output import OutputObjectDefinition
|
|
15
15
|
from .._run_context import RunContext
|
|
16
|
-
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
16
|
+
from ..builtin_tools import CodeExecutionTool, UrlContextTool, WebSearchTool
|
|
17
17
|
from ..exceptions import UserError
|
|
18
18
|
from ..messages import (
|
|
19
19
|
BinaryContent,
|
|
@@ -72,6 +72,7 @@ try:
|
|
|
72
72
|
ToolConfigDict,
|
|
73
73
|
ToolDict,
|
|
74
74
|
ToolListUnionDict,
|
|
75
|
+
UrlContextDict,
|
|
75
76
|
)
|
|
76
77
|
|
|
77
78
|
from ..providers.google import GoogleProvider
|
|
@@ -218,7 +219,7 @@ class GoogleModel(Model):
|
|
|
218
219
|
)
|
|
219
220
|
if self._provider.name != 'google-gla':
|
|
220
221
|
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
|
|
221
|
-
config.update(
|
|
222
|
+
config.update( # pragma: lax no cover
|
|
222
223
|
system_instruction=generation_config.get('system_instruction'),
|
|
223
224
|
tools=cast(list[ToolDict], generation_config.get('tools')),
|
|
224
225
|
# Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
|
|
@@ -270,6 +271,8 @@ class GoogleModel(Model):
|
|
|
270
271
|
for tool in model_request_parameters.builtin_tools:
|
|
271
272
|
if isinstance(tool, WebSearchTool):
|
|
272
273
|
tools.append(ToolDict(google_search=GoogleSearchDict()))
|
|
274
|
+
elif isinstance(tool, UrlContextTool):
|
|
275
|
+
tools.append(ToolDict(url_context=UrlContextDict()))
|
|
273
276
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
274
277
|
tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
|
|
275
278
|
else: # pragma: no cover
|
|
@@ -374,23 +377,25 @@ class GoogleModel(Model):
|
|
|
374
377
|
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
|
|
375
378
|
if not response.candidates or len(response.candidates) != 1:
|
|
376
379
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
|
|
377
|
-
|
|
378
|
-
|
|
380
|
+
candidate = response.candidates[0]
|
|
381
|
+
if candidate.content is None or candidate.content.parts is None:
|
|
382
|
+
if candidate.finish_reason == 'SAFETY':
|
|
379
383
|
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
|
|
380
384
|
else:
|
|
381
385
|
raise UnexpectedModelBehavior(
|
|
382
386
|
'Content field missing from Gemini response', str(response)
|
|
383
387
|
) # pragma: no cover
|
|
384
|
-
parts =
|
|
388
|
+
parts = candidate.content.parts or []
|
|
385
389
|
vendor_id = response.response_id or None
|
|
386
390
|
vendor_details: dict[str, Any] | None = None
|
|
387
|
-
finish_reason =
|
|
391
|
+
finish_reason = candidate.finish_reason
|
|
388
392
|
if finish_reason: # pragma: no branch
|
|
389
393
|
vendor_details = {'finish_reason': finish_reason.value}
|
|
390
394
|
usage = _metadata_as_usage(response)
|
|
391
395
|
return _process_response_from_parts(
|
|
392
396
|
parts,
|
|
393
397
|
response.model_version or self._model_name,
|
|
398
|
+
self._provider.name,
|
|
394
399
|
usage,
|
|
395
400
|
vendor_id=vendor_id,
|
|
396
401
|
vendor_details=vendor_details,
|
|
@@ -410,6 +415,7 @@ class GoogleModel(Model):
|
|
|
410
415
|
_model_name=self._model_name,
|
|
411
416
|
_response=peekable_response,
|
|
412
417
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
418
|
+
_provider_name=self._provider.name,
|
|
413
419
|
)
|
|
414
420
|
|
|
415
421
|
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
|
|
@@ -519,6 +525,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
519
525
|
_model_name: GoogleModelName
|
|
520
526
|
_response: AsyncIterator[GenerateContentResponse]
|
|
521
527
|
_timestamp: datetime
|
|
528
|
+
_provider_name: str
|
|
522
529
|
|
|
523
530
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
524
531
|
async for chunk in self._response:
|
|
@@ -526,10 +533,16 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
526
533
|
|
|
527
534
|
assert chunk.candidates is not None
|
|
528
535
|
candidate = chunk.candidates[0]
|
|
529
|
-
if candidate.content is None:
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
536
|
+
if candidate.content is None or candidate.content.parts is None:
|
|
537
|
+
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
538
|
+
# Normal completion - skip this chunk
|
|
539
|
+
continue
|
|
540
|
+
elif candidate.finish_reason == 'SAFETY': # pragma: no cover
|
|
541
|
+
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
|
|
542
|
+
else: # pragma: no cover
|
|
543
|
+
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
544
|
+
parts = candidate.content.parts or []
|
|
545
|
+
for part in parts:
|
|
533
546
|
if part.text is not None:
|
|
534
547
|
if part.thought:
|
|
535
548
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
@@ -554,6 +567,11 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
554
567
|
"""Get the model name of the response."""
|
|
555
568
|
return self._model_name
|
|
556
569
|
|
|
570
|
+
@property
|
|
571
|
+
def provider_name(self) -> str:
|
|
572
|
+
"""Get the provider name."""
|
|
573
|
+
return self._provider_name
|
|
574
|
+
|
|
557
575
|
@property
|
|
558
576
|
def timestamp(self) -> datetime:
|
|
559
577
|
"""Get the timestamp of the response."""
|
|
@@ -589,6 +607,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
589
607
|
def _process_response_from_parts(
|
|
590
608
|
parts: list[Part],
|
|
591
609
|
model_name: GoogleModelName,
|
|
610
|
+
provider_name: str,
|
|
592
611
|
usage: usage.RequestUsage,
|
|
593
612
|
vendor_id: str | None,
|
|
594
613
|
vendor_details: dict[str, Any] | None = None,
|
|
@@ -626,7 +645,12 @@ def _process_response_from_parts(
|
|
|
626
645
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
627
646
|
)
|
|
628
647
|
return ModelResponse(
|
|
629
|
-
parts=items,
|
|
648
|
+
parts=items,
|
|
649
|
+
model_name=model_name,
|
|
650
|
+
usage=usage,
|
|
651
|
+
provider_request_id=vendor_id,
|
|
652
|
+
provider_details=vendor_details,
|
|
653
|
+
provider_name=provider_name,
|
|
630
654
|
)
|
|
631
655
|
|
|
632
656
|
|
|
@@ -654,7 +678,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
654
678
|
if cached_content_token_count := metadata.cached_content_token_count:
|
|
655
679
|
details['cached_content_tokens'] = cached_content_token_count
|
|
656
680
|
|
|
657
|
-
if thoughts_token_count := metadata.thoughts_token_count:
|
|
681
|
+
if thoughts_token_count := (metadata.thoughts_token_count or 0):
|
|
658
682
|
details['thoughts_tokens'] = thoughts_token_count
|
|
659
683
|
|
|
660
684
|
if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
|
|
@@ -687,7 +711,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
687
711
|
|
|
688
712
|
return usage.RequestUsage(
|
|
689
713
|
input_tokens=metadata.prompt_token_count or 0,
|
|
690
|
-
output_tokens=metadata.candidates_token_count or 0,
|
|
714
|
+
output_tokens=(metadata.candidates_token_count or 0) + thoughts_token_count,
|
|
691
715
|
cache_read_tokens=cached_content_token_count or 0,
|
|
692
716
|
input_audio_tokens=input_audio_tokens,
|
|
693
717
|
output_audio_tokens=output_audio_tokens,
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -290,6 +290,7 @@ class GroqModel(Model):
|
|
|
290
290
|
model_name=response.model,
|
|
291
291
|
timestamp=timestamp,
|
|
292
292
|
provider_request_id=response.id,
|
|
293
|
+
provider_name=self._provider.name,
|
|
293
294
|
)
|
|
294
295
|
|
|
295
296
|
async def _process_streamed_response(
|
|
@@ -309,6 +310,7 @@ class GroqModel(Model):
|
|
|
309
310
|
_model_name=self._model_name,
|
|
310
311
|
_model_profile=self.profile,
|
|
311
312
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
313
|
+
_provider_name=self._provider.name,
|
|
312
314
|
)
|
|
313
315
|
|
|
314
316
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -444,6 +446,7 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
444
446
|
_model_profile: ModelProfile
|
|
445
447
|
_response: AsyncIterable[chat.ChatCompletionChunk]
|
|
446
448
|
_timestamp: datetime
|
|
449
|
+
_provider_name: str
|
|
447
450
|
|
|
448
451
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
449
452
|
async for chunk in self._response:
|
|
@@ -482,6 +485,11 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
482
485
|
"""Get the model name of the response."""
|
|
483
486
|
return self._model_name
|
|
484
487
|
|
|
488
|
+
@property
|
|
489
|
+
def provider_name(self) -> str:
|
|
490
|
+
"""Get the provider name."""
|
|
491
|
+
return self._provider_name
|
|
492
|
+
|
|
485
493
|
@property
|
|
486
494
|
def timestamp(self) -> datetime:
|
|
487
495
|
"""Get the timestamp of the response."""
|
|
@@ -272,6 +272,7 @@ class HuggingFaceModel(Model):
|
|
|
272
272
|
model_name=response.model,
|
|
273
273
|
timestamp=timestamp,
|
|
274
274
|
provider_request_id=response.id,
|
|
275
|
+
provider_name=self._provider.name,
|
|
275
276
|
)
|
|
276
277
|
|
|
277
278
|
async def _process_streamed_response(
|
|
@@ -291,6 +292,7 @@ class HuggingFaceModel(Model):
|
|
|
291
292
|
_model_profile=self.profile,
|
|
292
293
|
_response=peekable_response,
|
|
293
294
|
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
295
|
+
_provider_name=self._provider.name,
|
|
294
296
|
)
|
|
295
297
|
|
|
296
298
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
@@ -437,6 +439,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
437
439
|
_model_profile: ModelProfile
|
|
438
440
|
_response: AsyncIterable[ChatCompletionStreamOutput]
|
|
439
441
|
_timestamp: datetime
|
|
442
|
+
_provider_name: str
|
|
440
443
|
|
|
441
444
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
442
445
|
async for chunk in self._response:
|
|
@@ -474,6 +477,11 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
474
477
|
"""Get the model name of the response."""
|
|
475
478
|
return self._model_name
|
|
476
479
|
|
|
480
|
+
@property
|
|
481
|
+
def provider_name(self) -> str:
|
|
482
|
+
"""Get the provider name."""
|
|
483
|
+
return self._provider_name
|
|
484
|
+
|
|
477
485
|
@property
|
|
478
486
|
def timestamp(self) -> datetime:
|
|
479
487
|
"""Get the timestamp of the response."""
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import itertools
|
|
3
4
|
import json
|
|
4
5
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
|
5
6
|
from contextlib import asynccontextmanager, contextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Callable, Literal
|
|
8
|
+
from typing import Any, Callable, Literal, cast
|
|
8
9
|
from urllib.parse import urlparse
|
|
9
10
|
|
|
10
11
|
from opentelemetry._events import (
|
|
@@ -18,8 +19,14 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
|
|
|
18
19
|
from opentelemetry.util.types import AttributeValue
|
|
19
20
|
from pydantic import TypeAdapter
|
|
20
21
|
|
|
22
|
+
from .. import _otel_messages
|
|
21
23
|
from .._run_context import RunContext
|
|
22
|
-
from ..messages import
|
|
24
|
+
from ..messages import (
|
|
25
|
+
ModelMessage,
|
|
26
|
+
ModelRequest,
|
|
27
|
+
ModelResponse,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
)
|
|
23
30
|
from ..settings import ModelSettings
|
|
24
31
|
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
|
|
25
32
|
from .wrapper import WrapperModel
|
|
@@ -80,6 +87,8 @@ class InstrumentationSettings:
|
|
|
80
87
|
event_logger: EventLogger = field(repr=False)
|
|
81
88
|
event_mode: Literal['attributes', 'logs'] = 'attributes'
|
|
82
89
|
include_binary_content: bool = True
|
|
90
|
+
include_content: bool = True
|
|
91
|
+
version: Literal[1, 2] = 1
|
|
83
92
|
|
|
84
93
|
def __init__(
|
|
85
94
|
self,
|
|
@@ -90,6 +99,7 @@ class InstrumentationSettings:
|
|
|
90
99
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
91
100
|
include_binary_content: bool = True,
|
|
92
101
|
include_content: bool = True,
|
|
102
|
+
version: Literal[1, 2] = 1,
|
|
93
103
|
):
|
|
94
104
|
"""Create instrumentation options.
|
|
95
105
|
|
|
@@ -109,6 +119,10 @@ class InstrumentationSettings:
|
|
|
109
119
|
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
110
120
|
include_content: Whether to include prompts, completions, and tool call arguments and responses
|
|
111
121
|
in the instrumentation events.
|
|
122
|
+
version: Version of the data format.
|
|
123
|
+
Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
|
|
124
|
+
Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
|
|
125
|
+
Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
|
|
112
126
|
"""
|
|
113
127
|
from pydantic_ai import __version__
|
|
114
128
|
|
|
@@ -122,6 +136,7 @@ class InstrumentationSettings:
|
|
|
122
136
|
self.event_mode = event_mode
|
|
123
137
|
self.include_binary_content = include_binary_content
|
|
124
138
|
self.include_content = include_content
|
|
139
|
+
self.version = version
|
|
125
140
|
|
|
126
141
|
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
127
142
|
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
|
|
@@ -179,6 +194,90 @@ class InstrumentationSettings:
|
|
|
179
194
|
event.body = InstrumentedModel.serialize_any(event.body)
|
|
180
195
|
return events
|
|
181
196
|
|
|
197
|
+
def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_messages.ChatMessage]:
|
|
198
|
+
result: list[_otel_messages.ChatMessage] = []
|
|
199
|
+
for message in messages:
|
|
200
|
+
if isinstance(message, ModelRequest):
|
|
201
|
+
for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)):
|
|
202
|
+
message_parts: list[_otel_messages.MessagePart] = []
|
|
203
|
+
for part in group:
|
|
204
|
+
if hasattr(part, 'otel_message_parts'):
|
|
205
|
+
message_parts.extend(part.otel_message_parts(self))
|
|
206
|
+
result.append(
|
|
207
|
+
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
|
|
208
|
+
)
|
|
209
|
+
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
210
|
+
result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
|
|
214
|
+
if self.version == 1:
|
|
215
|
+
events = self.messages_to_otel_events(input_messages)
|
|
216
|
+
for event in self.messages_to_otel_events([response]):
|
|
217
|
+
events.append(
|
|
218
|
+
Event(
|
|
219
|
+
'gen_ai.choice',
|
|
220
|
+
body={
|
|
221
|
+
'index': 0,
|
|
222
|
+
'message': event.body,
|
|
223
|
+
},
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
for event in events:
|
|
227
|
+
event.attributes = {
|
|
228
|
+
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
229
|
+
**(event.attributes or {}),
|
|
230
|
+
}
|
|
231
|
+
self._emit_events(span, events)
|
|
232
|
+
else:
|
|
233
|
+
output_messages = self.messages_to_otel_messages([response])
|
|
234
|
+
assert len(output_messages) == 1
|
|
235
|
+
output_message = cast(_otel_messages.OutputMessage, output_messages[0])
|
|
236
|
+
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
237
|
+
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
238
|
+
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
239
|
+
attributes = {
|
|
240
|
+
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
241
|
+
'gen_ai.output.messages': json.dumps([output_message]),
|
|
242
|
+
'logfire.json_schema': json.dumps(
|
|
243
|
+
{
|
|
244
|
+
'type': 'object',
|
|
245
|
+
'properties': {
|
|
246
|
+
'gen_ai.input.messages': {'type': 'array'},
|
|
247
|
+
'gen_ai.output.messages': {'type': 'array'},
|
|
248
|
+
**({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
|
|
249
|
+
'model_request_parameters': {'type': 'object'},
|
|
250
|
+
},
|
|
251
|
+
}
|
|
252
|
+
),
|
|
253
|
+
}
|
|
254
|
+
if instructions is not None:
|
|
255
|
+
attributes['gen_ai.system_instructions'] = json.dumps(
|
|
256
|
+
[_otel_messages.TextPart(type='text', content=instructions)]
|
|
257
|
+
)
|
|
258
|
+
span.set_attributes(attributes)
|
|
259
|
+
|
|
260
|
+
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
261
|
+
if self.event_mode == 'logs':
|
|
262
|
+
for event in events:
|
|
263
|
+
self.event_logger.emit(event)
|
|
264
|
+
else:
|
|
265
|
+
attr_name = 'events'
|
|
266
|
+
span.set_attributes(
|
|
267
|
+
{
|
|
268
|
+
attr_name: json.dumps([InstrumentedModel.event_to_dict(event) for event in events]),
|
|
269
|
+
'logfire.json_schema': json.dumps(
|
|
270
|
+
{
|
|
271
|
+
'type': 'object',
|
|
272
|
+
'properties': {
|
|
273
|
+
attr_name: {'type': 'array'},
|
|
274
|
+
'model_request_parameters': {'type': 'object'},
|
|
275
|
+
},
|
|
276
|
+
}
|
|
277
|
+
),
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
|
|
182
281
|
|
|
183
282
|
GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
|
|
184
283
|
GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
|
|
@@ -269,7 +368,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
269
368
|
# FallbackModel updates these span attributes.
|
|
270
369
|
attributes.update(getattr(span, 'attributes', {}))
|
|
271
370
|
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
272
|
-
system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
|
|
371
|
+
system = cast(str, attributes[GEN_AI_SYSTEM_ATTRIBUTE])
|
|
273
372
|
|
|
274
373
|
response_model = response.model_name or request_model
|
|
275
374
|
|
|
@@ -297,18 +396,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
297
396
|
if not span.is_recording():
|
|
298
397
|
return
|
|
299
398
|
|
|
300
|
-
|
|
301
|
-
for event in self.instrumentation_settings.messages_to_otel_events([response]):
|
|
302
|
-
events.append(
|
|
303
|
-
Event(
|
|
304
|
-
'gen_ai.choice',
|
|
305
|
-
body={
|
|
306
|
-
# TODO finish_reason
|
|
307
|
-
'index': 0,
|
|
308
|
-
'message': event.body,
|
|
309
|
-
},
|
|
310
|
-
)
|
|
311
|
-
)
|
|
399
|
+
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
312
400
|
span.set_attributes(
|
|
313
401
|
{
|
|
314
402
|
**response.usage.opentelemetry_attributes(),
|
|
@@ -316,12 +404,6 @@ class InstrumentedModel(WrapperModel):
|
|
|
316
404
|
}
|
|
317
405
|
)
|
|
318
406
|
span.update_name(f'{operation} {request_model}')
|
|
319
|
-
for event in events:
|
|
320
|
-
event.attributes = {
|
|
321
|
-
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
322
|
-
**(event.attributes or {}),
|
|
323
|
-
}
|
|
324
|
-
self._emit_events(span, events)
|
|
325
407
|
|
|
326
408
|
yield finish
|
|
327
409
|
finally:
|
|
@@ -330,27 +412,6 @@ class InstrumentedModel(WrapperModel):
|
|
|
330
412
|
# to prevent them from being redundantly recorded in the span itself by logfire.
|
|
331
413
|
record_metrics()
|
|
332
414
|
|
|
333
|
-
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
334
|
-
if self.instrumentation_settings.event_mode == 'logs':
|
|
335
|
-
for event in events:
|
|
336
|
-
self.instrumentation_settings.event_logger.emit(event)
|
|
337
|
-
else:
|
|
338
|
-
attr_name = 'events'
|
|
339
|
-
span.set_attributes(
|
|
340
|
-
{
|
|
341
|
-
attr_name: json.dumps([self.event_to_dict(event) for event in events]),
|
|
342
|
-
'logfire.json_schema': json.dumps(
|
|
343
|
-
{
|
|
344
|
-
'type': 'object',
|
|
345
|
-
'properties': {
|
|
346
|
-
attr_name: {'type': 'array'},
|
|
347
|
-
'model_request_parameters': {'type': 'object'},
|
|
348
|
-
},
|
|
349
|
-
}
|
|
350
|
-
),
|
|
351
|
-
}
|
|
352
|
-
)
|
|
353
|
-
|
|
354
415
|
@staticmethod
|
|
355
416
|
def model_attributes(model: Model):
|
|
356
417
|
attributes: dict[str, AttributeValue] = {
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -353,6 +353,7 @@ class MistralModel(Model):
|
|
|
353
353
|
model_name=response.model,
|
|
354
354
|
timestamp=timestamp,
|
|
355
355
|
provider_request_id=response.id,
|
|
356
|
+
provider_name=self._provider.name,
|
|
356
357
|
)
|
|
357
358
|
|
|
358
359
|
async def _process_streamed_response(
|
|
@@ -378,6 +379,7 @@ class MistralModel(Model):
|
|
|
378
379
|
_response=peekable_response,
|
|
379
380
|
_model_name=self._model_name,
|
|
380
381
|
_timestamp=timestamp,
|
|
382
|
+
_provider_name=self._provider.name,
|
|
381
383
|
)
|
|
382
384
|
|
|
383
385
|
@staticmethod
|
|
@@ -584,6 +586,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
584
586
|
_model_name: MistralModelName
|
|
585
587
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
586
588
|
_timestamp: datetime
|
|
589
|
+
_provider_name: str
|
|
587
590
|
|
|
588
591
|
_delta_content: str = field(default='', init=False)
|
|
589
592
|
|
|
@@ -631,6 +634,11 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
631
634
|
"""Get the model name of the response."""
|
|
632
635
|
return self._model_name
|
|
633
636
|
|
|
637
|
+
@property
|
|
638
|
+
def provider_name(self) -> str:
|
|
639
|
+
"""Get the provider name."""
|
|
640
|
+
return self._provider_name
|
|
641
|
+
|
|
634
642
|
@property
|
|
635
643
|
def timestamp(self) -> datetime:
|
|
636
644
|
"""Get the timestamp of the response."""
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -500,6 +500,7 @@ class OpenAIModel(Model):
|
|
|
500
500
|
timestamp=timestamp,
|
|
501
501
|
provider_details=vendor_details,
|
|
502
502
|
provider_request_id=response.id,
|
|
503
|
+
provider_name=self._provider.name,
|
|
503
504
|
)
|
|
504
505
|
|
|
505
506
|
async def _process_streamed_response(
|
|
@@ -519,6 +520,7 @@ class OpenAIModel(Model):
|
|
|
519
520
|
_model_profile=self.profile,
|
|
520
521
|
_response=peekable_response,
|
|
521
522
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
523
|
+
_provider_name=self._provider.name,
|
|
522
524
|
)
|
|
523
525
|
|
|
524
526
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -571,6 +573,8 @@ class OpenAIModel(Model):
|
|
|
571
573
|
# Note: model responses from this model should only have one text item, so the following
|
|
572
574
|
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
573
575
|
message_param['content'] = '\n\n'.join(texts)
|
|
576
|
+
else:
|
|
577
|
+
message_param['content'] = None
|
|
574
578
|
if tool_calls:
|
|
575
579
|
message_param['tool_calls'] = tool_calls
|
|
576
580
|
openai_messages.append(message_param)
|
|
@@ -803,6 +807,7 @@ class OpenAIResponsesModel(Model):
|
|
|
803
807
|
model_name=response.model,
|
|
804
808
|
provider_request_id=response.id,
|
|
805
809
|
timestamp=timestamp,
|
|
810
|
+
provider_name=self._provider.name,
|
|
806
811
|
)
|
|
807
812
|
|
|
808
813
|
async def _process_streamed_response(
|
|
@@ -822,6 +827,7 @@ class OpenAIResponsesModel(Model):
|
|
|
822
827
|
_model_name=self._model_name,
|
|
823
828
|
_response=peekable_response,
|
|
824
829
|
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
830
|
+
_provider_name=self._provider.name,
|
|
825
831
|
)
|
|
826
832
|
|
|
827
833
|
@overload
|
|
@@ -1137,6 +1143,7 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1137
1143
|
_model_profile: ModelProfile
|
|
1138
1144
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
1139
1145
|
_timestamp: datetime
|
|
1146
|
+
_provider_name: str
|
|
1140
1147
|
|
|
1141
1148
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
1142
1149
|
async for chunk in self._response:
|
|
@@ -1180,6 +1187,11 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1180
1187
|
"""Get the model name of the response."""
|
|
1181
1188
|
return self._model_name
|
|
1182
1189
|
|
|
1190
|
+
@property
|
|
1191
|
+
def provider_name(self) -> str:
|
|
1192
|
+
"""Get the provider name."""
|
|
1193
|
+
return self._provider_name
|
|
1194
|
+
|
|
1183
1195
|
@property
|
|
1184
1196
|
def timestamp(self) -> datetime:
|
|
1185
1197
|
"""Get the timestamp of the response."""
|
|
@@ -1193,6 +1205,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1193
1205
|
_model_name: OpenAIModelName
|
|
1194
1206
|
_response: AsyncIterable[responses.ResponseStreamEvent]
|
|
1195
1207
|
_timestamp: datetime
|
|
1208
|
+
_provider_name: str
|
|
1196
1209
|
|
|
1197
1210
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
1198
1211
|
async for chunk in self._response:
|
|
@@ -1313,6 +1326,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1313
1326
|
"""Get the model name of the response."""
|
|
1314
1327
|
return self._model_name
|
|
1315
1328
|
|
|
1329
|
+
@property
|
|
1330
|
+
def provider_name(self) -> str:
|
|
1331
|
+
"""Get the provider name."""
|
|
1332
|
+
return self._provider_name
|
|
1333
|
+
|
|
1316
1334
|
@property
|
|
1317
1335
|
def timestamp(self) -> datetime:
|
|
1318
1336
|
"""Get the timestamp of the response."""
|
pydantic_ai/models/test.py
CHANGED
|
@@ -131,6 +131,7 @@ class TestModel(Model):
|
|
|
131
131
|
_model_name=self._model_name,
|
|
132
132
|
_structured_response=model_response,
|
|
133
133
|
_messages=messages,
|
|
134
|
+
_provider_name=self._system,
|
|
134
135
|
)
|
|
135
136
|
|
|
136
137
|
@property
|
|
@@ -263,6 +264,7 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
263
264
|
_model_name: str
|
|
264
265
|
_structured_response: ModelResponse
|
|
265
266
|
_messages: InitVar[Iterable[ModelMessage]]
|
|
267
|
+
_provider_name: str
|
|
266
268
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
267
269
|
|
|
268
270
|
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
@@ -305,6 +307,11 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
305
307
|
"""Get the model name of the response."""
|
|
306
308
|
return self._model_name
|
|
307
309
|
|
|
310
|
+
@property
|
|
311
|
+
def provider_name(self) -> str:
|
|
312
|
+
"""Get the provider name."""
|
|
313
|
+
return self._provider_name
|
|
314
|
+
|
|
308
315
|
@property
|
|
309
316
|
def timestamp(self) -> datetime:
|
|
310
317
|
"""Get the timestamp of the response."""
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import overload
|
|
4
|
+
from typing import Union, overload
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
|
+
from typing_extensions import TypeAlias
|
|
7
8
|
|
|
8
9
|
from pydantic_ai.exceptions import UserError
|
|
9
10
|
from pydantic_ai.models import cached_async_http_client
|
|
@@ -12,15 +13,18 @@ from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
|
12
13
|
from pydantic_ai.providers import Provider
|
|
13
14
|
|
|
14
15
|
try:
|
|
15
|
-
from anthropic import AsyncAnthropic
|
|
16
|
-
except ImportError as _import_error:
|
|
16
|
+
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
|
17
|
+
except ImportError as _import_error:
|
|
17
18
|
raise ImportError(
|
|
18
19
|
'Please install the `anthropic` package to use the Anthropic provider, '
|
|
19
20
|
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
|
|
20
21
|
) from _import_error
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AnthropicProvider(Provider[AsyncAnthropicClient]):
|
|
24
28
|
"""Provider for Anthropic API."""
|
|
25
29
|
|
|
26
30
|
@property
|
|
@@ -32,14 +36,14 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
32
36
|
return str(self._client.base_url)
|
|
33
37
|
|
|
34
38
|
@property
|
|
35
|
-
def client(self) ->
|
|
39
|
+
def client(self) -> AsyncAnthropicClient:
|
|
36
40
|
return self._client
|
|
37
41
|
|
|
38
42
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
39
43
|
return anthropic_model_profile(model_name)
|
|
40
44
|
|
|
41
45
|
@overload
|
|
42
|
-
def __init__(self, *, anthropic_client:
|
|
46
|
+
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...
|
|
43
47
|
|
|
44
48
|
@overload
|
|
45
49
|
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
|
|
@@ -48,7 +52,7 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
48
52
|
self,
|
|
49
53
|
*,
|
|
50
54
|
api_key: str | None = None,
|
|
51
|
-
anthropic_client:
|
|
55
|
+
anthropic_client: AsyncAnthropicClient | None = None,
|
|
52
56
|
http_client: httpx.AsyncClient | None = None,
|
|
53
57
|
) -> None:
|
|
54
58
|
"""Create a new Anthropic provider.
|
|
@@ -71,7 +75,6 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
71
75
|
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
|
|
72
76
|
'to use the Anthropic provider.'
|
|
73
77
|
)
|
|
74
|
-
|
|
75
78
|
if http_client is not None:
|
|
76
79
|
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
77
80
|
else:
|