pydantic-ai 0.0.32__tar.gz → 0.0.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/PKG-INFO +3 -3
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/pyproject.toml +3 -3
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/conftest.py +1 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/json_body_serializer.py +19 -13
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_gemini/test_image_as_binary_content_input.yaml +3 -3
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_gemini/test_image_url_input.yaml +9 -9
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_gemini.py +29 -29
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_model.py +2 -2
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_model_names.py +11 -2
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_openai.py +24 -21
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_vertexai.py +3 -0
- pydantic_ai-0.0.33/tests/providers/__init__.py +0 -0
- pydantic_ai-0.0.33/tests/providers/test_deepseek.py +48 -0
- pydantic_ai-0.0.33/tests/providers/test_google_gla.py +19 -0
- pydantic_ai-0.0.33/tests/providers/test_google_vertex.py +110 -0
- pydantic_ai-0.0.33/tests/providers/test_provider_names.py +44 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_examples.py +5 -3
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_live.py +14 -5
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/.gitignore +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/LICENSE +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/Makefile +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/README.md +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/__init__.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/assets/kiwi.png +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/assets/marcelo.mp3 +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/example_modules/README.md +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/example_modules/bank_database.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/example_modules/fake_database.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/example_modules/weather_service.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/__init__.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/test_graph.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/test_history.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/test_mermaid.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/test_state.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/graph/test_utils.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/import_examples.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/__init__.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_anthropic/test_image_url_input.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_anthropic/test_image_url_input_invalid_mime_type.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_anthropic/test_multiple_parallel_tool_calls.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_groq/test_image_as_binary_content_input.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_groq/test_image_url_input.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_openai/test_audio_as_binary_content_input.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_openai/test_image_as_binary_content_input.yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[developer].yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[system].yaml +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/mock_async_stream.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_anthropic.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_cohere.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_fallback.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_groq.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_instrumented.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_mistral.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_model_function.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/test_model_test.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_agent.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_deps.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_format_as_xml.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_json_body_serializer.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_logfire.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_parts_manager.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_streaming.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_tools.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_usage_limits.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/test_utils.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/typed_agent.py +0 -0
- {pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/typed_graph.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -28,9 +28,9 @@ Classifier: Topic :: Internet
|
|
|
28
28
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
29
29
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
30
30
|
Requires-Python: >=3.9
|
|
31
|
-
Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.
|
|
31
|
+
Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.33
|
|
32
32
|
Provides-Extra: examples
|
|
33
|
-
Requires-Dist: pydantic-ai-examples==0.0.
|
|
33
|
+
Requires-Dist: pydantic-ai-examples==0.0.33; extra == 'examples'
|
|
34
34
|
Provides-Extra: logfire
|
|
35
35
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
36
36
|
Description-Content-Type: text/markdown
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.33"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -37,7 +37,7 @@ classifiers = [
|
|
|
37
37
|
requires-python = ">=3.9"
|
|
38
38
|
|
|
39
39
|
dependencies = [
|
|
40
|
-
"pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.
|
|
40
|
+
"pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.33",
|
|
41
41
|
]
|
|
42
42
|
|
|
43
43
|
[project.urls]
|
|
@@ -47,7 +47,7 @@ Documentation = "https://ai.pydantic.dev"
|
|
|
47
47
|
Changelog = "https://github.com/pydantic/pydantic-ai/releases"
|
|
48
48
|
|
|
49
49
|
[project.optional-dependencies]
|
|
50
|
-
examples = ["pydantic-ai-examples==0.0.
|
|
50
|
+
examples = ["pydantic-ai-examples==0.0.33"]
|
|
51
51
|
logfire = ["logfire>=2.3"]
|
|
52
52
|
|
|
53
53
|
[tool.uv.sources]
|
|
@@ -200,6 +200,7 @@ def pytest_recording_configure(config: Any, vcr: VCR):
|
|
|
200
200
|
@pytest.fixture(scope='module')
|
|
201
201
|
def vcr_config():
|
|
202
202
|
return {
|
|
203
|
+
'ignore_localhost': True,
|
|
203
204
|
# Note: additional header filtering is done inside the serializer
|
|
204
205
|
'filter_headers': ['authorization', 'x-api-key'],
|
|
205
206
|
'decode_compressed_response': True,
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
|
|
2
2
|
import json
|
|
3
|
+
import urllib.parse
|
|
3
4
|
from typing import TYPE_CHECKING, Any
|
|
4
5
|
|
|
5
6
|
import yaml
|
|
@@ -59,19 +60,24 @@ def serialize(cassette_dict: Any):
|
|
|
59
60
|
# update headers on source object
|
|
60
61
|
data['headers'] = headers
|
|
61
62
|
|
|
62
|
-
content_type = headers.get('content-type',
|
|
63
|
-
if
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
63
|
+
content_type = headers.get('content-type', [])
|
|
64
|
+
if any(header.startswith('application/json') for header in content_type):
|
|
65
|
+
# Parse the body as JSON
|
|
66
|
+
body: Any = data.get('body', None)
|
|
67
|
+
assert body is not None, data
|
|
68
|
+
if isinstance(body, dict):
|
|
69
|
+
# Responses will have the body under a field called 'string'
|
|
70
|
+
body = body.get('string')
|
|
71
|
+
if body is not None:
|
|
72
|
+
data['parsed_body'] = json.loads(body)
|
|
73
|
+
if 'access_token' in data['parsed_body']:
|
|
74
|
+
data['parsed_body']['access_token'] = 'scrubbed'
|
|
75
|
+
del data['body']
|
|
76
|
+
if content_type == ['application/x-www-form-urlencoded']:
|
|
77
|
+
query_params = urllib.parse.parse_qs(data['body'])
|
|
78
|
+
if 'client_secret' in query_params:
|
|
79
|
+
query_params['client_secret'] = ['scrubbed']
|
|
80
|
+
data['body'] = urllib.parse.urlencode(query_params)
|
|
75
81
|
|
|
76
82
|
# Use our custom dumper
|
|
77
83
|
return yaml.dump(cassette_dict, Dumper=LiteralDumper, allow_unicode=True, width=120)
|
|
@@ -38,7 +38,7 @@ interactions:
|
|
|
38
38
|
"role": "model"
|
|
39
39
|
},
|
|
40
40
|
"finishReason": "STOP",
|
|
41
|
-
"avgLogprobs": -0.
|
|
41
|
+
"avgLogprobs": -0.031536102294921875
|
|
42
42
|
}
|
|
43
43
|
],
|
|
44
44
|
"usageMetadata": {
|
|
@@ -68,11 +68,11 @@ interactions:
|
|
|
68
68
|
alt-svc:
|
|
69
69
|
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
70
70
|
content-length:
|
|
71
|
-
- '
|
|
71
|
+
- '711'
|
|
72
72
|
content-type:
|
|
73
73
|
- application/json; charset=UTF-8
|
|
74
74
|
server-timing:
|
|
75
|
-
- gfet4t7; dur=
|
|
75
|
+
- gfet4t7; dur=2657
|
|
76
76
|
transfer-encoding:
|
|
77
77
|
- chunked
|
|
78
78
|
vary:
|
|
@@ -6793,7 +6793,7 @@ interactions:
|
|
|
6793
6793
|
access-control-allow-origin:
|
|
6794
6794
|
- '*'
|
|
6795
6795
|
age:
|
|
6796
|
-
- '
|
|
6796
|
+
- '88'
|
|
6797
6797
|
alt-svc:
|
|
6798
6798
|
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
6799
6799
|
cache-control:
|
|
@@ -6809,7 +6809,7 @@ interactions:
|
|
|
6809
6809
|
cross-origin-resource-policy:
|
|
6810
6810
|
- cross-origin
|
|
6811
6811
|
expires:
|
|
6812
|
-
- Wed,
|
|
6812
|
+
- Wed, 04 Mar 2026 10:32:03 GMT
|
|
6813
6813
|
last-modified:
|
|
6814
6814
|
- Mon, 30 Aug 2123 17:01:05 GMT
|
|
6815
6815
|
report-to:
|
|
@@ -6850,19 +6850,19 @@ interactions:
|
|
|
6850
6850
|
"content": {
|
|
6851
6851
|
"parts": [
|
|
6852
6852
|
{
|
|
6853
|
-
"text": "This is not a fruit
|
|
6853
|
+
"text": "This is not a fruit; it's a pipe organ console."
|
|
6854
6854
|
}
|
|
6855
6855
|
],
|
|
6856
6856
|
"role": "model"
|
|
6857
6857
|
},
|
|
6858
6858
|
"finishReason": "STOP",
|
|
6859
|
-
"avgLogprobs": -0.
|
|
6859
|
+
"avgLogprobs": -0.31288215092250277
|
|
6860
6860
|
}
|
|
6861
6861
|
],
|
|
6862
6862
|
"usageMetadata": {
|
|
6863
6863
|
"promptTokenCount": 1814,
|
|
6864
|
-
"candidatesTokenCount":
|
|
6865
|
-
"totalTokenCount":
|
|
6864
|
+
"candidatesTokenCount": 14,
|
|
6865
|
+
"totalTokenCount": 1828,
|
|
6866
6866
|
"promptTokensDetails": [
|
|
6867
6867
|
{
|
|
6868
6868
|
"modality": "TEXT",
|
|
@@ -6876,7 +6876,7 @@ interactions:
|
|
|
6876
6876
|
"candidatesTokensDetails": [
|
|
6877
6877
|
{
|
|
6878
6878
|
"modality": "TEXT",
|
|
6879
|
-
"tokenCount":
|
|
6879
|
+
"tokenCount": 14
|
|
6880
6880
|
}
|
|
6881
6881
|
]
|
|
6882
6882
|
},
|
|
@@ -6886,11 +6886,11 @@ interactions:
|
|
|
6886
6886
|
alt-svc:
|
|
6887
6887
|
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
|
6888
6888
|
content-length:
|
|
6889
|
-
- '
|
|
6889
|
+
- '730'
|
|
6890
6890
|
content-type:
|
|
6891
6891
|
- application/json; charset=UTF-8
|
|
6892
6892
|
server-timing:
|
|
6893
|
-
- gfet4t7; dur=
|
|
6893
|
+
- gfet4t7; dur=1749
|
|
6894
6894
|
transfer-encoding:
|
|
6895
6895
|
- chunked
|
|
6896
6896
|
vary:
|
|
@@ -45,6 +45,7 @@ from pydantic_ai.models.gemini import (
|
|
|
45
45
|
_GeminiTools,
|
|
46
46
|
_GeminiUsageMetaData,
|
|
47
47
|
)
|
|
48
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
48
49
|
from pydantic_ai.result import Usage
|
|
49
50
|
from pydantic_ai.tools import ToolDefinition
|
|
50
51
|
|
|
@@ -55,9 +56,8 @@ pytestmark = pytest.mark.anyio
|
|
|
55
56
|
|
|
56
57
|
def test_api_key_arg(env: TestEnv):
|
|
57
58
|
env.set('GEMINI_API_KEY', 'via-env-var')
|
|
58
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
59
|
-
assert
|
|
60
|
-
assert m.auth.api_key == 'via-arg'
|
|
59
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
60
|
+
assert m.client.headers['x-goog-api-key'] == 'via-arg'
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def test_api_key_env_var(env: TestEnv):
|
|
@@ -80,11 +80,10 @@ def test_api_key_empty(env: TestEnv):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
async def test_model_simple(allow_model_requests: None):
|
|
83
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
84
|
-
assert isinstance(m.
|
|
83
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
84
|
+
assert isinstance(m.client, httpx.AsyncClient)
|
|
85
85
|
assert m.model_name == 'gemini-1.5-flash'
|
|
86
|
-
assert
|
|
87
|
-
assert m.auth.api_key == 'via-arg'
|
|
86
|
+
assert 'x-goog-api-key' in m.client.headers
|
|
88
87
|
|
|
89
88
|
arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[])
|
|
90
89
|
tools = m._get_tools(arc)
|
|
@@ -94,7 +93,7 @@ async def test_model_simple(allow_model_requests: None):
|
|
|
94
93
|
|
|
95
94
|
|
|
96
95
|
async def test_model_tools(allow_model_requests: None):
|
|
97
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
96
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
98
97
|
tools = [
|
|
99
98
|
ToolDefinition(
|
|
100
99
|
'foo',
|
|
@@ -153,7 +152,7 @@ async def test_model_tools(allow_model_requests: None):
|
|
|
153
152
|
|
|
154
153
|
|
|
155
154
|
async def test_require_response_tool(allow_model_requests: None):
|
|
156
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
155
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
157
156
|
result_tool = ToolDefinition(
|
|
158
157
|
'result',
|
|
159
158
|
'This is the tool for the final Result',
|
|
@@ -212,7 +211,7 @@ async def test_json_def_replaced(allow_model_requests: None):
|
|
|
212
211
|
}
|
|
213
212
|
)
|
|
214
213
|
|
|
215
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
214
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
216
215
|
result_tool = ToolDefinition(
|
|
217
216
|
'result',
|
|
218
217
|
'This is the tool for the final Result',
|
|
@@ -259,7 +258,7 @@ async def test_json_def_replaced_any_of(allow_model_requests: None):
|
|
|
259
258
|
|
|
260
259
|
json_schema = Locations.model_json_schema()
|
|
261
260
|
|
|
262
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
261
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
263
262
|
result_tool = ToolDefinition(
|
|
264
263
|
'result',
|
|
265
264
|
'This is the tool for the final Result',
|
|
@@ -322,7 +321,7 @@ async def test_json_def_recursive(allow_model_requests: None):
|
|
|
322
321
|
}
|
|
323
322
|
)
|
|
324
323
|
|
|
325
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
324
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
326
325
|
result_tool = ToolDefinition(
|
|
327
326
|
'result',
|
|
328
327
|
'This is the tool for the final Result',
|
|
@@ -354,7 +353,7 @@ async def test_json_def_date(allow_model_requests: None):
|
|
|
354
353
|
}
|
|
355
354
|
)
|
|
356
355
|
|
|
357
|
-
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
356
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
|
|
358
357
|
result_tool = ToolDefinition(
|
|
359
358
|
'result',
|
|
360
359
|
'This is the tool for the final Result',
|
|
@@ -451,7 +450,7 @@ def example_usage() -> _GeminiUsageMetaData:
|
|
|
451
450
|
async def test_text_success(get_gemini_client: GetGeminiClient):
|
|
452
451
|
response = gemini_response(_content_model_response(ModelResponse(parts=[TextPart('Hello world')])))
|
|
453
452
|
gemini_client = get_gemini_client(response)
|
|
454
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
453
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
455
454
|
agent = Agent(m)
|
|
456
455
|
|
|
457
456
|
result = await agent.run('Hello')
|
|
@@ -493,7 +492,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
|
|
|
493
492
|
_content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]})]))
|
|
494
493
|
)
|
|
495
494
|
gemini_client = get_gemini_client(response)
|
|
496
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
495
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
497
496
|
agent = Agent(m, result_type=list[int])
|
|
498
497
|
|
|
499
498
|
result = await agent.run('Hello')
|
|
@@ -540,7 +539,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
540
539
|
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('final response')]))),
|
|
541
540
|
]
|
|
542
541
|
gemini_client = get_gemini_client(responses)
|
|
543
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
542
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
544
543
|
agent = Agent(m, system_prompt='this is the system prompt')
|
|
545
544
|
|
|
546
545
|
@agent.tool_plain
|
|
@@ -622,7 +621,7 @@ async def test_unexpected_response(client_with_handler: ClientWithHandler, env:
|
|
|
622
621
|
return httpx.Response(401, content='invalid request')
|
|
623
622
|
|
|
624
623
|
gemini_client = client_with_handler(handler)
|
|
625
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
624
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
626
625
|
agent = Agent(m, system_prompt='this is the system prompt')
|
|
627
626
|
|
|
628
627
|
with pytest.raises(ModelHTTPError) as exc_info:
|
|
@@ -639,7 +638,7 @@ async def test_stream_text(get_gemini_client: GetGeminiClient):
|
|
|
639
638
|
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
|
|
640
639
|
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
|
|
641
640
|
gemini_client = get_gemini_client(stream)
|
|
642
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
641
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
643
642
|
agent = Agent(m)
|
|
644
643
|
|
|
645
644
|
async with agent.run_stream('Hello') as result:
|
|
@@ -684,7 +683,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
|
|
|
684
683
|
|
|
685
684
|
stream = AsyncByteStreamList(parts)
|
|
686
685
|
gemini_client = get_gemini_client(stream)
|
|
687
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
686
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
688
687
|
agent = Agent(m)
|
|
689
688
|
|
|
690
689
|
async with agent.run_stream('Hello') as result:
|
|
@@ -698,7 +697,7 @@ async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
|
|
|
698
697
|
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
|
|
699
698
|
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
|
|
700
699
|
gemini_client = get_gemini_client(stream)
|
|
701
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
700
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
702
701
|
agent = Agent(m)
|
|
703
702
|
with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'):
|
|
704
703
|
async with agent.run_stream('Hello'):
|
|
@@ -714,7 +713,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient):
|
|
|
714
713
|
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
|
|
715
714
|
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
|
|
716
715
|
gemini_client = get_gemini_client(stream)
|
|
717
|
-
model = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
716
|
+
model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
718
717
|
agent = Agent(model, result_type=tuple[int, int])
|
|
719
718
|
|
|
720
719
|
async with agent.run_stream('Hello') as result:
|
|
@@ -744,7 +743,7 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
|
|
|
744
743
|
second_stream = AsyncByteStreamList([d2[:100], d2[100:]])
|
|
745
744
|
|
|
746
745
|
gemini_client = get_gemini_client([first_stream, second_stream])
|
|
747
|
-
model = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
746
|
+
model = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
748
747
|
agent = Agent(model, result_type=tuple[int, int])
|
|
749
748
|
tool_calls: list[str] = []
|
|
750
749
|
|
|
@@ -817,7 +816,7 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
|
|
|
817
816
|
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
|
|
818
817
|
stream = AsyncByteStreamList([json_data[:100], json_data[100:200], json_data[200:]])
|
|
819
818
|
gemini_client = get_gemini_client(stream)
|
|
820
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
819
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
|
|
821
820
|
agent = Agent(m)
|
|
822
821
|
|
|
823
822
|
@agent.tool_plain()
|
|
@@ -887,7 +886,7 @@ async def test_model_settings(client_with_handler: ClientWithHandler, env: TestE
|
|
|
887
886
|
)
|
|
888
887
|
|
|
889
888
|
gemini_client = client_with_handler(handler)
|
|
890
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
|
|
889
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
|
|
891
890
|
agent = Agent(m)
|
|
892
891
|
|
|
893
892
|
result = await agent.run(
|
|
@@ -939,7 +938,8 @@ async def test_safety_settings_unsafe(
|
|
|
939
938
|
)
|
|
940
939
|
|
|
941
940
|
gemini_client = client_with_handler(handler)
|
|
942
|
-
|
|
941
|
+
|
|
942
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
|
|
943
943
|
agent = Agent(m)
|
|
944
944
|
|
|
945
945
|
await agent.run(
|
|
@@ -975,7 +975,7 @@ async def test_safety_settings_safe(
|
|
|
975
975
|
)
|
|
976
976
|
|
|
977
977
|
gemini_client = client_with_handler(handler)
|
|
978
|
-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
|
|
978
|
+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client, api_key='mock'))
|
|
979
979
|
agent = Agent(m)
|
|
980
980
|
|
|
981
981
|
result = await agent.run(
|
|
@@ -994,7 +994,7 @@ async def test_safety_settings_safe(
|
|
|
994
994
|
async def test_image_as_binary_content_input(
|
|
995
995
|
allow_model_requests: None, gemini_api_key: str, image_content: BinaryContent
|
|
996
996
|
) -> None:
|
|
997
|
-
m = GeminiModel('gemini-2.0-flash', api_key=gemini_api_key)
|
|
997
|
+
m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key))
|
|
998
998
|
agent = Agent(m)
|
|
999
999
|
|
|
1000
1000
|
result = await agent.run(['What is the name of this fruit?', image_content])
|
|
@@ -1003,10 +1003,10 @@ async def test_image_as_binary_content_input(
|
|
|
1003
1003
|
|
|
1004
1004
|
@pytest.mark.vcr()
|
|
1005
1005
|
async def test_image_url_input(allow_model_requests: None, gemini_api_key: str) -> None:
|
|
1006
|
-
m = GeminiModel('gemini-2.0-flash-exp', api_key=gemini_api_key)
|
|
1006
|
+
m = GeminiModel('gemini-2.0-flash-exp', provider=GoogleGLAProvider(api_key=gemini_api_key))
|
|
1007
1007
|
agent = Agent(m)
|
|
1008
1008
|
|
|
1009
1009
|
image_url = ImageUrl(url='https://goo.gle/instrument-img')
|
|
1010
1010
|
|
|
1011
1011
|
result = await agent.run(['What is the name of this fruit?', image_url])
|
|
1012
|
-
assert result.data == snapshot(
|
|
1012
|
+
assert result.data == snapshot("This is not a fruit; it's a pipe organ console.")
|
|
@@ -19,7 +19,7 @@ TEST_CASES = [
|
|
|
19
19
|
'gemini-1.5-flash',
|
|
20
20
|
'google-vertex',
|
|
21
21
|
'vertexai',
|
|
22
|
-
'
|
|
22
|
+
'GeminiModel',
|
|
23
23
|
),
|
|
24
24
|
(
|
|
25
25
|
'GEMINI_API_KEY',
|
|
@@ -27,7 +27,7 @@ TEST_CASES = [
|
|
|
27
27
|
'gemini-1.5-flash',
|
|
28
28
|
'google-vertex',
|
|
29
29
|
'vertexai',
|
|
30
|
-
'
|
|
30
|
+
'GeminiModel',
|
|
31
31
|
),
|
|
32
32
|
(
|
|
33
33
|
'ANTHROPIC_API_KEY',
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from collections.abc import Iterator
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
|
+
from typing_extensions import get_args
|
|
5
6
|
|
|
6
7
|
from pydantic_ai.models import KnownModelName
|
|
7
8
|
|
|
@@ -40,10 +41,18 @@ def test_known_model_names():
|
|
|
40
41
|
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
|
|
41
42
|
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
|
|
42
43
|
]
|
|
44
|
+
deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner']
|
|
43
45
|
extra_names = ['test']
|
|
44
46
|
|
|
45
47
|
generated_names = sorted(
|
|
46
|
-
anthropic_names
|
|
48
|
+
anthropic_names
|
|
49
|
+
+ cohere_names
|
|
50
|
+
+ google_names
|
|
51
|
+
+ groq_names
|
|
52
|
+
+ mistral_names
|
|
53
|
+
+ openai_names
|
|
54
|
+
+ deepseek_names
|
|
55
|
+
+ extra_names
|
|
47
56
|
)
|
|
48
57
|
|
|
49
58
|
known_model_names = sorted(get_args(KnownModelName))
|
|
@@ -46,6 +46,7 @@ with try_import() as imports_successful:
|
|
|
46
46
|
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails
|
|
47
47
|
|
|
48
48
|
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
|
|
49
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
49
50
|
|
|
50
51
|
# note: we use Union here so that casting works with Python 3.9
|
|
51
52
|
MockChatCompletion = Union[chat.ChatCompletion, Exception]
|
|
@@ -58,26 +59,26 @@ pytestmark = [
|
|
|
58
59
|
|
|
59
60
|
|
|
60
61
|
def test_init():
|
|
61
|
-
m = OpenAIModel('gpt-4o', api_key='foobar')
|
|
62
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key='foobar'))
|
|
62
63
|
assert str(m.client.base_url) == 'https://api.openai.com/v1/'
|
|
63
64
|
assert m.client.api_key == 'foobar'
|
|
64
65
|
assert m.model_name == 'gpt-4o'
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
def test_init_with_base_url():
|
|
68
|
-
m = OpenAIModel('gpt-4o', base_url='https://example.com/v1', api_key='foobar')
|
|
69
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(base_url='https://example.com/v1', api_key='foobar'))
|
|
69
70
|
assert str(m.client.base_url) == 'https://example.com/v1/'
|
|
70
71
|
assert m.client.api_key == 'foobar'
|
|
71
72
|
assert m.model_name == 'gpt-4o'
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
def test_init_with_no_api_key_will_still_setup_client():
|
|
75
|
-
m = OpenAIModel('llama3.2', base_url='http://localhost:19434/v1')
|
|
76
|
+
m = OpenAIModel('llama3.2', provider=OpenAIProvider(base_url='http://localhost:19434/v1'))
|
|
76
77
|
assert str(m.client.base_url) == 'http://localhost:19434/v1/'
|
|
77
78
|
|
|
78
79
|
|
|
79
80
|
def test_init_with_non_openai_model():
|
|
80
|
-
m = OpenAIModel('llama3.2-vision:latest', base_url='https://example.com/v1/')
|
|
81
|
+
m = OpenAIModel('llama3.2-vision:latest', provider=OpenAIProvider(base_url='https://example.com/v1/'))
|
|
81
82
|
assert m.model_name == 'llama3.2-vision:latest'
|
|
82
83
|
|
|
83
84
|
|
|
@@ -157,7 +158,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
|
|
|
157
158
|
async def test_request_simple_success(allow_model_requests: None):
|
|
158
159
|
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
|
|
159
160
|
mock_client = MockOpenAI.create_mock(c)
|
|
160
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
161
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
161
162
|
agent = Agent(m)
|
|
162
163
|
|
|
163
164
|
result = await agent.run('hello')
|
|
@@ -206,7 +207,7 @@ async def test_request_simple_usage(allow_model_requests: None):
|
|
|
206
207
|
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
|
|
207
208
|
)
|
|
208
209
|
mock_client = MockOpenAI.create_mock(c)
|
|
209
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
210
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
210
211
|
agent = Agent(m)
|
|
211
212
|
|
|
212
213
|
result = await agent.run('Hello')
|
|
@@ -229,7 +230,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
229
230
|
)
|
|
230
231
|
)
|
|
231
232
|
mock_client = MockOpenAI.create_mock(c)
|
|
232
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
233
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
233
234
|
agent = Agent(m, result_type=list[int])
|
|
234
235
|
|
|
235
236
|
result = await agent.run('Hello')
|
|
@@ -305,7 +306,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
305
306
|
completion_message(ChatCompletionMessage(content='final response', role='assistant')),
|
|
306
307
|
]
|
|
307
308
|
mock_client = MockOpenAI.create_mock(responses)
|
|
308
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
309
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
309
310
|
agent = Agent(m, system_prompt='this is the system prompt')
|
|
310
311
|
|
|
311
312
|
@agent.tool_plain
|
|
@@ -408,7 +409,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha
|
|
|
408
409
|
async def test_stream_text(allow_model_requests: None):
|
|
409
410
|
stream = [text_chunk('hello '), text_chunk('world'), chunk([])]
|
|
410
411
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
411
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
412
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
412
413
|
agent = Agent(m)
|
|
413
414
|
|
|
414
415
|
async with agent.run_stream('') as result:
|
|
@@ -425,7 +426,7 @@ async def test_stream_text_finish_reason(allow_model_requests: None):
|
|
|
425
426
|
text_chunk('.', finish_reason='stop'),
|
|
426
427
|
]
|
|
427
428
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
428
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
429
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
429
430
|
agent = Agent(m)
|
|
430
431
|
|
|
431
432
|
async with agent.run_stream('') as result:
|
|
@@ -472,7 +473,7 @@ async def test_stream_structured(allow_model_requests: None):
|
|
|
472
473
|
chunk([]),
|
|
473
474
|
]
|
|
474
475
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
475
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
476
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
476
477
|
agent = Agent(m, result_type=MyTypedDict)
|
|
477
478
|
|
|
478
479
|
async with agent.run_stream('') as result:
|
|
@@ -500,7 +501,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
|
|
|
500
501
|
struc_chunk(None, None, finish_reason='stop'),
|
|
501
502
|
]
|
|
502
503
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
503
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
504
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
504
505
|
agent = Agent(m, result_type=MyTypedDict)
|
|
505
506
|
|
|
506
507
|
async with agent.run_stream('') as result:
|
|
@@ -520,7 +521,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
|
|
|
520
521
|
async def test_no_content(allow_model_requests: None):
|
|
521
522
|
stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])]
|
|
522
523
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
523
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
524
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
524
525
|
agent = Agent(m, result_type=MyTypedDict)
|
|
525
526
|
|
|
526
527
|
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
|
|
@@ -535,7 +536,7 @@ async def test_no_delta(allow_model_requests: None):
|
|
|
535
536
|
text_chunk('world'),
|
|
536
537
|
]
|
|
537
538
|
mock_client = MockOpenAI.create_mock_stream(stream)
|
|
538
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
539
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
539
540
|
agent = Agent(m)
|
|
540
541
|
|
|
541
542
|
async with agent.run_stream('') as result:
|
|
@@ -553,7 +554,7 @@ async def test_system_prompt_role(
|
|
|
553
554
|
|
|
554
555
|
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
|
|
555
556
|
mock_client = MockOpenAI.create_mock(c)
|
|
556
|
-
m = OpenAIModel('gpt-4o', system_prompt_role=system_prompt_role, openai_client=mock_client)
|
|
557
|
+
m = OpenAIModel('gpt-4o', system_prompt_role=system_prompt_role, provider=OpenAIProvider(openai_client=mock_client))
|
|
557
558
|
assert m.system_prompt_role == system_prompt_role
|
|
558
559
|
|
|
559
560
|
agent = Agent(m, system_prompt='some instructions')
|
|
@@ -579,7 +580,9 @@ async def test_openai_o1_mini_system_role(
|
|
|
579
580
|
system_prompt_role: Literal['system', 'developer'],
|
|
580
581
|
openai_api_key: str,
|
|
581
582
|
) -> None:
|
|
582
|
-
model = OpenAIModel(
|
|
583
|
+
model = OpenAIModel(
|
|
584
|
+
'o1-mini', provider=OpenAIProvider(api_key=openai_api_key), system_prompt_role=system_prompt_role
|
|
585
|
+
)
|
|
583
586
|
agent = Agent(model=model, system_prompt='You are a helpful assistant.')
|
|
584
587
|
|
|
585
588
|
with pytest.raises(ModelHTTPError, match=r".*Unsupported value: 'messages\[0\]\.role' does not support.*"):
|
|
@@ -602,7 +605,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
|
|
|
602
605
|
)
|
|
603
606
|
)
|
|
604
607
|
mock_client = MockOpenAI.create_mock(c)
|
|
605
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
608
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
606
609
|
agent = Agent(m, result_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
|
|
607
610
|
|
|
608
611
|
await agent.run('Hello')
|
|
@@ -612,7 +615,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
|
|
|
612
615
|
async def test_image_url_input(allow_model_requests: None):
|
|
613
616
|
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
|
|
614
617
|
mock_client = MockOpenAI.create_mock(c)
|
|
615
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
618
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
616
619
|
agent = Agent(m)
|
|
617
620
|
|
|
618
621
|
result = await agent.run(
|
|
@@ -650,7 +653,7 @@ async def test_image_url_input(allow_model_requests: None):
|
|
|
650
653
|
async def test_image_as_binary_content_input(
|
|
651
654
|
allow_model_requests: None, image_content: BinaryContent, openai_api_key: str
|
|
652
655
|
):
|
|
653
|
-
m = OpenAIModel('gpt-4o', api_key=openai_api_key)
|
|
656
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key))
|
|
654
657
|
agent = Agent(m)
|
|
655
658
|
|
|
656
659
|
result = await agent.run(['What fruit is in the image?', image_content])
|
|
@@ -661,7 +664,7 @@ async def test_image_as_binary_content_input(
|
|
|
661
664
|
async def test_audio_as_binary_content_input(
|
|
662
665
|
allow_model_requests: None, audio_content: BinaryContent, openai_api_key: str
|
|
663
666
|
):
|
|
664
|
-
m = OpenAIModel('gpt-4o-audio-preview', api_key=openai_api_key)
|
|
667
|
+
m = OpenAIModel('gpt-4o-audio-preview', provider=OpenAIProvider(api_key=openai_api_key))
|
|
665
668
|
agent = Agent(m)
|
|
666
669
|
|
|
667
670
|
result = await agent.run(['Whose name is mentioned in the audio?', audio_content])
|
|
@@ -676,7 +679,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
|
|
|
676
679
|
body={'error': 'test error'},
|
|
677
680
|
)
|
|
678
681
|
)
|
|
679
|
-
m = OpenAIModel('gpt-4o', openai_client=mock_client)
|
|
682
|
+
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
|
|
680
683
|
agent = Agent(m)
|
|
681
684
|
with pytest.raises(ModelHTTPError) as exc_info:
|
|
682
685
|
agent.run_sync('hello')
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# pyright: reportDeprecated=false
|
|
1
2
|
from __future__ import annotations as _annotations
|
|
2
3
|
|
|
3
4
|
import json
|
|
@@ -22,6 +23,8 @@ with try_import() as imports_successful:
|
|
|
22
23
|
pytestmark = [
|
|
23
24
|
pytest.mark.skipif(not imports_successful(), reason='google-auth not installed'),
|
|
24
25
|
pytest.mark.anyio,
|
|
26
|
+
# This ignore is added because we should just remove the `VertexAIModel` class.
|
|
27
|
+
pytest.mark.filterwarnings('ignore::DeprecationWarning'),
|
|
25
28
|
]
|
|
26
29
|
|
|
27
30
|
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from ..conftest import try_import
|
|
9
|
+
|
|
10
|
+
with try_import() as imports_successful:
|
|
11
|
+
import openai
|
|
12
|
+
|
|
13
|
+
from pydantic_ai.providers.deepseek import DeepSeekProvider
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_deep_seek_provider():
|
|
20
|
+
provider = DeepSeekProvider(api_key='api-key')
|
|
21
|
+
assert provider.name == 'deepseek'
|
|
22
|
+
assert provider.base_url == 'https://api.deepseek.com'
|
|
23
|
+
assert isinstance(provider.client, openai.AsyncOpenAI)
|
|
24
|
+
assert provider.client.api_key == 'api-key'
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_deep_seek_provider_need_api_key() -> None:
|
|
28
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
29
|
+
with pytest.raises(
|
|
30
|
+
ValueError,
|
|
31
|
+
match=re.escape(
|
|
32
|
+
'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
|
|
33
|
+
'to use the DeepSeek provider.'
|
|
34
|
+
),
|
|
35
|
+
):
|
|
36
|
+
DeepSeekProvider()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_deep_seek_provider_pass_http_client() -> None:
|
|
40
|
+
http_client = httpx.AsyncClient()
|
|
41
|
+
provider = DeepSeekProvider(http_client=http_client, api_key='api-key')
|
|
42
|
+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_deep_seek_pass_openai_client() -> None:
|
|
46
|
+
openai_client = openai.AsyncOpenAI(api_key='api-key')
|
|
47
|
+
provider = DeepSeekProvider(openai_client=openai_client)
|
|
48
|
+
assert provider.client == openai_client
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_google_gla_provider_need_api_key() -> None:
|
|
11
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
12
|
+
with pytest.raises(
|
|
13
|
+
ValueError,
|
|
14
|
+
match=re.escape(
|
|
15
|
+
'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
|
|
16
|
+
'to use the Google GLA provider.'
|
|
17
|
+
),
|
|
18
|
+
):
|
|
19
|
+
GoogleGLAProvider()
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
import pytest
|
|
10
|
+
from inline_snapshot import snapshot
|
|
11
|
+
|
|
12
|
+
from ..conftest import try_import
|
|
13
|
+
|
|
14
|
+
with try_import() as imports_successful:
|
|
15
|
+
from google.auth.transport.requests import Request
|
|
16
|
+
|
|
17
|
+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
|
|
18
|
+
|
|
19
|
+
pytestmark = [
|
|
20
|
+
pytest.mark.skipif(not imports_successful(), reason='google-genai not installed'),
|
|
21
|
+
pytest.mark.anyio(),
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture()
|
|
26
|
+
def http_client():
|
|
27
|
+
async def handler(request: httpx.Request):
|
|
28
|
+
if (
|
|
29
|
+
request.url.path
|
|
30
|
+
== '/v1/projects/my-project-id/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent'
|
|
31
|
+
):
|
|
32
|
+
return httpx.Response(200, json={'content': 'success'})
|
|
33
|
+
raise NotImplementedError(f'Unexpected request: {request.url!r}') # pragma: no cover
|
|
34
|
+
|
|
35
|
+
return httpx.AsyncClient(transport=httpx.MockTransport(handler=handler))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_google_vertex_provider(allow_model_requests: None) -> None:
|
|
39
|
+
provider = GoogleVertexProvider()
|
|
40
|
+
assert provider.name == 'google-vertex'
|
|
41
|
+
assert provider.base_url == snapshot(
|
|
42
|
+
'https://us-central1-aiplatform.googleapis.com/v1/projects/None/locations/us-central1/publishers/google/models/'
|
|
43
|
+
)
|
|
44
|
+
assert isinstance(provider.client, httpx.AsyncClient)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class NoOpCredentials:
|
|
49
|
+
token = 'my-token'
|
|
50
|
+
|
|
51
|
+
def refresh(self, request: Request): ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@patch('pydantic_ai.providers.google_vertex.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
|
|
55
|
+
async def test_google_vertex_provider_auth(allow_model_requests: None, http_client: httpx.AsyncClient):
|
|
56
|
+
provider = GoogleVertexProvider(http_client=http_client)
|
|
57
|
+
await provider.client.post('/gemini-1.0-pro:generateContent')
|
|
58
|
+
assert provider.region == 'us-central1'
|
|
59
|
+
assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
|
|
60
|
+
assert getattr(provider.client.auth, 'token_created') is not None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def test_google_vertex_provider_service_account_file(
|
|
64
|
+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path, allow_model_requests: None
|
|
65
|
+
):
|
|
66
|
+
service_account_path = tmp_path / 'service_account.json'
|
|
67
|
+
save_service_account(service_account_path, 'my-project-id')
|
|
68
|
+
|
|
69
|
+
provider = GoogleVertexProvider(service_account_file=service_account_path)
|
|
70
|
+
monkeypatch.setattr(provider.client.auth, '_refresh_token', lambda: 'my-token')
|
|
71
|
+
await provider.client.post('/gemini-1.0-pro:generateContent')
|
|
72
|
+
assert provider.region == 'us-central1'
|
|
73
|
+
assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
|
|
74
|
+
assert getattr(provider.client.auth, 'token_created') is not None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def save_service_account(service_account_path: Path, project_id: str) -> None:
|
|
78
|
+
service_account = {
|
|
79
|
+
'type': 'service_account',
|
|
80
|
+
'project_id': project_id,
|
|
81
|
+
'private_key_id': 'abc',
|
|
82
|
+
# this is just a random private key I created with `openssl genpke ...`, it doesn't do anything
|
|
83
|
+
'private_key': (
|
|
84
|
+
'-----BEGIN PRIVATE KEY-----\n'
|
|
85
|
+
'MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMFrZYX4gZ20qv88\n'
|
|
86
|
+
'jD0QCswXgcxgP7Ta06G47QEFprDVcv4WMUBDJVAKofzVcYyhsasWsOSxcpA8LIi9\n'
|
|
87
|
+
'/VS2Otf8CmIK6nPBCD17Qgt8/IQYXOS4U2EBh0yjo0HQ4vFpkqium4lLWxrAZohA\n'
|
|
88
|
+
'8r82clV08iLRUW3J+xvN23iPHyVDAgMBAAECgYBScRJe3iNxMvbHv+kOhe30O/jJ\n'
|
|
89
|
+
'QiUlUzhtcEMk8mGwceqHvrHTcEtRKJcPC3NQvALcp9lSQQhRzjQ1PLXkC6BcfKFd\n'
|
|
90
|
+
'03q5tVPmJiqsHbSyUyHWzdlHP42xWpl/RmX/DfRKGhPOvufZpSTzkmKWtN+7osHu\n'
|
|
91
|
+
'7eiMpg2EDswCvOgf0QJBAPXLYwHbZLaM2KEMDgJSse5ZTE/0VMf+5vSTGUmHkr9c\n'
|
|
92
|
+
'Wx2G1i258kc/JgsXInPbq4BnK9hd0Xj2T5cmEmQtm4UCQQDJc02DFnPnjPnnDUwg\n'
|
|
93
|
+
'BPhrCyW+rnBGUVjehveu4XgbGx7l3wsbORTaKdCX3HIKUupgfFwFcDlMUzUy6fPO\n'
|
|
94
|
+
'IuQnAkA8FhVE/fIX4kSO0hiWnsqafr/2B7+2CG1DOraC0B6ioxwvEqhHE17T5e8R\n'
|
|
95
|
+
'5PzqH7hEMnR4dy7fCC+avpbeYHvVAkA5W58iR+5Qa49r/hlCtKeWsuHYXQqSuu62\n'
|
|
96
|
+
'zW8QWBo+fYZapRsgcSxCwc0msBm4XstlFYON+NoXpUlsabiFZOHZAkEA8Ffq3xoU\n'
|
|
97
|
+
'y0eYGy3MEzxx96F+tkl59lfkwHKWchWZJ95vAKWJaHx9WFxSWiJofbRna8Iim6pY\n'
|
|
98
|
+
'BootYWyTCfjjwA==\n'
|
|
99
|
+
'-----END PRIVATE KEY-----\n'
|
|
100
|
+
),
|
|
101
|
+
'client_email': 'testing-pydantic-ai@pydantic-ai.iam.gserviceaccount.com',
|
|
102
|
+
'client_id': '123',
|
|
103
|
+
'auth_uri': 'https://accounts.google.com/o/oauth2/auth',
|
|
104
|
+
'token_uri': 'https://oauth2.googleapis.com/token',
|
|
105
|
+
'auth_provider_x509_cert_url': 'https://www.googleapis.com/oauth2/v1/certs',
|
|
106
|
+
'client_x509_cert_url': 'https://www.googleapis.com/...',
|
|
107
|
+
'universe_domain': 'googleapis.com',
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
service_account_path.write_text(json.dumps(service_account, indent=2))
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
from unittest.mock import patch
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
10
|
+
|
|
11
|
+
from ..conftest import try_import
|
|
12
|
+
|
|
13
|
+
with try_import() as imports_successful:
|
|
14
|
+
from pydantic_ai.providers.deepseek import DeepSeekProvider
|
|
15
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
16
|
+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
|
|
17
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
18
|
+
|
|
19
|
+
test_infer_provider_params = [
|
|
20
|
+
('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
|
|
21
|
+
('openai', OpenAIProvider, None),
|
|
22
|
+
('google-vertex', GoogleVertexProvider, None),
|
|
23
|
+
('google-gla', GoogleGLAProvider, 'GEMINI_API_KEY'),
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
if not imports_successful():
|
|
27
|
+
test_infer_provider_params = []
|
|
28
|
+
|
|
29
|
+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='need to install all extra packages')
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture(scope='module', autouse=True)
|
|
33
|
+
def empty_env():
|
|
34
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
35
|
+
yield
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.mark.parametrize(('provider', 'provider_cls', 'exception_has'), test_infer_provider_params)
|
|
39
|
+
def test_infer_provider(provider: str, provider_cls: type[Provider[Any]], exception_has: str | None):
|
|
40
|
+
if exception_has is not None:
|
|
41
|
+
with pytest.raises(ValueError, match=rf'.*{exception_has}.*'):
|
|
42
|
+
infer_provider(provider)
|
|
43
|
+
else:
|
|
44
|
+
assert isinstance(infer_provider(provider), provider_cls)
|
|
@@ -36,9 +36,9 @@ from pydantic_ai.models.test import TestModel
|
|
|
36
36
|
from .conftest import ClientWithHandler, TestEnv
|
|
37
37
|
|
|
38
38
|
try:
|
|
39
|
-
from pydantic_ai.
|
|
39
|
+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
|
|
40
40
|
except ImportError:
|
|
41
|
-
|
|
41
|
+
GoogleVertexProvider = None
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
try:
|
|
@@ -47,7 +47,9 @@ except ImportError:
|
|
|
47
47
|
logfire = None
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
pytestmark = pytest.mark.skipif(
|
|
50
|
+
pytestmark = pytest.mark.skipif(
|
|
51
|
+
GoogleVertexProvider is None or logfire is None, reason='google-auth or logfire not installed'
|
|
52
|
+
)
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
def find_filter_examples() -> Iterable[CodeExample]:
|
|
@@ -23,23 +23,29 @@ pytestmark = [
|
|
|
23
23
|
|
|
24
24
|
def openai(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
25
25
|
from pydantic_ai.models.openai import OpenAIModel
|
|
26
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
26
27
|
|
|
27
|
-
return OpenAIModel('gpt-4o-mini', http_client=http_client)
|
|
28
|
+
return OpenAIModel('gpt-4o-mini', provider=OpenAIProvider(http_client=http_client))
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def gemini(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
31
32
|
from pydantic_ai.models.gemini import GeminiModel
|
|
33
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
32
34
|
|
|
33
|
-
return GeminiModel('gemini-1.5-pro', http_client=http_client)
|
|
35
|
+
return GeminiModel('gemini-1.5-pro', provider=GoogleGLAProvider(http_client=http_client))
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
def vertexai(http_client: httpx.AsyncClient, tmp_path: Path) -> Model:
|
|
37
|
-
from pydantic_ai.models.
|
|
39
|
+
from pydantic_ai.models.gemini import GeminiModel
|
|
40
|
+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
|
|
38
41
|
|
|
39
42
|
service_account_content = os.environ['GOOGLE_SERVICE_ACCOUNT_CONTENT']
|
|
40
43
|
service_account_path = tmp_path / 'service_account.json'
|
|
41
44
|
service_account_path.write_text(service_account_content)
|
|
42
|
-
return
|
|
45
|
+
return GeminiModel(
|
|
46
|
+
'gemini-1.5-flash',
|
|
47
|
+
provider=GoogleVertexProvider(service_account_file=service_account_path, http_client=http_client),
|
|
48
|
+
)
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
@@ -56,8 +62,11 @@ def anthropic(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
|
56
62
|
|
|
57
63
|
def ollama(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
58
64
|
from pydantic_ai.models.openai import OpenAIModel
|
|
65
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
59
66
|
|
|
60
|
-
return OpenAIModel(
|
|
67
|
+
return OpenAIModel(
|
|
68
|
+
'qwen2:0.5b', provider=OpenAIProvider(base_url='http://localhost:11434/v1/', http_client=http_client)
|
|
69
|
+
)
|
|
61
70
|
|
|
62
71
|
|
|
63
72
|
def mistral(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{pydantic_ai-0.0.32 → pydantic_ai-0.0.33}/tests/models/cassettes/test_groq/test_image_url_input.yaml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|