pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +16 -19
- pydantic_ai/_parts_manager.py +3 -1
- pydantic_ai/_tool_manager.py +29 -6
- pydantic_ai/ag_ui.py +75 -43
- pydantic_ai/agent/__init__.py +7 -7
- pydantic_ai/durable_exec/temporal/_agent.py +71 -10
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/mcp.py +13 -25
- pydantic_ai/messages.py +78 -19
- pydantic_ai/models/__init__.py +1 -0
- pydantic_ai/models/anthropic.py +4 -11
- pydantic_ai/models/bedrock.py +6 -14
- pydantic_ai/models/gemini.py +3 -1
- pydantic_ai/models/google.py +15 -1
- pydantic_ai/models/groq.py +122 -34
- pydantic_ai/models/instrumented.py +5 -0
- pydantic_ai/models/openai.py +17 -13
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/retries.py +42 -2
- pydantic_ai/tools.py +7 -7
- pydantic_ai/toolsets/combined.py +2 -2
- pydantic_ai/toolsets/function.py +47 -19
- pydantic_ai/usage.py +37 -3
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.1.dist-info}/METADATA +6 -7
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.1.dist-info}/RECORD +32 -31
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -7,8 +7,11 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import BaseModel, Json, ValidationError
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
13
|
+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
14
|
+
|
|
12
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
16
|
from .._run_context import RunContext
|
|
14
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
@@ -48,7 +51,7 @@ from . import (
|
|
|
48
51
|
)
|
|
49
52
|
|
|
50
53
|
try:
|
|
51
|
-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
54
|
+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
|
|
52
55
|
from groq.types import chat
|
|
53
56
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
57
|
except ImportError as _import_error:
|
|
@@ -169,9 +172,24 @@ class GroqModel(Model):
|
|
|
169
172
|
model_request_parameters: ModelRequestParameters,
|
|
170
173
|
) -> ModelResponse:
|
|
171
174
|
check_allow_model_requests()
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
+
try:
|
|
176
|
+
response = await self._completions_create(
|
|
177
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
|
+
)
|
|
179
|
+
except ModelHTTPError as e:
|
|
180
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
181
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
182
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
183
|
+
try:
|
|
184
|
+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
185
|
+
tool_call_part = ToolCallPart(
|
|
186
|
+
tool_name=error.error.failed_generation.name,
|
|
187
|
+
args=error.error.failed_generation.arguments,
|
|
188
|
+
)
|
|
189
|
+
return ModelResponse(parts=[tool_call_part])
|
|
190
|
+
except ValidationError:
|
|
191
|
+
pass
|
|
192
|
+
raise
|
|
175
193
|
model_response = self._process_response(response)
|
|
176
194
|
return model_response
|
|
177
195
|
|
|
@@ -228,6 +246,18 @@ class GroqModel(Model):
|
|
|
228
246
|
|
|
229
247
|
groq_messages = self._map_messages(messages)
|
|
230
248
|
|
|
249
|
+
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
250
|
+
if model_request_parameters.output_mode == 'native':
|
|
251
|
+
output_object = model_request_parameters.output_object
|
|
252
|
+
assert output_object is not None
|
|
253
|
+
response_format = self._map_json_schema(output_object)
|
|
254
|
+
elif (
|
|
255
|
+
model_request_parameters.output_mode == 'prompted'
|
|
256
|
+
and not tools
|
|
257
|
+
and self.profile.supports_json_object_output
|
|
258
|
+
): # pragma: no branch
|
|
259
|
+
response_format = {'type': 'json_object'}
|
|
260
|
+
|
|
231
261
|
try:
|
|
232
262
|
extra_headers = model_settings.get('extra_headers', {})
|
|
233
263
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -240,6 +270,7 @@ class GroqModel(Model):
|
|
|
240
270
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
241
271
|
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
242
272
|
stream=stream,
|
|
273
|
+
response_format=response_format or NOT_GIVEN,
|
|
243
274
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
244
275
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
245
276
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -385,6 +416,19 @@ class GroqModel(Model):
|
|
|
385
416
|
},
|
|
386
417
|
}
|
|
387
418
|
|
|
419
|
+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
420
|
+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
|
|
421
|
+
'type': 'json_schema',
|
|
422
|
+
'json_schema': {
|
|
423
|
+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
|
|
424
|
+
'schema': o.json_schema,
|
|
425
|
+
'strict': o.strict,
|
|
426
|
+
},
|
|
427
|
+
}
|
|
428
|
+
if o.description: # pragma: no branch
|
|
429
|
+
response_format_param['json_schema']['description'] = o.description
|
|
430
|
+
return response_format_param
|
|
431
|
+
|
|
388
432
|
@classmethod
|
|
389
433
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
390
434
|
for part in message.parts:
|
|
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
449
493
|
_provider_name: str
|
|
450
494
|
|
|
451
495
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
496
|
+
try:
|
|
497
|
+
async for chunk in self._response:
|
|
498
|
+
self._usage += _map_usage(chunk)
|
|
499
|
+
|
|
500
|
+
try:
|
|
501
|
+
choice = chunk.choices[0]
|
|
502
|
+
except IndexError:
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
# Handle the text part of the response
|
|
506
|
+
content = choice.delta.content
|
|
507
|
+
if content is not None:
|
|
508
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
509
|
+
vendor_part_id='content',
|
|
510
|
+
content=content,
|
|
511
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
512
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
513
|
+
)
|
|
514
|
+
if maybe_event is not None: # pragma: no branch
|
|
515
|
+
yield maybe_event
|
|
516
|
+
|
|
517
|
+
# Handle the tool calls
|
|
518
|
+
for dtc in choice.delta.tool_calls or []:
|
|
519
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
520
|
+
vendor_part_id=dtc.index,
|
|
521
|
+
tool_name=dtc.function and dtc.function.name,
|
|
522
|
+
args=dtc.function and dtc.function.arguments,
|
|
523
|
+
tool_call_id=dtc.id,
|
|
524
|
+
)
|
|
525
|
+
if maybe_event is not None:
|
|
526
|
+
yield maybe_event
|
|
527
|
+
except APIError as e:
|
|
528
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
529
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
530
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
|
|
531
|
+
try:
|
|
532
|
+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
533
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
534
|
+
vendor_part_id='tool_use_failed',
|
|
535
|
+
tool_name=error.failed_generation.name,
|
|
536
|
+
args=error.failed_generation.arguments,
|
|
537
|
+
)
|
|
538
|
+
return
|
|
539
|
+
except ValidationError as e: # pragma: no cover
|
|
540
|
+
pass
|
|
541
|
+
raise # pragma: no cover
|
|
482
542
|
|
|
483
543
|
@property
|
|
484
544
|
def model_name(self) -> GroqModelName:
|
|
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
510
570
|
input_tokens=response_usage.prompt_tokens,
|
|
511
571
|
output_tokens=response_usage.completion_tokens,
|
|
512
572
|
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class _GroqToolUseFailedGeneration(BaseModel):
|
|
576
|
+
name: str
|
|
577
|
+
arguments: dict[str, Any]
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class _GroqToolUseFailedInnerError(BaseModel):
|
|
581
|
+
message: str
|
|
582
|
+
type: Literal['invalid_request_error']
|
|
583
|
+
code: Literal['tool_use_failed']
|
|
584
|
+
failed_generation: Json[_GroqToolUseFailedGeneration]
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class _GroqToolUseFailedError(BaseModel):
|
|
588
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
589
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
590
|
+
# Example payload from `exception.body`:
|
|
591
|
+
# {
|
|
592
|
+
# 'error': {
|
|
593
|
+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
|
|
594
|
+
# 'type': 'invalid_request_error',
|
|
595
|
+
# 'code': 'tool_use_failed',
|
|
596
|
+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
|
|
597
|
+
# }
|
|
598
|
+
# }
|
|
599
|
+
|
|
600
|
+
error: _GroqToolUseFailedInnerError
|
|
@@ -420,10 +420,15 @@ class InstrumentedModel(WrapperModel):
|
|
|
420
420
|
return
|
|
421
421
|
|
|
422
422
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
423
|
+
try:
|
|
424
|
+
cost_attributes = {'operation.cost': float(response.cost().total_price)}
|
|
425
|
+
except LookupError:
|
|
426
|
+
cost_attributes = {}
|
|
423
427
|
span.set_attributes(
|
|
424
428
|
{
|
|
425
429
|
**response.usage.opentelemetry_attributes(),
|
|
426
430
|
'gen_ai.response.model': response_model,
|
|
431
|
+
**cost_attributes,
|
|
427
432
|
}
|
|
428
433
|
)
|
|
429
434
|
span.update_name(f'{operation} {request_model}')
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -225,6 +225,7 @@ class OpenAIChatModel(Model):
|
|
|
225
225
|
'openrouter',
|
|
226
226
|
'together',
|
|
227
227
|
'vercel',
|
|
228
|
+
'litellm',
|
|
228
229
|
]
|
|
229
230
|
| Provider[AsyncOpenAI] = 'openai',
|
|
230
231
|
profile: ModelProfileSpec | None = None,
|
|
@@ -252,6 +253,7 @@ class OpenAIChatModel(Model):
|
|
|
252
253
|
'openrouter',
|
|
253
254
|
'together',
|
|
254
255
|
'vercel',
|
|
256
|
+
'litellm',
|
|
255
257
|
]
|
|
256
258
|
| Provider[AsyncOpenAI] = 'openai',
|
|
257
259
|
profile: ModelProfileSpec | None = None,
|
|
@@ -278,6 +280,7 @@ class OpenAIChatModel(Model):
|
|
|
278
280
|
'openrouter',
|
|
279
281
|
'together',
|
|
280
282
|
'vercel',
|
|
283
|
+
'litellm',
|
|
281
284
|
]
|
|
282
285
|
| Provider[AsyncOpenAI] = 'openai',
|
|
283
286
|
profile: ModelProfileSpec | None = None,
|
|
@@ -606,7 +609,7 @@ class OpenAIChatModel(Model):
|
|
|
606
609
|
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
607
610
|
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
|
|
608
611
|
'type': 'json_schema',
|
|
609
|
-
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema
|
|
612
|
+
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
|
|
610
613
|
}
|
|
611
614
|
if o.description:
|
|
612
615
|
response_format_param['json_schema']['description'] = o.description
|
|
@@ -1171,6 +1174,10 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1171
1174
|
except IndexError:
|
|
1172
1175
|
continue
|
|
1173
1176
|
|
|
1177
|
+
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
|
|
1178
|
+
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
1179
|
+
continue
|
|
1180
|
+
|
|
1174
1181
|
# Handle the text part of the response
|
|
1175
1182
|
content = choice.delta.content
|
|
1176
1183
|
if content is not None:
|
|
@@ -1270,12 +1277,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1270
1277
|
tool_call_id=chunk.item.call_id,
|
|
1271
1278
|
)
|
|
1272
1279
|
elif isinstance(chunk.item, responses.ResponseReasoningItem):
|
|
1273
|
-
|
|
1274
|
-
yield self._parts_manager.handle_thinking_delta(
|
|
1275
|
-
vendor_part_id=chunk.item.id,
|
|
1276
|
-
content=content,
|
|
1277
|
-
signature=chunk.item.id,
|
|
1278
|
-
)
|
|
1280
|
+
pass
|
|
1279
1281
|
elif isinstance(chunk.item, responses.ResponseOutputMessage):
|
|
1280
1282
|
pass
|
|
1281
1283
|
elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
|
|
@@ -1291,7 +1293,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1291
1293
|
pass
|
|
1292
1294
|
|
|
1293
1295
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
|
|
1294
|
-
|
|
1296
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
1297
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1298
|
+
content=chunk.part.text,
|
|
1299
|
+
id=chunk.item_id,
|
|
1300
|
+
)
|
|
1295
1301
|
|
|
1296
1302
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
|
|
1297
1303
|
pass # there's nothing we need to do here
|
|
@@ -1301,9 +1307,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1301
1307
|
|
|
1302
1308
|
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
|
|
1303
1309
|
yield self._parts_manager.handle_thinking_delta(
|
|
1304
|
-
vendor_part_id=chunk.item_id,
|
|
1310
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1305
1311
|
content=chunk.delta,
|
|
1306
|
-
|
|
1312
|
+
id=chunk.item_id,
|
|
1307
1313
|
)
|
|
1308
1314
|
|
|
1309
1315
|
# TODO(Marcelo): We should support annotations in the future.
|
|
@@ -1311,9 +1317,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1311
1317
|
pass # there's nothing we need to do here
|
|
1312
1318
|
|
|
1313
1319
|
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
1314
|
-
maybe_event = self._parts_manager.handle_text_delta(
|
|
1315
|
-
vendor_part_id=chunk.content_index, content=chunk.delta
|
|
1316
|
-
)
|
|
1320
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
|
|
1317
1321
|
if maybe_event is not None: # pragma: no branch
|
|
1318
1322
|
yield maybe_event
|
|
1319
1323
|
|
|
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
135
135
|
from .github import GitHubProvider
|
|
136
136
|
|
|
137
137
|
return GitHubProvider
|
|
138
|
+
elif provider == 'litellm':
|
|
139
|
+
from .litellm import LiteLLMProvider
|
|
140
|
+
|
|
141
|
+
return LiteLLMProvider
|
|
138
142
|
else: # pragma: no cover
|
|
139
143
|
raise ValueError(f'Unknown provider: {provider}')
|
|
140
144
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
+
from asyncio import Lock
|
|
4
5
|
from collections.abc import AsyncGenerator, Mapping
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Literal, overload
|
|
@@ -118,7 +119,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
118
119
|
class _VertexAIAuth(httpx.Auth):
|
|
119
120
|
"""Auth class for Vertex AI API."""
|
|
120
121
|
|
|
121
|
-
_refresh_lock:
|
|
122
|
+
_refresh_lock: Lock = Lock()
|
|
122
123
|
|
|
123
124
|
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
124
125
|
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -14,6 +14,7 @@ from pydantic_ai.profiles.groq import groq_model_profile
|
|
|
14
14
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
15
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
16
|
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
17
|
+
from pydantic_ai.profiles.openai import openai_model_profile
|
|
17
18
|
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
19
|
from pydantic_ai.providers import Provider
|
|
19
20
|
|
|
@@ -26,6 +27,23 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
26
27
|
) from _import_error
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
|
|
31
|
+
"""Get the model profile for an MoonshotAI model used with the Groq provider."""
|
|
32
|
+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
|
|
33
|
+
moonshotai_model_profile(model_name)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
|
|
38
|
+
"""Get the model profile for a Meta model used with the Groq provider."""
|
|
39
|
+
if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
|
|
40
|
+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
|
|
41
|
+
meta_model_profile(model_name)
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
return meta_model_profile(model_name)
|
|
45
|
+
|
|
46
|
+
|
|
29
47
|
class GroqProvider(Provider[AsyncGroq]):
|
|
30
48
|
"""Provider for Groq API."""
|
|
31
49
|
|
|
@@ -44,13 +62,14 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
44
62
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
63
|
prefix_to_profile = {
|
|
46
64
|
'llama': meta_model_profile,
|
|
47
|
-
'meta-llama/':
|
|
65
|
+
'meta-llama/': meta_groq_model_profile,
|
|
48
66
|
'gemma': google_model_profile,
|
|
49
67
|
'qwen': qwen_model_profile,
|
|
50
68
|
'deepseek': deepseek_model_profile,
|
|
51
69
|
'mistral': mistral_model_profile,
|
|
52
|
-
'moonshotai/':
|
|
70
|
+
'moonshotai/': groq_moonshotai_model_profile,
|
|
53
71
|
'compound-': groq_model_profile,
|
|
72
|
+
'openai/': openai_model_profile,
|
|
54
73
|
}
|
|
55
74
|
|
|
56
75
|
for prefix, profile_func in prefix_to_profile.items():
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from typing import overload
|
|
4
|
+
|
|
5
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.amazon import amazon_model_profile
|
|
11
|
+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
12
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
13
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
14
|
+
from pydantic_ai.profiles.google import google_model_profile
|
|
15
|
+
from pydantic_ai.profiles.grok import grok_model_profile
|
|
16
|
+
from pydantic_ai.profiles.groq import groq_model_profile
|
|
17
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
18
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
19
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
20
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
|
|
21
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
22
|
+
from pydantic_ai.providers import Provider
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from openai import AsyncOpenAI
|
|
26
|
+
except ImportError as _import_error: # pragma: no cover
|
|
27
|
+
raise ImportError(
|
|
28
|
+
'Please install the `openai` package to use the LiteLLM provider, '
|
|
29
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
30
|
+
) from _import_error
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LiteLLMProvider(Provider[AsyncOpenAI]):
|
|
34
|
+
"""Provider for LiteLLM API."""
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def name(self) -> str:
|
|
38
|
+
return 'litellm'
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def base_url(self) -> str:
|
|
42
|
+
return str(self.client.base_url)
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def client(self) -> AsyncOpenAI:
|
|
46
|
+
return self._client
|
|
47
|
+
|
|
48
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
49
|
+
# Map provider prefixes to their profile functions
|
|
50
|
+
provider_to_profile = {
|
|
51
|
+
'anthropic': anthropic_model_profile,
|
|
52
|
+
'openai': openai_model_profile,
|
|
53
|
+
'google': google_model_profile,
|
|
54
|
+
'mistralai': mistral_model_profile,
|
|
55
|
+
'mistral': mistral_model_profile,
|
|
56
|
+
'cohere': cohere_model_profile,
|
|
57
|
+
'amazon': amazon_model_profile,
|
|
58
|
+
'bedrock': amazon_model_profile,
|
|
59
|
+
'meta-llama': meta_model_profile,
|
|
60
|
+
'meta': meta_model_profile,
|
|
61
|
+
'groq': groq_model_profile,
|
|
62
|
+
'deepseek': deepseek_model_profile,
|
|
63
|
+
'moonshotai': moonshotai_model_profile,
|
|
64
|
+
'x-ai': grok_model_profile,
|
|
65
|
+
'qwen': qwen_model_profile,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
profile = None
|
|
69
|
+
|
|
70
|
+
# Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
|
|
71
|
+
if '/' in model_name:
|
|
72
|
+
provider_prefix, model_suffix = model_name.split('/', 1)
|
|
73
|
+
if provider_prefix in provider_to_profile:
|
|
74
|
+
profile = provider_to_profile[provider_prefix](model_suffix)
|
|
75
|
+
|
|
76
|
+
# If no profile found, default to OpenAI profile
|
|
77
|
+
if profile is None:
|
|
78
|
+
profile = openai_model_profile(model_name)
|
|
79
|
+
|
|
80
|
+
# As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
|
|
81
|
+
# we maintain that behavior
|
|
82
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
83
|
+
|
|
84
|
+
@overload
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
*,
|
|
88
|
+
api_key: str | None = None,
|
|
89
|
+
api_base: str | None = None,
|
|
90
|
+
) -> None: ...
|
|
91
|
+
|
|
92
|
+
@overload
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
*,
|
|
96
|
+
api_key: str | None = None,
|
|
97
|
+
api_base: str | None = None,
|
|
98
|
+
http_client: AsyncHTTPClient,
|
|
99
|
+
) -> None: ...
|
|
100
|
+
|
|
101
|
+
@overload
|
|
102
|
+
def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
*,
|
|
107
|
+
api_key: str | None = None,
|
|
108
|
+
api_base: str | None = None,
|
|
109
|
+
openai_client: AsyncOpenAI | None = None,
|
|
110
|
+
http_client: AsyncHTTPClient | None = None,
|
|
111
|
+
) -> None:
|
|
112
|
+
"""Initialize a LiteLLM provider.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
|
|
116
|
+
api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
|
|
117
|
+
openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
|
|
118
|
+
http_client: Custom HTTP client to use.
|
|
119
|
+
"""
|
|
120
|
+
if openai_client is not None:
|
|
121
|
+
self._client = openai_client
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
# Create OpenAI client that will be used with LiteLLM's completion function
|
|
125
|
+
# The actual API calls will be intercepted and routed through LiteLLM
|
|
126
|
+
if http_client is not None:
|
|
127
|
+
self._client = AsyncOpenAI(
|
|
128
|
+
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
http_client = cached_async_http_client(provider='litellm')
|
|
132
|
+
self._client = AsyncOpenAI(
|
|
133
|
+
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
|
|
134
|
+
)
|
pydantic_ai/retries.py
CHANGED
|
@@ -13,6 +13,8 @@ The module includes:
|
|
|
13
13
|
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
|
+
from types import TracebackType
|
|
17
|
+
|
|
16
18
|
from httpx import (
|
|
17
19
|
AsyncBaseTransport,
|
|
18
20
|
AsyncHTTPTransport,
|
|
@@ -185,11 +187,30 @@ class TenacityTransport(BaseTransport):
|
|
|
185
187
|
response.request = req
|
|
186
188
|
|
|
187
189
|
if self.validate_response:
|
|
188
|
-
|
|
190
|
+
try:
|
|
191
|
+
self.validate_response(response)
|
|
192
|
+
except Exception:
|
|
193
|
+
response.close()
|
|
194
|
+
raise
|
|
189
195
|
return response
|
|
190
196
|
|
|
191
197
|
return handle_request(request)
|
|
192
198
|
|
|
199
|
+
def __enter__(self) -> TenacityTransport:
|
|
200
|
+
self.wrapped.__enter__()
|
|
201
|
+
return self
|
|
202
|
+
|
|
203
|
+
def __exit__(
|
|
204
|
+
self,
|
|
205
|
+
exc_type: type[BaseException] | None = None,
|
|
206
|
+
exc_value: BaseException | None = None,
|
|
207
|
+
traceback: TracebackType | None = None,
|
|
208
|
+
) -> None:
|
|
209
|
+
self.wrapped.__exit__(exc_type, exc_value, traceback)
|
|
210
|
+
|
|
211
|
+
def close(self) -> None:
|
|
212
|
+
self.wrapped.close() # pragma: no cover
|
|
213
|
+
|
|
193
214
|
|
|
194
215
|
class AsyncTenacityTransport(AsyncBaseTransport):
|
|
195
216
|
"""Asynchronous HTTP transport with tenacity-based retry functionality.
|
|
@@ -263,11 +284,30 @@ class AsyncTenacityTransport(AsyncBaseTransport):
|
|
|
263
284
|
response.request = req
|
|
264
285
|
|
|
265
286
|
if self.validate_response:
|
|
266
|
-
|
|
287
|
+
try:
|
|
288
|
+
self.validate_response(response)
|
|
289
|
+
except Exception:
|
|
290
|
+
await response.aclose()
|
|
291
|
+
raise
|
|
267
292
|
return response
|
|
268
293
|
|
|
269
294
|
return await handle_async_request(request)
|
|
270
295
|
|
|
296
|
+
async def __aenter__(self) -> AsyncTenacityTransport:
|
|
297
|
+
await self.wrapped.__aenter__()
|
|
298
|
+
return self
|
|
299
|
+
|
|
300
|
+
async def __aexit__(
|
|
301
|
+
self,
|
|
302
|
+
exc_type: type[BaseException] | None = None,
|
|
303
|
+
exc_value: BaseException | None = None,
|
|
304
|
+
traceback: TracebackType | None = None,
|
|
305
|
+
) -> None:
|
|
306
|
+
await self.wrapped.__aexit__(exc_type, exc_value, traceback)
|
|
307
|
+
|
|
308
|
+
async def aclose(self) -> None:
|
|
309
|
+
await self.wrapped.aclose()
|
|
310
|
+
|
|
271
311
|
|
|
272
312
|
def wait_retry_after(
|
|
273
313
|
fallback_strategy: Callable[[RetryCallState], float] | None = None, max_wait: float = 300
|
pydantic_ai/tools.py
CHANGED
|
@@ -70,7 +70,7 @@ Usage `ToolFuncEither[AgentDepsT, ToolParams]`.
|
|
|
70
70
|
ToolPrepareFunc: TypeAlias = Callable[[RunContext[AgentDepsT], 'ToolDefinition'], Awaitable['ToolDefinition | None']]
|
|
71
71
|
"""Definition of a function that can prepare a tool definition at call time.
|
|
72
72
|
|
|
73
|
-
See [tool docs](../tools.md#tool-prepare) for more information.
|
|
73
|
+
See [tool docs](../tools-advanced.md#tool-prepare) for more information.
|
|
74
74
|
|
|
75
75
|
Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
|
|
76
76
|
|
|
@@ -140,7 +140,7 @@ class DeferredToolRequests:
|
|
|
140
140
|
|
|
141
141
|
Results can be passed to the next agent run using a [`DeferredToolResults`][pydantic_ai.tools.DeferredToolResults] object with the same tool call IDs.
|
|
142
142
|
|
|
143
|
-
See [deferred tools docs](../tools.md#deferred-tools) for more information.
|
|
143
|
+
See [deferred tools docs](../deferred-tools.md#deferred-tools) for more information.
|
|
144
144
|
"""
|
|
145
145
|
|
|
146
146
|
calls: list[ToolCallPart] = field(default_factory=list)
|
|
@@ -204,7 +204,7 @@ class DeferredToolResults:
|
|
|
204
204
|
|
|
205
205
|
The tool call IDs need to match those from the [`DeferredToolRequests`][pydantic_ai.output.DeferredToolRequests] output object from the previous run.
|
|
206
206
|
|
|
207
|
-
See [deferred tools docs](../tools.md#deferred-tools) for more information.
|
|
207
|
+
See [deferred tools docs](../deferred-tools.md#deferred-tools) for more information.
|
|
208
208
|
"""
|
|
209
209
|
|
|
210
210
|
calls: dict[str, DeferredToolCallResult | Any] = field(default_factory=dict)
|
|
@@ -328,7 +328,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
328
328
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
329
329
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
330
330
|
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
331
|
-
See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
|
|
331
|
+
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
332
332
|
function_schema: The function schema to use for the tool. If not provided, it will be generated.
|
|
333
333
|
"""
|
|
334
334
|
self.function = function
|
|
@@ -472,16 +472,16 @@ class ToolDefinition:
|
|
|
472
472
|
- `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model
|
|
473
473
|
- `'output'`: a tool that passes through an output value that ends the run
|
|
474
474
|
- `'external'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running.
|
|
475
|
-
See the [tools documentation](../tools.md#deferred-tools) for more info.
|
|
475
|
+
See the [tools documentation](../deferred-tools.md#deferred-tools) for more info.
|
|
476
476
|
- `'unapproved'`: a tool that requires human-in-the-loop approval.
|
|
477
|
-
See the [tools documentation](../tools.md#human-in-the-loop-tool-approval) for more info.
|
|
477
|
+
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
478
478
|
"""
|
|
479
479
|
|
|
480
480
|
@property
|
|
481
481
|
def defer(self) -> bool:
|
|
482
482
|
"""Whether calls to this tool will be deferred.
|
|
483
483
|
|
|
484
|
-
See the [tools documentation](../tools.md#deferred-tools) for more info.
|
|
484
|
+
See the [tools documentation](../deferred-tools.md#deferred-tools) for more info.
|
|
485
485
|
"""
|
|
486
486
|
return self.kind in ('external', 'unapproved')
|
|
487
487
|
|