pydantic-ai-slim 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +8 -6
- pydantic_ai/_cli.py +32 -24
- pydantic_ai/_output.py +7 -7
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/agent.py +19 -13
- pydantic_ai/direct.py +2 -0
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/messages.py +29 -11
- pydantic_ai/models/__init__.py +42 -5
- pydantic_ai/models/anthropic.py +17 -12
- pydantic_ai/models/bedrock.py +10 -9
- pydantic_ai/models/cohere.py +4 -4
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +1 -1
- pydantic_ai/models/gemini.py +26 -22
- pydantic_ai/models/google.py +570 -0
- pydantic_ai/models/groq.py +12 -6
- pydantic_ai/models/instrumented.py +43 -33
- pydantic_ai/models/mistral.py +15 -9
- pydantic_ai/models/openai.py +45 -7
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +1 -1
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +143 -0
- pydantic_ai/providers/google_vertex.py +3 -3
- pydantic_ai/providers/openrouter.py +69 -0
- pydantic_ai/result.py +13 -21
- pydantic_ai/tools.py +2 -2
- pydantic_ai/usage.py +1 -1
- {pydantic_ai_slim-0.2.3.dist-info → pydantic_ai_slim-0.2.5.dist-info}/METADATA +7 -5
- pydantic_ai_slim-0.2.5.dist-info/RECORD +59 -0
- pydantic_ai_slim-0.2.5.dist-info/licenses/LICENSE +21 -0
- pydantic_ai_slim-0.2.3.dist-info/RECORD +0 -55
- {pydantic_ai_slim-0.2.3.dist-info → pydantic_ai_slim-0.2.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.3.dist-info → pydantic_ai_slim-0.2.5.dist-info}/entry_points.txt +0 -0
|
@@ -77,6 +77,7 @@ class InstrumentationSettings:
|
|
|
77
77
|
tracer: Tracer = field(repr=False)
|
|
78
78
|
event_logger: EventLogger = field(repr=False)
|
|
79
79
|
event_mode: Literal['attributes', 'logs'] = 'attributes'
|
|
80
|
+
include_binary_content: bool = True
|
|
80
81
|
|
|
81
82
|
def __init__(
|
|
82
83
|
self,
|
|
@@ -84,6 +85,7 @@ class InstrumentationSettings:
|
|
|
84
85
|
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
85
86
|
tracer_provider: TracerProvider | None = None,
|
|
86
87
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
88
|
+
include_binary_content: bool = True,
|
|
87
89
|
):
|
|
88
90
|
"""Create instrumentation options.
|
|
89
91
|
|
|
@@ -97,6 +99,7 @@ class InstrumentationSettings:
|
|
|
97
99
|
If not provided, the global event logger provider is used.
|
|
98
100
|
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
99
101
|
This is only used if `event_mode='logs'`.
|
|
102
|
+
include_binary_content: Whether to include binary content in the instrumentation events.
|
|
100
103
|
"""
|
|
101
104
|
from pydantic_ai import __version__
|
|
102
105
|
|
|
@@ -105,6 +108,40 @@ class InstrumentationSettings:
|
|
|
105
108
|
self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
|
|
106
109
|
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
|
|
107
110
|
self.event_mode = event_mode
|
|
111
|
+
self.include_binary_content = include_binary_content
|
|
112
|
+
|
|
113
|
+
def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
|
|
114
|
+
"""Convert a list of model messages to OpenTelemetry events.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
messages: The messages to convert.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A list of OpenTelemetry events.
|
|
121
|
+
"""
|
|
122
|
+
events: list[Event] = []
|
|
123
|
+
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
|
|
124
|
+
if instructions is not None:
|
|
125
|
+
events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
|
|
126
|
+
|
|
127
|
+
for message_index, message in enumerate(messages):
|
|
128
|
+
message_events: list[Event] = []
|
|
129
|
+
if isinstance(message, ModelRequest):
|
|
130
|
+
for part in message.parts:
|
|
131
|
+
if hasattr(part, 'otel_event'):
|
|
132
|
+
message_events.append(part.otel_event(self))
|
|
133
|
+
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
134
|
+
message_events = message.otel_events()
|
|
135
|
+
for event in message_events:
|
|
136
|
+
event.attributes = {
|
|
137
|
+
'gen_ai.message.index': message_index,
|
|
138
|
+
**(event.attributes or {}),
|
|
139
|
+
}
|
|
140
|
+
events.extend(message_events)
|
|
141
|
+
|
|
142
|
+
for event in events:
|
|
143
|
+
event.body = InstrumentedModel.serialize_any(event.body)
|
|
144
|
+
return events
|
|
108
145
|
|
|
109
146
|
|
|
110
147
|
GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
|
|
@@ -155,7 +192,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
155
192
|
) as response_stream:
|
|
156
193
|
yield response_stream
|
|
157
194
|
finally:
|
|
158
|
-
if response_stream:
|
|
195
|
+
if response_stream: # pragma: no branch
|
|
159
196
|
finish(response_stream.get())
|
|
160
197
|
|
|
161
198
|
@contextmanager
|
|
@@ -193,8 +230,8 @@ class InstrumentedModel(WrapperModel):
|
|
|
193
230
|
if not span.is_recording():
|
|
194
231
|
return
|
|
195
232
|
|
|
196
|
-
events = self.messages_to_otel_events(messages)
|
|
197
|
-
for event in self.messages_to_otel_events([response]):
|
|
233
|
+
events = self.settings.messages_to_otel_events(messages)
|
|
234
|
+
for event in self.settings.messages_to_otel_events([response]):
|
|
198
235
|
events.append(
|
|
199
236
|
Event(
|
|
200
237
|
'gen_ai.choice',
|
|
@@ -253,9 +290,9 @@ class InstrumentedModel(WrapperModel):
|
|
|
253
290
|
except Exception: # pragma: no cover
|
|
254
291
|
pass
|
|
255
292
|
else:
|
|
256
|
-
if parsed.hostname:
|
|
293
|
+
if parsed.hostname: # pragma: no branch
|
|
257
294
|
attributes['server.address'] = parsed.hostname
|
|
258
|
-
if parsed.port:
|
|
295
|
+
if parsed.port: # pragma: no branch
|
|
259
296
|
attributes['server.port'] = parsed.port
|
|
260
297
|
|
|
261
298
|
return attributes
|
|
@@ -263,40 +300,13 @@ class InstrumentedModel(WrapperModel):
|
|
|
263
300
|
@staticmethod
|
|
264
301
|
def event_to_dict(event: Event) -> dict[str, Any]:
|
|
265
302
|
if not event.body:
|
|
266
|
-
body = {}
|
|
303
|
+
body = {} # pragma: no cover
|
|
267
304
|
elif isinstance(event.body, Mapping):
|
|
268
305
|
body = event.body # type: ignore
|
|
269
306
|
else:
|
|
270
307
|
body = {'body': event.body}
|
|
271
308
|
return {**body, **(event.attributes or {})}
|
|
272
309
|
|
|
273
|
-
@staticmethod
|
|
274
|
-
def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
|
|
275
|
-
events: list[Event] = []
|
|
276
|
-
last_model_request: ModelRequest | None = None
|
|
277
|
-
for message_index, message in enumerate(messages):
|
|
278
|
-
message_events: list[Event] = []
|
|
279
|
-
if isinstance(message, ModelRequest):
|
|
280
|
-
last_model_request = message
|
|
281
|
-
for part in message.parts:
|
|
282
|
-
if hasattr(part, 'otel_event'):
|
|
283
|
-
message_events.append(part.otel_event())
|
|
284
|
-
elif isinstance(message, ModelResponse):
|
|
285
|
-
message_events = message.otel_events()
|
|
286
|
-
for event in message_events:
|
|
287
|
-
event.attributes = {
|
|
288
|
-
'gen_ai.message.index': message_index,
|
|
289
|
-
**(event.attributes or {}),
|
|
290
|
-
}
|
|
291
|
-
events.extend(message_events)
|
|
292
|
-
if last_model_request and last_model_request.instructions:
|
|
293
|
-
events.insert(
|
|
294
|
-
0, Event('gen_ai.system.message', body={'content': last_model_request.instructions, 'role': 'system'})
|
|
295
|
-
)
|
|
296
|
-
for event in events:
|
|
297
|
-
event.body = InstrumentedModel.serialize_any(event.body)
|
|
298
|
-
return events
|
|
299
|
-
|
|
300
310
|
@staticmethod
|
|
301
311
|
def serialize_any(value: Any) -> str:
|
|
302
312
|
try:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -71,7 +71,7 @@ try:
|
|
|
71
71
|
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
72
72
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
73
73
|
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
74
|
-
except ImportError as e:
|
|
74
|
+
except ImportError as e: # pragma: lax no cover
|
|
75
75
|
raise ImportError(
|
|
76
76
|
'Please install `mistral` to use the Mistral model, '
|
|
77
77
|
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
@@ -208,7 +208,7 @@ class MistralModel(Model):
|
|
|
208
208
|
except SDKError as e:
|
|
209
209
|
if (status_code := e.status_code) >= 400:
|
|
210
210
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
211
|
-
raise
|
|
211
|
+
raise # pragma: lax no cover
|
|
212
212
|
|
|
213
213
|
assert response, 'A unexpected empty response from Mistral.'
|
|
214
214
|
return response
|
|
@@ -325,7 +325,9 @@ class MistralModel(Model):
|
|
|
325
325
|
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
326
326
|
parts.append(tool)
|
|
327
327
|
|
|
328
|
-
return ModelResponse(
|
|
328
|
+
return ModelResponse(
|
|
329
|
+
parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
|
|
330
|
+
)
|
|
329
331
|
|
|
330
332
|
async def _process_streamed_response(
|
|
331
333
|
self,
|
|
@@ -336,7 +338,9 @@ class MistralModel(Model):
|
|
|
336
338
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
337
339
|
first_chunk = await peekable_response.peek()
|
|
338
340
|
if isinstance(first_chunk, _utils.Unset):
|
|
339
|
-
raise UnexpectedModelBehavior(
|
|
341
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
342
|
+
'Streamed response ended without content or tool calls'
|
|
343
|
+
)
|
|
340
344
|
|
|
341
345
|
if first_chunk.data.created:
|
|
342
346
|
timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
|
|
@@ -437,7 +441,7 @@ class MistralModel(Model):
|
|
|
437
441
|
"""Convert a timeout to milliseconds."""
|
|
438
442
|
if timeout is None:
|
|
439
443
|
return None
|
|
440
|
-
if isinstance(timeout, float):
|
|
444
|
+
if isinstance(timeout, float): # pragma: no cover
|
|
441
445
|
return int(1000 * timeout)
|
|
442
446
|
raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
|
|
443
447
|
|
|
@@ -454,7 +458,7 @@ class MistralModel(Model):
|
|
|
454
458
|
)
|
|
455
459
|
elif isinstance(part, RetryPromptPart):
|
|
456
460
|
if part.tool_name is None:
|
|
457
|
-
yield MistralUserMessage(content=part.model_response())
|
|
461
|
+
yield MistralUserMessage(content=part.model_response()) # pragma: no cover
|
|
458
462
|
else:
|
|
459
463
|
yield MistralToolMessage(
|
|
460
464
|
tool_call_id=part.tool_call_id,
|
|
@@ -519,7 +523,7 @@ class MistralModel(Model):
|
|
|
519
523
|
else:
|
|
520
524
|
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
521
525
|
elif isinstance(item, DocumentUrl):
|
|
522
|
-
raise RuntimeError('DocumentUrl is not supported in Mistral.')
|
|
526
|
+
raise RuntimeError('DocumentUrl is not supported in Mistral.') # pragma: no cover
|
|
523
527
|
elif isinstance(item, VideoUrl):
|
|
524
528
|
raise RuntimeError('VideoUrl is not supported in Mistral.')
|
|
525
529
|
else: # pragma: no cover
|
|
@@ -663,7 +667,7 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
|
|
|
663
667
|
details=None,
|
|
664
668
|
)
|
|
665
669
|
else:
|
|
666
|
-
return Usage()
|
|
670
|
+
return Usage() # pragma: no cover
|
|
667
671
|
|
|
668
672
|
|
|
669
673
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
@@ -677,7 +681,9 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
|
|
|
677
681
|
if isinstance(chunk, MistralTextChunk):
|
|
678
682
|
output = output or '' + chunk.text
|
|
679
683
|
else:
|
|
680
|
-
assert False,
|
|
684
|
+
assert False, ( # pragma: no cover
|
|
685
|
+
f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
686
|
+
)
|
|
681
687
|
elif isinstance(content, str):
|
|
682
688
|
output = content
|
|
683
689
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
|
|
|
104
104
|
result in faster responses and fewer tokens used on reasoning in a response.
|
|
105
105
|
"""
|
|
106
106
|
|
|
107
|
+
openai_logprobs: bool
|
|
108
|
+
"""Include log probabilities in the response."""
|
|
109
|
+
|
|
110
|
+
openai_top_logprobs: int
|
|
111
|
+
"""Include log probabilities of the top n tokens in the response."""
|
|
112
|
+
|
|
107
113
|
openai_user: str
|
|
108
114
|
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
|
|
109
115
|
|
|
@@ -287,6 +293,8 @@ class OpenAIModel(Model):
|
|
|
287
293
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
288
294
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
289
295
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
296
|
+
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
|
|
297
|
+
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
|
|
290
298
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
291
299
|
extra_headers=extra_headers,
|
|
292
300
|
extra_body=model_settings.get('extra_body'),
|
|
@@ -294,26 +302,54 @@ class OpenAIModel(Model):
|
|
|
294
302
|
except APIStatusError as e:
|
|
295
303
|
if (status_code := e.status_code) >= 400:
|
|
296
304
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
297
|
-
raise
|
|
305
|
+
raise # pragma: lax no cover
|
|
298
306
|
|
|
299
307
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
300
308
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
301
309
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
302
310
|
choice = response.choices[0]
|
|
303
311
|
items: list[ModelResponsePart] = []
|
|
312
|
+
vendor_details: dict[str, Any] | None = None
|
|
313
|
+
|
|
314
|
+
# Add logprobs to vendor_details if available
|
|
315
|
+
if choice.logprobs is not None and choice.logprobs.content:
|
|
316
|
+
# Convert logprobs to a serializable format
|
|
317
|
+
vendor_details = {
|
|
318
|
+
'logprobs': [
|
|
319
|
+
{
|
|
320
|
+
'token': lp.token,
|
|
321
|
+
'bytes': lp.bytes,
|
|
322
|
+
'logprob': lp.logprob,
|
|
323
|
+
'top_logprobs': [
|
|
324
|
+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
325
|
+
],
|
|
326
|
+
}
|
|
327
|
+
for lp in choice.logprobs.content
|
|
328
|
+
],
|
|
329
|
+
}
|
|
330
|
+
|
|
304
331
|
if choice.message.content is not None:
|
|
305
332
|
items.append(TextPart(choice.message.content))
|
|
306
333
|
if choice.message.tool_calls is not None:
|
|
307
334
|
for c in choice.message.tool_calls:
|
|
308
335
|
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
309
|
-
return ModelResponse(
|
|
336
|
+
return ModelResponse(
|
|
337
|
+
items,
|
|
338
|
+
usage=_map_usage(response),
|
|
339
|
+
model_name=response.model,
|
|
340
|
+
timestamp=timestamp,
|
|
341
|
+
vendor_details=vendor_details,
|
|
342
|
+
vendor_id=response.id,
|
|
343
|
+
)
|
|
310
344
|
|
|
311
345
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
312
346
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
313
347
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
314
348
|
first_chunk = await peekable_response.peek()
|
|
315
349
|
if isinstance(first_chunk, _utils.Unset):
|
|
316
|
-
raise UnexpectedModelBehavior(
|
|
350
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
351
|
+
'Streamed response ended without content or tool calls'
|
|
352
|
+
)
|
|
317
353
|
|
|
318
354
|
return OpenAIStreamedResponse(
|
|
319
355
|
_model_name=self._model_name,
|
|
@@ -399,7 +435,9 @@ class OpenAIModel(Model):
|
|
|
399
435
|
)
|
|
400
436
|
elif isinstance(part, RetryPromptPart):
|
|
401
437
|
if part.tool_name is None:
|
|
402
|
-
yield chat.ChatCompletionUserMessageParam(
|
|
438
|
+
yield chat.ChatCompletionUserMessageParam( # pragma: no cover
|
|
439
|
+
role='user', content=part.model_response()
|
|
440
|
+
)
|
|
403
441
|
else:
|
|
404
442
|
yield chat.ChatCompletionToolMessageParam(
|
|
405
443
|
role='tool',
|
|
@@ -637,7 +675,7 @@ class OpenAIResponsesModel(Model):
|
|
|
637
675
|
except APIStatusError as e:
|
|
638
676
|
if (status_code := e.status_code) >= 400:
|
|
639
677
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
640
|
-
raise
|
|
678
|
+
raise # pragma: lax no cover
|
|
641
679
|
|
|
642
680
|
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
|
|
643
681
|
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
|
|
@@ -867,7 +905,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
867
905
|
args=chunk.delta,
|
|
868
906
|
tool_call_id=chunk.item_id,
|
|
869
907
|
)
|
|
870
|
-
if maybe_event is not None:
|
|
908
|
+
if maybe_event is not None: # pragma: no branch
|
|
871
909
|
yield maybe_event
|
|
872
910
|
|
|
873
911
|
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDoneEvent):
|
|
@@ -1031,7 +1069,7 @@ class _OpenAIJsonSchema(WalkJsonSchema):
|
|
|
1031
1069
|
notes.append(f'{key}={value}')
|
|
1032
1070
|
notes_string = ', '.join(notes)
|
|
1033
1071
|
schema['description'] = notes_string if not description else f'{description} ({notes_string})'
|
|
1034
|
-
elif self.strict is None:
|
|
1072
|
+
elif self.strict is None: # pragma: no branch
|
|
1035
1073
|
self.is_strict_compatible = False
|
|
1036
1074
|
|
|
1037
1075
|
schema_type = schema.get('type')
|
pydantic_ai/models/test.py
CHANGED
|
@@ -169,7 +169,7 @@ class TestModel(Model):
|
|
|
169
169
|
model_name=self._model_name,
|
|
170
170
|
)
|
|
171
171
|
|
|
172
|
-
if messages:
|
|
172
|
+
if messages: # pragma: no branch
|
|
173
173
|
last_message = messages[-1]
|
|
174
174
|
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
|
|
175
175
|
|
pydantic_ai/models/wrapper.py
CHANGED
|
@@ -52,6 +52,10 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
52
52
|
from .deepseek import DeepSeekProvider
|
|
53
53
|
|
|
54
54
|
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'openrouter':
|
|
56
|
+
from .openrouter import OpenRouterProvider
|
|
57
|
+
|
|
58
|
+
return OpenRouterProvider()
|
|
55
59
|
elif provider == 'azure':
|
|
56
60
|
from .azure import AzureProvider
|
|
57
61
|
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Literal, overload
|
|
5
|
+
|
|
6
|
+
from pydantic_ai.exceptions import UserError
|
|
7
|
+
from pydantic_ai.models import get_user_agent
|
|
8
|
+
from pydantic_ai.providers import Provider
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from google import genai
|
|
12
|
+
from google.auth.credentials import Credentials
|
|
13
|
+
except ImportError as _import_error:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
'Please install the `google-genai` package to use the Google provider, '
|
|
16
|
+
'you can use the `google` optional group — `pip install "pydantic-ai-slim[google]"`'
|
|
17
|
+
) from _import_error
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GoogleProvider(Provider[genai.Client]):
|
|
21
|
+
"""Provider for Google."""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return 'google-vertex' if self._client._api_client.vertexai else 'google-gla' # type: ignore[reportPrivateUsage]
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def base_url(self) -> str:
|
|
29
|
+
return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage]
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def client(self) -> genai.Client:
|
|
33
|
+
return self._client
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
37
|
+
|
|
38
|
+
@overload
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
credentials: Credentials | None = None,
|
|
43
|
+
project: str | None = None,
|
|
44
|
+
location: VertexAILocation | Literal['global'] | None = None,
|
|
45
|
+
) -> None: ...
|
|
46
|
+
|
|
47
|
+
@overload
|
|
48
|
+
def __init__(self, *, client: genai.Client) -> None: ...
|
|
49
|
+
|
|
50
|
+
@overload
|
|
51
|
+
def __init__(self, *, vertexai: bool = False) -> None: ...
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
*,
|
|
56
|
+
api_key: str | None = None,
|
|
57
|
+
credentials: Credentials | None = None,
|
|
58
|
+
project: str | None = None,
|
|
59
|
+
location: VertexAILocation | Literal['global'] | None = None,
|
|
60
|
+
client: genai.Client | None = None,
|
|
61
|
+
vertexai: bool | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Create a new Google provider.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
|
|
67
|
+
use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable.
|
|
68
|
+
Applies to the Gemini Developer API only.
|
|
69
|
+
credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be
|
|
70
|
+
obtained from environment variables and default credentials. For more information, see Set up
|
|
71
|
+
Application Default Credentials. Applies to the Vertex AI API only.
|
|
72
|
+
project: The Google Cloud project ID to use for quota. Can be obtained from environment variables
|
|
73
|
+
(for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only.
|
|
74
|
+
location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables.
|
|
75
|
+
Applies to the Vertex AI API only.
|
|
76
|
+
client: A pre-initialized client to use.
|
|
77
|
+
vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used.
|
|
78
|
+
Defaults to `False`.
|
|
79
|
+
"""
|
|
80
|
+
if client is None:
|
|
81
|
+
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
|
|
82
|
+
api_key = api_key or os.environ.get('GOOGLE_API_KEY')
|
|
83
|
+
|
|
84
|
+
if vertexai is None: # pragma: lax no cover
|
|
85
|
+
vertexai = bool(location or project or credentials)
|
|
86
|
+
|
|
87
|
+
if not vertexai:
|
|
88
|
+
if api_key is None:
|
|
89
|
+
raise UserError( # pragma: no cover
|
|
90
|
+
'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
|
|
91
|
+
'to use the Google Generative Language API.'
|
|
92
|
+
)
|
|
93
|
+
self._client = genai.Client(
|
|
94
|
+
vertexai=vertexai,
|
|
95
|
+
api_key=api_key,
|
|
96
|
+
http_options={'headers': {'User-Agent': get_user_agent()}},
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
self._client = genai.Client(
|
|
100
|
+
vertexai=vertexai,
|
|
101
|
+
project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'),
|
|
102
|
+
location=location or os.environ.get('GOOGLE_CLOUD_LOCATION'),
|
|
103
|
+
credentials=credentials,
|
|
104
|
+
http_options={'headers': {'User-Agent': get_user_agent()}},
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
self._client = client # pragma: lax no cover
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
VertexAILocation = Literal[
|
|
111
|
+
'asia-east1',
|
|
112
|
+
'asia-east2',
|
|
113
|
+
'asia-northeast1',
|
|
114
|
+
'asia-northeast3',
|
|
115
|
+
'asia-south1',
|
|
116
|
+
'asia-southeast1',
|
|
117
|
+
'australia-southeast1',
|
|
118
|
+
'europe-central2',
|
|
119
|
+
'europe-north1',
|
|
120
|
+
'europe-southwest1',
|
|
121
|
+
'europe-west1',
|
|
122
|
+
'europe-west2',
|
|
123
|
+
'europe-west3',
|
|
124
|
+
'europe-west4',
|
|
125
|
+
'europe-west6',
|
|
126
|
+
'europe-west8',
|
|
127
|
+
'europe-west9',
|
|
128
|
+
'me-central1',
|
|
129
|
+
'me-central2',
|
|
130
|
+
'me-west1',
|
|
131
|
+
'northamerica-northeast1',
|
|
132
|
+
'southamerica-east1',
|
|
133
|
+
'us-central1',
|
|
134
|
+
'us-east1',
|
|
135
|
+
'us-east4',
|
|
136
|
+
'us-east5',
|
|
137
|
+
'us-south1',
|
|
138
|
+
'us-west1',
|
|
139
|
+
'us-west4',
|
|
140
|
+
]
|
|
141
|
+
"""Regions available for Vertex AI.
|
|
142
|
+
More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
|
|
143
|
+
"""
|
|
@@ -128,7 +128,7 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
128
128
|
self.credentials = None
|
|
129
129
|
|
|
130
130
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
131
|
-
if self.credentials is None:
|
|
131
|
+
if self.credentials is None: # pragma: no branch
|
|
132
132
|
self.credentials = await self._get_credentials()
|
|
133
133
|
if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
|
|
134
134
|
await self._refresh_token()
|
|
@@ -157,9 +157,9 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
157
157
|
creds, creds_project_id = await _async_google_auth()
|
|
158
158
|
creds_source = '`google.auth.default()`'
|
|
159
159
|
|
|
160
|
-
if self.project_id is None:
|
|
160
|
+
if self.project_id is None: # pragma: no branch
|
|
161
161
|
if creds_project_id is None:
|
|
162
|
-
raise UserError(f'No project_id provided and none found in {creds_source}')
|
|
162
|
+
raise UserError(f'No project_id provided and none found in {creds_source}') # pragma: no cover
|
|
163
163
|
self.project_id = creds_project_id
|
|
164
164
|
return creds
|
|
165
165
|
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.providers import Provider
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from openai import AsyncOpenAI
|
|
15
|
+
except ImportError as _import_error: # pragma: no cover
|
|
16
|
+
raise ImportError(
|
|
17
|
+
'Please install the `openai` package to use the OpenRouter provider, '
|
|
18
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
19
|
+
) from _import_error
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OpenRouterProvider(Provider[AsyncOpenAI]):
|
|
23
|
+
"""Provider for OpenRouter API."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def name(self) -> str:
|
|
27
|
+
return 'openrouter'
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def base_url(self) -> str:
|
|
31
|
+
return 'https://openrouter.ai/api/v1'
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def client(self) -> AsyncOpenAI:
|
|
35
|
+
return self._client
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def __init__(self) -> None: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
45
|
+
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
openai_client: AsyncOpenAI | None = None,
|
|
54
|
+
http_client: AsyncHTTPClient | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
api_key = api_key or os.getenv('OPENROUTER_API_KEY')
|
|
57
|
+
if not api_key and openai_client is None:
|
|
58
|
+
raise UserError(
|
|
59
|
+
'Set the `OPENROUTER_API_KEY` environment variable or pass it via `OpenRouterProvider(api_key=...)`'
|
|
60
|
+
'to use the OpenRouter provider.'
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if openai_client is not None:
|
|
64
|
+
self._client = openai_client
|
|
65
|
+
elif http_client is not None:
|
|
66
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
67
|
+
else:
|
|
68
|
+
http_client = cached_async_http_client(provider='openrouter')
|
|
69
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|