pydantic-ai 0.0.22__tar.gz → 0.0.24__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/PKG-INFO +3 -3
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/pyproject.toml +3 -3
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_anthropic.py +9 -8
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_cohere.py +2 -1
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_gemini.py +133 -30
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_groq.py +9 -8
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_mistral.py +19 -20
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model.py +31 -15
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_function.py +3 -3
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_names.py +1 -1
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_openai.py +15 -9
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_vertexai.py +15 -37
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_agent.py +18 -16
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_examples.py +3 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_logfire.py +6 -8
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/.gitignore +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/LICENSE +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/Makefile +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/README.md +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/__init__.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/conftest.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/README.md +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/bank_database.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/fake_database.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/weather_service.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/__init__.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_graph.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_history.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_mermaid.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_state.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/import_examples.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/__init__.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/mock_async_stream.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_test.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_deps.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_format_as_xml.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_live.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_parts_manager.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_streaming.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_tools.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_usage_limits.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_utils.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/typed_agent.py +0 -0
- {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/typed_graph.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.24
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -32,9 +32,9 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
32
32
|
Classifier: Topic :: Internet
|
|
33
33
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
34
34
|
Requires-Python: >=3.9
|
|
35
|
-
Requires-Dist: pydantic-ai-slim[anthropic,cohere,
|
|
35
|
+
Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.24
|
|
36
36
|
Provides-Extra: examples
|
|
37
|
-
Requires-Dist: pydantic-ai-examples==0.0.
|
|
37
|
+
Requires-Dist: pydantic-ai-examples==0.0.24; extra == 'examples'
|
|
38
38
|
Provides-Extra: logfire
|
|
39
39
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
40
40
|
Description-Content-Type: text/markdown
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.24"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -37,7 +37,7 @@ classifiers = [
|
|
|
37
37
|
]
|
|
38
38
|
requires-python = ">=3.9"
|
|
39
39
|
|
|
40
|
-
dependencies = ["pydantic-ai-slim[
|
|
40
|
+
dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.24"]
|
|
41
41
|
|
|
42
42
|
[project.urls]
|
|
43
43
|
Homepage = "https://ai.pydantic.dev"
|
|
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
|
|
|
46
46
|
Changelog = "https://github.com/pydantic/pydantic-ai/releases"
|
|
47
47
|
|
|
48
48
|
[project.optional-dependencies]
|
|
49
|
-
examples = ["pydantic-ai-examples==0.0.
|
|
49
|
+
examples = ["pydantic-ai-examples==0.0.24"]
|
|
50
50
|
logfire = ["logfire>=2.3"]
|
|
51
51
|
|
|
52
52
|
[tool.uv.sources]
|
|
@@ -60,7 +60,8 @@ T = TypeVar('T')
|
|
|
60
60
|
def test_init():
|
|
61
61
|
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
|
|
62
62
|
assert m.client.api_key == 'foobar'
|
|
63
|
-
assert m.
|
|
63
|
+
assert m.model_name == 'claude-3-5-haiku-latest'
|
|
64
|
+
assert m.system == 'anthropic'
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
@dataclass
|
|
@@ -111,7 +112,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
|
|
|
111
112
|
return AnthropicMessage(
|
|
112
113
|
id='123',
|
|
113
114
|
content=content,
|
|
114
|
-
model='claude-3-5-haiku-
|
|
115
|
+
model='claude-3-5-haiku-123',
|
|
115
116
|
role='assistant',
|
|
116
117
|
stop_reason='end_turn',
|
|
117
118
|
type='message',
|
|
@@ -140,13 +141,13 @@ async def test_sync_request_text_response(allow_model_requests: None):
|
|
|
140
141
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
141
142
|
ModelResponse(
|
|
142
143
|
parts=[TextPart(content='world')],
|
|
143
|
-
model_name='claude-3-5-haiku-
|
|
144
|
+
model_name='claude-3-5-haiku-123',
|
|
144
145
|
timestamp=IsNow(tz=timezone.utc),
|
|
145
146
|
),
|
|
146
147
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
147
148
|
ModelResponse(
|
|
148
149
|
parts=[TextPart(content='world')],
|
|
149
|
-
model_name='claude-3-5-haiku-
|
|
150
|
+
model_name='claude-3-5-haiku-123',
|
|
150
151
|
timestamp=IsNow(tz=timezone.utc),
|
|
151
152
|
),
|
|
152
153
|
]
|
|
@@ -189,7 +190,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
189
190
|
tool_call_id='123',
|
|
190
191
|
)
|
|
191
192
|
],
|
|
192
|
-
model_name='claude-3-5-haiku-
|
|
193
|
+
model_name='claude-3-5-haiku-123',
|
|
193
194
|
timestamp=IsNow(tz=timezone.utc),
|
|
194
195
|
),
|
|
195
196
|
ModelRequest(
|
|
@@ -251,7 +252,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
251
252
|
tool_call_id='1',
|
|
252
253
|
)
|
|
253
254
|
],
|
|
254
|
-
model_name='claude-3-5-haiku-
|
|
255
|
+
model_name='claude-3-5-haiku-123',
|
|
255
256
|
timestamp=IsNow(tz=timezone.utc),
|
|
256
257
|
),
|
|
257
258
|
ModelRequest(
|
|
@@ -272,7 +273,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
272
273
|
tool_call_id='2',
|
|
273
274
|
)
|
|
274
275
|
],
|
|
275
|
-
model_name='claude-3-5-haiku-
|
|
276
|
+
model_name='claude-3-5-haiku-123',
|
|
276
277
|
timestamp=IsNow(tz=timezone.utc),
|
|
277
278
|
),
|
|
278
279
|
ModelRequest(
|
|
@@ -287,7 +288,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
287
288
|
),
|
|
288
289
|
ModelResponse(
|
|
289
290
|
parts=[TextPart(content='final response')],
|
|
290
|
-
model_name='claude-3-5-haiku-
|
|
291
|
+
model_name='claude-3-5-haiku-123',
|
|
291
292
|
timestamp=IsNow(tz=timezone.utc),
|
|
292
293
|
),
|
|
293
294
|
]
|
|
@@ -24,9 +24,11 @@ from pydantic_ai.messages import (
|
|
|
24
24
|
ToolReturnPart,
|
|
25
25
|
UserPromptPart,
|
|
26
26
|
)
|
|
27
|
+
from pydantic_ai.models import ModelRequestParameters
|
|
27
28
|
from pydantic_ai.models.gemini import (
|
|
28
29
|
ApiKeyAuth,
|
|
29
30
|
GeminiModel,
|
|
31
|
+
GeminiModelSettings,
|
|
30
32
|
_content_model_response,
|
|
31
33
|
_function_call_part_from_call,
|
|
32
34
|
_gemini_response_ta,
|
|
@@ -36,6 +38,7 @@ from pydantic_ai.models.gemini import (
|
|
|
36
38
|
_GeminiFunction,
|
|
37
39
|
_GeminiFunctionCallingConfig,
|
|
38
40
|
_GeminiResponse,
|
|
41
|
+
_GeminiSafetyRating,
|
|
39
42
|
_GeminiTextPart,
|
|
40
43
|
_GeminiToolConfig,
|
|
41
44
|
_GeminiTools,
|
|
@@ -75,18 +78,21 @@ def test_api_key_empty(env: TestEnv):
|
|
|
75
78
|
GeminiModel('gemini-1.5-flash')
|
|
76
79
|
|
|
77
80
|
|
|
78
|
-
async def
|
|
81
|
+
async def test_model_simple(allow_model_requests: None):
|
|
79
82
|
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
80
|
-
|
|
81
|
-
assert
|
|
82
|
-
assert
|
|
83
|
-
assert
|
|
84
|
-
assert agent_model.auth.api_key == 'via-arg'
|
|
85
|
-
assert agent_model.tools is None
|
|
86
|
-
assert agent_model.tool_config is None
|
|
83
|
+
assert isinstance(m.http_client, httpx.AsyncClient)
|
|
84
|
+
assert m.model_name == 'gemini-1.5-flash'
|
|
85
|
+
assert isinstance(m.auth, ApiKeyAuth)
|
|
86
|
+
assert m.auth.api_key == 'via-arg'
|
|
87
87
|
|
|
88
|
+
arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[])
|
|
89
|
+
tools = m._get_tools(arc)
|
|
90
|
+
tool_config = m._get_tool_config(arc, tools)
|
|
91
|
+
assert tools is None
|
|
92
|
+
assert tool_config is None
|
|
88
93
|
|
|
89
|
-
|
|
94
|
+
|
|
95
|
+
async def test_model_tools(allow_model_requests: None):
|
|
90
96
|
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
|
|
91
97
|
tools = [
|
|
92
98
|
ToolDefinition(
|
|
@@ -110,8 +116,11 @@ async def test_agent_model_tools(allow_model_requests: None):
|
|
|
110
116
|
'This is the tool for the final Result',
|
|
111
117
|
{'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
|
|
112
118
|
)
|
|
113
|
-
|
|
114
|
-
|
|
119
|
+
|
|
120
|
+
arc = ModelRequestParameters(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
|
|
121
|
+
tools = m._get_tools(arc)
|
|
122
|
+
tool_config = m._get_tool_config(arc, tools)
|
|
123
|
+
assert tools == snapshot(
|
|
115
124
|
_GeminiTools(
|
|
116
125
|
function_declarations=[
|
|
117
126
|
_GeminiFunction(
|
|
@@ -139,7 +148,7 @@ async def test_agent_model_tools(allow_model_requests: None):
|
|
|
139
148
|
]
|
|
140
149
|
)
|
|
141
150
|
)
|
|
142
|
-
assert
|
|
151
|
+
assert tool_config is None
|
|
143
152
|
|
|
144
153
|
|
|
145
154
|
async def test_require_response_tool(allow_model_requests: None):
|
|
@@ -149,8 +158,10 @@ async def test_require_response_tool(allow_model_requests: None):
|
|
|
149
158
|
'This is the tool for the final Result',
|
|
150
159
|
{'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
|
|
151
160
|
)
|
|
152
|
-
|
|
153
|
-
|
|
161
|
+
arc = ModelRequestParameters(function_tools=[], allow_text_result=False, result_tools=[result_tool])
|
|
162
|
+
tools = m._get_tools(arc)
|
|
163
|
+
tool_config = m._get_tool_config(arc, tools)
|
|
164
|
+
assert tools == snapshot(
|
|
154
165
|
_GeminiTools(
|
|
155
166
|
function_declarations=[
|
|
156
167
|
_GeminiFunction(
|
|
@@ -164,7 +175,7 @@ async def test_require_response_tool(allow_model_requests: None):
|
|
|
164
175
|
]
|
|
165
176
|
)
|
|
166
177
|
)
|
|
167
|
-
assert
|
|
178
|
+
assert tool_config == snapshot(
|
|
168
179
|
_GeminiToolConfig(
|
|
169
180
|
function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result'])
|
|
170
181
|
)
|
|
@@ -206,8 +217,9 @@ async def test_json_def_replaced(allow_model_requests: None):
|
|
|
206
217
|
'This is the tool for the final Result',
|
|
207
218
|
json_schema,
|
|
208
219
|
)
|
|
209
|
-
|
|
210
|
-
|
|
220
|
+
assert m._get_tools(
|
|
221
|
+
ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
|
|
222
|
+
) == snapshot(
|
|
211
223
|
_GeminiTools(
|
|
212
224
|
function_declarations=[
|
|
213
225
|
_GeminiFunction(
|
|
@@ -252,8 +264,9 @@ async def test_json_def_replaced_any_of(allow_model_requests: None):
|
|
|
252
264
|
'This is the tool for the final Result',
|
|
253
265
|
json_schema,
|
|
254
266
|
)
|
|
255
|
-
|
|
256
|
-
|
|
267
|
+
assert m._get_tools(
|
|
268
|
+
ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
|
|
269
|
+
) == snapshot(
|
|
257
270
|
_GeminiTools(
|
|
258
271
|
function_declarations=[
|
|
259
272
|
_GeminiFunction(
|
|
@@ -315,7 +328,7 @@ async def test_json_def_recursive(allow_model_requests: None):
|
|
|
315
328
|
json_schema,
|
|
316
329
|
)
|
|
317
330
|
with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'):
|
|
318
|
-
|
|
331
|
+
m._get_tools(ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]))
|
|
319
332
|
|
|
320
333
|
|
|
321
334
|
async def test_json_def_date(allow_model_requests: None):
|
|
@@ -346,8 +359,9 @@ async def test_json_def_date(allow_model_requests: None):
|
|
|
346
359
|
'This is the tool for the final Result',
|
|
347
360
|
json_schema,
|
|
348
361
|
)
|
|
349
|
-
|
|
350
|
-
|
|
362
|
+
assert m._get_tools(
|
|
363
|
+
ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
|
|
364
|
+
) == snapshot(
|
|
351
365
|
_GeminiTools(
|
|
352
366
|
function_declarations=[
|
|
353
367
|
_GeminiFunction(
|
|
@@ -426,7 +440,7 @@ def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | No
|
|
|
426
440
|
candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[])
|
|
427
441
|
if finish_reason: # pragma: no cover
|
|
428
442
|
candidate['finish_reason'] = finish_reason
|
|
429
|
-
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
|
|
443
|
+
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123')
|
|
430
444
|
|
|
431
445
|
|
|
432
446
|
def example_usage() -> _GeminiUsageMetaData:
|
|
@@ -445,7 +459,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
|
|
|
445
459
|
[
|
|
446
460
|
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
447
461
|
ModelResponse(
|
|
448
|
-
parts=[TextPart(content='Hello world')],
|
|
462
|
+
parts=[TextPart(content='Hello world')],
|
|
463
|
+
model_name='gemini-1.5-flash-123',
|
|
464
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
449
465
|
),
|
|
450
466
|
]
|
|
451
467
|
)
|
|
@@ -458,13 +474,13 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
|
|
|
458
474
|
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
459
475
|
ModelResponse(
|
|
460
476
|
parts=[TextPart(content='Hello world')],
|
|
461
|
-
model_name='gemini-1.5-flash',
|
|
477
|
+
model_name='gemini-1.5-flash-123',
|
|
462
478
|
timestamp=IsNow(tz=timezone.utc),
|
|
463
479
|
),
|
|
464
480
|
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
465
481
|
ModelResponse(
|
|
466
482
|
parts=[TextPart(content='Hello world')],
|
|
467
|
-
model_name='gemini-1.5-flash',
|
|
483
|
+
model_name='gemini-1.5-flash-123',
|
|
468
484
|
timestamp=IsNow(tz=timezone.utc),
|
|
469
485
|
),
|
|
470
486
|
]
|
|
@@ -491,7 +507,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
|
|
|
491
507
|
args={'response': [1, 2, 123]},
|
|
492
508
|
)
|
|
493
509
|
],
|
|
494
|
-
model_name='gemini-1.5-flash',
|
|
510
|
+
model_name='gemini-1.5-flash-123',
|
|
495
511
|
timestamp=IsNow(tz=timezone.utc),
|
|
496
512
|
),
|
|
497
513
|
ModelRequest(
|
|
@@ -552,7 +568,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
552
568
|
args={'loc_name': 'San Fransisco'},
|
|
553
569
|
)
|
|
554
570
|
],
|
|
555
|
-
model_name='gemini-1.5-flash',
|
|
571
|
+
model_name='gemini-1.5-flash-123',
|
|
556
572
|
timestamp=IsNow(tz=timezone.utc),
|
|
557
573
|
),
|
|
558
574
|
ModelRequest(
|
|
@@ -575,7 +591,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
575
591
|
args={'loc_name': 'New York'},
|
|
576
592
|
),
|
|
577
593
|
],
|
|
578
|
-
model_name='gemini-1.5-flash',
|
|
594
|
+
model_name='gemini-1.5-flash-123',
|
|
579
595
|
timestamp=IsNow(tz=timezone.utc),
|
|
580
596
|
),
|
|
581
597
|
ModelRequest(
|
|
@@ -590,7 +606,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
590
606
|
),
|
|
591
607
|
ModelResponse(
|
|
592
608
|
parts=[TextPart(content='final response')],
|
|
593
|
-
model_name='gemini-1.5-flash',
|
|
609
|
+
model_name='gemini-1.5-flash-123',
|
|
594
610
|
timestamp=IsNow(tz=timezone.utc),
|
|
595
611
|
),
|
|
596
612
|
]
|
|
@@ -853,3 +869,90 @@ async def test_model_settings(client_with_handler: ClientWithHandler, env: TestE
|
|
|
853
869
|
},
|
|
854
870
|
)
|
|
855
871
|
assert result.data == 'world'
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def gemini_no_content_response(
|
|
875
|
+
safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY'
|
|
876
|
+
) -> _GeminiResponse:
|
|
877
|
+
candidate = _GeminiCandidates(safety_ratings=safety_ratings)
|
|
878
|
+
if finish_reason:
|
|
879
|
+
candidate['finish_reason'] = finish_reason
|
|
880
|
+
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
async def test_safety_settings_unsafe(
|
|
884
|
+
client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
|
|
885
|
+
) -> None:
|
|
886
|
+
try:
|
|
887
|
+
|
|
888
|
+
def handler(request: httpx.Request) -> httpx.Response:
|
|
889
|
+
safety_settings = json.loads(request.content)['safety_settings']
|
|
890
|
+
assert safety_settings == [
|
|
891
|
+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
892
|
+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
893
|
+
]
|
|
894
|
+
|
|
895
|
+
return httpx.Response(
|
|
896
|
+
200,
|
|
897
|
+
content=_gemini_response_ta.dump_json(
|
|
898
|
+
gemini_no_content_response(
|
|
899
|
+
finish_reason='SAFETY',
|
|
900
|
+
safety_ratings=[
|
|
901
|
+
{'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True}
|
|
902
|
+
],
|
|
903
|
+
),
|
|
904
|
+
by_alias=True,
|
|
905
|
+
),
|
|
906
|
+
headers={'Content-Type': 'application/json'},
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
gemini_client = client_with_handler(handler)
|
|
910
|
+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
|
|
911
|
+
agent = Agent(m)
|
|
912
|
+
|
|
913
|
+
await agent.run(
|
|
914
|
+
'a request for something rude',
|
|
915
|
+
model_settings=GeminiModelSettings(
|
|
916
|
+
gemini_safety_settings=[
|
|
917
|
+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
918
|
+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
919
|
+
]
|
|
920
|
+
),
|
|
921
|
+
)
|
|
922
|
+
except UnexpectedModelBehavior as e:
|
|
923
|
+
assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')"
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
async def test_safety_settings_safe(
|
|
927
|
+
client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
|
|
928
|
+
) -> None:
|
|
929
|
+
def handler(request: httpx.Request) -> httpx.Response:
|
|
930
|
+
safety_settings = json.loads(request.content)['safety_settings']
|
|
931
|
+
assert safety_settings == [
|
|
932
|
+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
933
|
+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
934
|
+
]
|
|
935
|
+
|
|
936
|
+
return httpx.Response(
|
|
937
|
+
200,
|
|
938
|
+
content=_gemini_response_ta.dump_json(
|
|
939
|
+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
|
|
940
|
+
by_alias=True,
|
|
941
|
+
),
|
|
942
|
+
headers={'Content-Type': 'application/json'},
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
gemini_client = client_with_handler(handler)
|
|
946
|
+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
|
|
947
|
+
agent = Agent(m)
|
|
948
|
+
|
|
949
|
+
result = await agent.run(
|
|
950
|
+
'hello',
|
|
951
|
+
model_settings=GeminiModelSettings(
|
|
952
|
+
gemini_safety_settings=[
|
|
953
|
+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
954
|
+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
|
|
955
|
+
]
|
|
956
|
+
),
|
|
957
|
+
)
|
|
958
|
+
assert result.data == 'world'
|
|
@@ -52,7 +52,8 @@ pytestmark = [
|
|
|
52
52
|
def test_init():
|
|
53
53
|
m = GroqModel('llama-3.3-70b-versatile', api_key='foobar')
|
|
54
54
|
assert m.client.api_key == 'foobar'
|
|
55
|
-
assert m.
|
|
55
|
+
assert m.model_name == 'llama-3.3-70b-versatile'
|
|
56
|
+
assert m.system == 'groq'
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
@dataclass
|
|
@@ -102,7 +103,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
|
|
|
102
103
|
id='123',
|
|
103
104
|
choices=[Choice(finish_reason='stop', index=0, message=message)],
|
|
104
105
|
created=1704067200, # 2024-01-01
|
|
105
|
-
model='llama-3.3-70b-versatile',
|
|
106
|
+
model='llama-3.3-70b-versatile-123',
|
|
106
107
|
object='chat.completion',
|
|
107
108
|
usage=usage,
|
|
108
109
|
)
|
|
@@ -129,13 +130,13 @@ async def test_request_simple_success(allow_model_requests: None):
|
|
|
129
130
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
130
131
|
ModelResponse(
|
|
131
132
|
parts=[TextPart(content='world')],
|
|
132
|
-
model_name='llama-3.3-70b-versatile',
|
|
133
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
133
134
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
134
135
|
),
|
|
135
136
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
136
137
|
ModelResponse(
|
|
137
138
|
parts=[TextPart(content='world')],
|
|
138
|
-
model_name='llama-3.3-70b-versatile',
|
|
139
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
139
140
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
140
141
|
),
|
|
141
142
|
]
|
|
@@ -186,7 +187,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
186
187
|
tool_call_id='123',
|
|
187
188
|
)
|
|
188
189
|
],
|
|
189
|
-
model_name='llama-3.3-70b-versatile',
|
|
190
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
190
191
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
191
192
|
),
|
|
192
193
|
ModelRequest(
|
|
@@ -272,7 +273,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
272
273
|
tool_call_id='1',
|
|
273
274
|
)
|
|
274
275
|
],
|
|
275
|
-
model_name='llama-3.3-70b-versatile',
|
|
276
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
276
277
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
277
278
|
),
|
|
278
279
|
ModelRequest(
|
|
@@ -293,7 +294,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
293
294
|
tool_call_id='2',
|
|
294
295
|
)
|
|
295
296
|
],
|
|
296
|
-
model_name='llama-3.3-70b-versatile',
|
|
297
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
297
298
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
298
299
|
),
|
|
299
300
|
ModelRequest(
|
|
@@ -308,7 +309,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
308
309
|
),
|
|
309
310
|
ModelResponse(
|
|
310
311
|
parts=[TextPart(content='final response')],
|
|
311
|
-
model_name='llama-3.3-70b-versatile',
|
|
312
|
+
model_name='llama-3.3-70b-versatile-123',
|
|
312
313
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
313
314
|
),
|
|
314
315
|
]
|
|
@@ -48,7 +48,6 @@ with try_import() as imports_successful:
|
|
|
48
48
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
49
49
|
|
|
50
50
|
from pydantic_ai.models.mistral import (
|
|
51
|
-
MistralAgentModel,
|
|
52
51
|
MistralModel,
|
|
53
52
|
MistralStreamedResponse,
|
|
54
53
|
)
|
|
@@ -124,7 +123,7 @@ def completion_message(
|
|
|
124
123
|
id='123',
|
|
125
124
|
choices=[MistralChatCompletionChoice(finish_reason='stop', index=0, message=message)],
|
|
126
125
|
created=1704067200 if with_created else None, # 2024-01-01
|
|
127
|
-
model='mistral-large-
|
|
126
|
+
model='mistral-large-123',
|
|
128
127
|
object='chat.completion',
|
|
129
128
|
usage=usage or MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
|
|
130
129
|
)
|
|
@@ -218,13 +217,13 @@ async def test_multiple_completions(allow_model_requests: None):
|
|
|
218
217
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
219
218
|
ModelResponse(
|
|
220
219
|
parts=[TextPart(content='world')],
|
|
221
|
-
model_name='mistral-large-
|
|
220
|
+
model_name='mistral-large-123',
|
|
222
221
|
timestamp=IsNow(tz=timezone.utc),
|
|
223
222
|
),
|
|
224
223
|
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
|
|
225
224
|
ModelResponse(
|
|
226
225
|
parts=[TextPart(content='hello again')],
|
|
227
|
-
model_name='mistral-large-
|
|
226
|
+
model_name='mistral-large-123',
|
|
228
227
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
229
228
|
),
|
|
230
229
|
]
|
|
@@ -270,19 +269,19 @@ async def test_three_completions(allow_model_requests: None):
|
|
|
270
269
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
271
270
|
ModelResponse(
|
|
272
271
|
parts=[TextPart(content='world')],
|
|
273
|
-
model_name='mistral-large-
|
|
272
|
+
model_name='mistral-large-123',
|
|
274
273
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
275
274
|
),
|
|
276
275
|
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
|
|
277
276
|
ModelResponse(
|
|
278
277
|
parts=[TextPart(content='hello again')],
|
|
279
|
-
model_name='mistral-large-
|
|
278
|
+
model_name='mistral-large-123',
|
|
280
279
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
281
280
|
),
|
|
282
281
|
ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]),
|
|
283
282
|
ModelResponse(
|
|
284
283
|
parts=[TextPart(content='final message')],
|
|
285
|
-
model_name='mistral-large-
|
|
284
|
+
model_name='mistral-large-123',
|
|
286
285
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
287
286
|
),
|
|
288
287
|
]
|
|
@@ -397,7 +396,7 @@ async def test_request_model_structured_with_arguments_dict_response(allow_model
|
|
|
397
396
|
tool_call_id='123',
|
|
398
397
|
)
|
|
399
398
|
],
|
|
400
|
-
model_name='mistral-large-
|
|
399
|
+
model_name='mistral-large-123',
|
|
401
400
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
402
401
|
),
|
|
403
402
|
ModelRequest(
|
|
@@ -459,7 +458,7 @@ async def test_request_model_structured_with_arguments_str_response(allow_model_
|
|
|
459
458
|
tool_call_id='123',
|
|
460
459
|
)
|
|
461
460
|
],
|
|
462
|
-
model_name='mistral-large-
|
|
461
|
+
model_name='mistral-large-123',
|
|
463
462
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
464
463
|
),
|
|
465
464
|
ModelRequest(
|
|
@@ -520,7 +519,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
|
|
|
520
519
|
tool_call_id='123',
|
|
521
520
|
)
|
|
522
521
|
],
|
|
523
|
-
model_name='mistral-large-
|
|
522
|
+
model_name='mistral-large-123',
|
|
524
523
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
525
524
|
),
|
|
526
525
|
ModelRequest(
|
|
@@ -1105,7 +1104,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
1105
1104
|
tool_call_id='1',
|
|
1106
1105
|
)
|
|
1107
1106
|
],
|
|
1108
|
-
model_name='mistral-large-
|
|
1107
|
+
model_name='mistral-large-123',
|
|
1109
1108
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1110
1109
|
),
|
|
1111
1110
|
ModelRequest(
|
|
@@ -1126,7 +1125,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
1126
1125
|
tool_call_id='2',
|
|
1127
1126
|
)
|
|
1128
1127
|
],
|
|
1129
|
-
model_name='mistral-large-
|
|
1128
|
+
model_name='mistral-large-123',
|
|
1130
1129
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1131
1130
|
),
|
|
1132
1131
|
ModelRequest(
|
|
@@ -1141,7 +1140,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
1141
1140
|
),
|
|
1142
1141
|
ModelResponse(
|
|
1143
1142
|
parts=[TextPart(content='final response')],
|
|
1144
|
-
model_name='mistral-large-
|
|
1143
|
+
model_name='mistral-large-123',
|
|
1145
1144
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1146
1145
|
),
|
|
1147
1146
|
]
|
|
@@ -1245,7 +1244,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
|
|
|
1245
1244
|
tool_call_id='1',
|
|
1246
1245
|
)
|
|
1247
1246
|
],
|
|
1248
|
-
model_name='mistral-large-
|
|
1247
|
+
model_name='mistral-large-123',
|
|
1249
1248
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1250
1249
|
),
|
|
1251
1250
|
ModelRequest(
|
|
@@ -1266,7 +1265,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
|
|
|
1266
1265
|
tool_call_id='2',
|
|
1267
1266
|
)
|
|
1268
1267
|
],
|
|
1269
|
-
model_name='mistral-large-
|
|
1268
|
+
model_name='mistral-large-123',
|
|
1270
1269
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1271
1270
|
),
|
|
1272
1271
|
ModelRequest(
|
|
@@ -1287,7 +1286,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
|
|
|
1287
1286
|
tool_call_id='1',
|
|
1288
1287
|
)
|
|
1289
1288
|
],
|
|
1290
|
-
model_name='mistral-large-
|
|
1289
|
+
model_name='mistral-large-123',
|
|
1291
1290
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
1292
1291
|
),
|
|
1293
1292
|
ModelRequest(
|
|
@@ -1668,8 +1667,8 @@ def test_generate_user_output_format_complex():
|
|
|
1668
1667
|
'prop_unrecognized_type': {'type': 'customSomething'},
|
|
1669
1668
|
}
|
|
1670
1669
|
}
|
|
1671
|
-
|
|
1672
|
-
result =
|
|
1670
|
+
m = MistralModel('', json_mode_schema_prompt='{schema}')
|
|
1671
|
+
result = m._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
|
|
1673
1672
|
assert result.content == (
|
|
1674
1673
|
"{'prop_anyOf': 'Optional[str]', "
|
|
1675
1674
|
"'prop_no_type': 'Any', "
|
|
@@ -1685,8 +1684,8 @@ def test_generate_user_output_format_complex():
|
|
|
1685
1684
|
|
|
1686
1685
|
def test_generate_user_output_format_multiple():
|
|
1687
1686
|
schema = {'properties': {'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}}}
|
|
1688
|
-
|
|
1689
|
-
result =
|
|
1687
|
+
m = MistralModel('', json_mode_schema_prompt='{schema}')
|
|
1688
|
+
result = m._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
|
|
1690
1689
|
assert result.content == "[{'prop_anyOf': 'Optional[str]'}, {'prop_anyOf': 'Optional[str]'}]"
|
|
1691
1690
|
|
|
1692
1691
|
|
|
@@ -8,66 +8,81 @@ from pydantic_ai.models import infer_model
|
|
|
8
8
|
from ..conftest import TestEnv
|
|
9
9
|
|
|
10
10
|
TEST_CASES = [
|
|
11
|
-
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', '
|
|
12
|
-
('OPENAI_API_KEY', 'gpt-3.5-turbo', '
|
|
13
|
-
('OPENAI_API_KEY', 'o1', '
|
|
14
|
-
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', '
|
|
15
|
-
('GEMINI_API_KEY', 'gemini-1.5-flash', '
|
|
11
|
+
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
|
|
12
|
+
('OPENAI_API_KEY', 'gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
|
|
13
|
+
('OPENAI_API_KEY', 'o1', 'o1', 'openai', 'openai', 'OpenAIModel'),
|
|
14
|
+
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
|
|
15
|
+
('GEMINI_API_KEY', 'gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
|
|
16
16
|
(
|
|
17
17
|
'GEMINI_API_KEY',
|
|
18
18
|
'google-vertex:gemini-1.5-flash',
|
|
19
|
-
'
|
|
19
|
+
'gemini-1.5-flash',
|
|
20
|
+
'google-vertex',
|
|
20
21
|
'vertexai',
|
|
21
22
|
'VertexAIModel',
|
|
22
23
|
),
|
|
23
24
|
(
|
|
24
25
|
'GEMINI_API_KEY',
|
|
25
26
|
'vertexai:gemini-1.5-flash',
|
|
26
|
-
'
|
|
27
|
+
'gemini-1.5-flash',
|
|
28
|
+
'google-vertex',
|
|
27
29
|
'vertexai',
|
|
28
30
|
'VertexAIModel',
|
|
29
31
|
),
|
|
30
32
|
(
|
|
31
33
|
'ANTHROPIC_API_KEY',
|
|
32
34
|
'anthropic:claude-3-5-haiku-latest',
|
|
33
|
-
'
|
|
35
|
+
'claude-3-5-haiku-latest',
|
|
36
|
+
'anthropic',
|
|
34
37
|
'anthropic',
|
|
35
38
|
'AnthropicModel',
|
|
36
39
|
),
|
|
37
40
|
(
|
|
38
41
|
'ANTHROPIC_API_KEY',
|
|
39
42
|
'claude-3-5-haiku-latest',
|
|
40
|
-
'
|
|
43
|
+
'claude-3-5-haiku-latest',
|
|
44
|
+
'anthropic',
|
|
41
45
|
'anthropic',
|
|
42
46
|
'AnthropicModel',
|
|
43
47
|
),
|
|
44
48
|
(
|
|
45
49
|
'GROQ_API_KEY',
|
|
46
50
|
'groq:llama-3.3-70b-versatile',
|
|
47
|
-
'
|
|
51
|
+
'llama-3.3-70b-versatile',
|
|
52
|
+
'groq',
|
|
48
53
|
'groq',
|
|
49
54
|
'GroqModel',
|
|
50
55
|
),
|
|
51
56
|
(
|
|
52
57
|
'MISTRAL_API_KEY',
|
|
53
58
|
'mistral:mistral-small-latest',
|
|
54
|
-
'mistral
|
|
59
|
+
'mistral-small-latest',
|
|
60
|
+
'mistral',
|
|
55
61
|
'mistral',
|
|
56
62
|
'MistralModel',
|
|
57
63
|
),
|
|
58
64
|
(
|
|
59
65
|
'CO_API_KEY',
|
|
60
66
|
'cohere:command',
|
|
61
|
-
'
|
|
67
|
+
'command',
|
|
68
|
+
'cohere',
|
|
62
69
|
'cohere',
|
|
63
70
|
'CohereModel',
|
|
64
71
|
),
|
|
65
72
|
]
|
|
66
73
|
|
|
67
74
|
|
|
68
|
-
@pytest.mark.parametrize(
|
|
75
|
+
@pytest.mark.parametrize(
|
|
76
|
+
'mock_api_key, model_name, expected_model_name, expected_system, module_name, model_class_name', TEST_CASES
|
|
77
|
+
)
|
|
69
78
|
def test_infer_model(
|
|
70
|
-
env: TestEnv,
|
|
79
|
+
env: TestEnv,
|
|
80
|
+
mock_api_key: str,
|
|
81
|
+
model_name: str,
|
|
82
|
+
expected_model_name: str,
|
|
83
|
+
expected_system: str,
|
|
84
|
+
module_name: str,
|
|
85
|
+
model_class_name: str,
|
|
71
86
|
):
|
|
72
87
|
try:
|
|
73
88
|
model_module = import_module(f'pydantic_ai.models.{module_name}')
|
|
@@ -79,7 +94,8 @@ def test_infer_model(
|
|
|
79
94
|
|
|
80
95
|
m = infer_model(model_name) # pyright: ignore[reportArgumentType]
|
|
81
96
|
assert isinstance(m, expected_model)
|
|
82
|
-
assert m.
|
|
97
|
+
assert m.model_name == expected_model_name
|
|
98
|
+
assert m.system == expected_system
|
|
83
99
|
|
|
84
100
|
m2 = infer_model(m)
|
|
85
101
|
assert m2 is m
|
|
@@ -41,13 +41,13 @@ async def stream_hello(_messages: list[ModelMessage], _agent_info: AgentInfo) ->
|
|
|
41
41
|
|
|
42
42
|
def test_init() -> None:
|
|
43
43
|
m = FunctionModel(function=hello)
|
|
44
|
-
assert m.
|
|
44
|
+
assert m.model_name == 'function:hello:'
|
|
45
45
|
|
|
46
46
|
m1 = FunctionModel(stream_function=stream_hello)
|
|
47
|
-
assert m1.
|
|
47
|
+
assert m1.model_name == 'function::stream_hello'
|
|
48
48
|
|
|
49
49
|
m2 = FunctionModel(function=hello, stream_function=stream_hello)
|
|
50
|
-
assert m2.
|
|
50
|
+
assert m2.model_name == 'function:hello:stream_hello'
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
|
|
@@ -38,7 +38,7 @@ def test_known_model_names():
|
|
|
38
38
|
groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
|
|
39
39
|
mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
|
|
40
40
|
openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
|
|
41
|
-
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt')
|
|
41
|
+
n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
|
|
42
42
|
]
|
|
43
43
|
extra_names = ['test']
|
|
44
44
|
|
|
@@ -25,7 +25,7 @@ from pydantic_ai.messages import (
|
|
|
25
25
|
from pydantic_ai.result import Usage
|
|
26
26
|
from pydantic_ai.settings import ModelSettings
|
|
27
27
|
|
|
28
|
-
from ..conftest import IsNow, try_import
|
|
28
|
+
from ..conftest import IsNow, TestEnv, try_import
|
|
29
29
|
from .mock_async_stream import MockAsyncStream
|
|
30
30
|
|
|
31
31
|
with try_import() as imports_successful:
|
|
@@ -65,12 +65,18 @@ def test_init_with_base_url():
|
|
|
65
65
|
m.name()
|
|
66
66
|
|
|
67
67
|
|
|
68
|
+
def test_init_with_no_api_key_will_still_setup_client():
|
|
69
|
+
m = OpenAIModel('llama3.2', base_url='http://localhost:19434/v1')
|
|
70
|
+
assert str(m.client.base_url) == 'http://localhost:19434/v1/'
|
|
71
|
+
|
|
72
|
+
|
|
68
73
|
def test_init_with_non_openai_model():
|
|
69
74
|
m = OpenAIModel('llama3.2-vision:latest', base_url='https://example.com/v1/')
|
|
70
75
|
m.name()
|
|
71
76
|
|
|
72
77
|
|
|
73
|
-
def test_init_of_openai_without_api_key_raises_error():
|
|
78
|
+
def test_init_of_openai_without_api_key_raises_error(env: TestEnv):
|
|
79
|
+
env.remove('OPENAI_API_KEY')
|
|
74
80
|
with pytest.raises(
|
|
75
81
|
OpenAIError,
|
|
76
82
|
match='^The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable$',
|
|
@@ -135,7 +141,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
|
|
|
135
141
|
id='123',
|
|
136
142
|
choices=[Choice(finish_reason='stop', index=0, message=message)],
|
|
137
143
|
created=1704067200, # 2024-01-01
|
|
138
|
-
model='gpt-4o',
|
|
144
|
+
model='gpt-4o-123',
|
|
139
145
|
object='chat.completion',
|
|
140
146
|
usage=usage,
|
|
141
147
|
)
|
|
@@ -162,13 +168,13 @@ async def test_request_simple_success(allow_model_requests: None):
|
|
|
162
168
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
163
169
|
ModelResponse(
|
|
164
170
|
parts=[TextPart(content='world')],
|
|
165
|
-
model_name='gpt-4o',
|
|
171
|
+
model_name='gpt-4o-123',
|
|
166
172
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
167
173
|
),
|
|
168
174
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
169
175
|
ModelResponse(
|
|
170
176
|
parts=[TextPart(content='world')],
|
|
171
|
-
model_name='gpt-4o',
|
|
177
|
+
model_name='gpt-4o-123',
|
|
172
178
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
173
179
|
),
|
|
174
180
|
]
|
|
@@ -232,7 +238,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
232
238
|
tool_call_id='123',
|
|
233
239
|
)
|
|
234
240
|
],
|
|
235
|
-
model_name='gpt-4o',
|
|
241
|
+
model_name='gpt-4o-123',
|
|
236
242
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
237
243
|
),
|
|
238
244
|
ModelRequest(
|
|
@@ -320,7 +326,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
320
326
|
tool_call_id='1',
|
|
321
327
|
)
|
|
322
328
|
],
|
|
323
|
-
model_name='gpt-4o',
|
|
329
|
+
model_name='gpt-4o-123',
|
|
324
330
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
325
331
|
),
|
|
326
332
|
ModelRequest(
|
|
@@ -341,7 +347,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
341
347
|
tool_call_id='2',
|
|
342
348
|
)
|
|
343
349
|
],
|
|
344
|
-
model_name='gpt-4o',
|
|
350
|
+
model_name='gpt-4o-123',
|
|
345
351
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
346
352
|
),
|
|
347
353
|
ModelRequest(
|
|
@@ -356,7 +362,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
356
362
|
),
|
|
357
363
|
ModelResponse(
|
|
358
364
|
parts=[TextPart(content='final response')],
|
|
359
|
-
model_name='gpt-4o',
|
|
365
|
+
model_name='gpt-4o-123',
|
|
360
366
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
361
367
|
),
|
|
362
368
|
]
|
|
@@ -30,17 +30,18 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
|
|
|
30
30
|
save_service_account(service_account_path, 'my-project-id')
|
|
31
31
|
|
|
32
32
|
model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path)
|
|
33
|
-
assert model.
|
|
34
|
-
assert model.
|
|
33
|
+
assert model._url is None
|
|
34
|
+
assert model._auth is None
|
|
35
35
|
|
|
36
|
-
await model.
|
|
36
|
+
await model.ainit()
|
|
37
37
|
|
|
38
38
|
assert model.url == snapshot(
|
|
39
39
|
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
|
|
40
40
|
'publishers/google/models/gemini-1.5-flash:'
|
|
41
41
|
)
|
|
42
42
|
assert model.auth is not None
|
|
43
|
-
assert model.
|
|
43
|
+
assert model.model_name == snapshot('gemini-1.5-flash')
|
|
44
|
+
assert model.system == snapshot('google-vertex')
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
class NoOpCredentials:
|
|
@@ -53,12 +54,12 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
|
|
|
53
54
|
return_value=(NoOpCredentials(), 'my-project-id'),
|
|
54
55
|
)
|
|
55
56
|
model = VertexAIModel('gemini-1.5-flash')
|
|
56
|
-
assert model.
|
|
57
|
-
assert model.
|
|
57
|
+
assert model._url is None
|
|
58
|
+
assert model._auth is None
|
|
58
59
|
|
|
59
60
|
assert patch.call_count == 0
|
|
60
61
|
|
|
61
|
-
await model.
|
|
62
|
+
await model.ainit()
|
|
62
63
|
|
|
63
64
|
assert patch.call_count == 1
|
|
64
65
|
|
|
@@ -67,9 +68,10 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
|
|
|
67
68
|
'publishers/google/models/gemini-1.5-flash:'
|
|
68
69
|
)
|
|
69
70
|
assert model.auth is not None
|
|
70
|
-
assert model.
|
|
71
|
+
assert model.model_name == snapshot('gemini-1.5-flash')
|
|
72
|
+
assert model.system == snapshot('google-vertex')
|
|
71
73
|
|
|
72
|
-
await model.
|
|
74
|
+
await model.ainit()
|
|
73
75
|
assert model.url is not None
|
|
74
76
|
assert model.auth is not None
|
|
75
77
|
assert patch.call_count == 1
|
|
@@ -80,10 +82,10 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
|
|
|
80
82
|
save_service_account(service_account_path, 'my-project-id')
|
|
81
83
|
|
|
82
84
|
model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='my-project-id')
|
|
83
|
-
assert model.
|
|
84
|
-
assert model.
|
|
85
|
+
assert model._url is None
|
|
86
|
+
assert model._auth is None
|
|
85
87
|
|
|
86
|
-
await model.
|
|
88
|
+
await model.ainit()
|
|
87
89
|
|
|
88
90
|
assert model.url == snapshot(
|
|
89
91
|
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
|
|
@@ -92,30 +94,6 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
|
|
|
92
94
|
assert model.auth is not None
|
|
93
95
|
|
|
94
96
|
|
|
95
|
-
async def test_init_service_account_wrong_project_id(tmp_path: Path, allow_model_requests: None):
|
|
96
|
-
service_account_path = tmp_path / 'service_account.json'
|
|
97
|
-
save_service_account(service_account_path, 'my-project-id')
|
|
98
|
-
|
|
99
|
-
model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='different')
|
|
100
|
-
|
|
101
|
-
with pytest.raises(UserError) as exc_info:
|
|
102
|
-
await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
|
|
103
|
-
assert str(exc_info.value) == snapshot(
|
|
104
|
-
"The project_id you provided does not match the one from service account file: 'different' != 'my-project-id'"
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
async def test_init_env_wrong_project_id(mocker: MockerFixture, allow_model_requests: None):
|
|
109
|
-
mocker.patch('pydantic_ai.models.vertexai.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
|
|
110
|
-
model = VertexAIModel('gemini-1.5-flash', project_id='different')
|
|
111
|
-
|
|
112
|
-
with pytest.raises(UserError) as exc_info:
|
|
113
|
-
await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
|
|
114
|
-
assert str(exc_info.value) == snapshot(
|
|
115
|
-
"The project_id you provided does not match the one from `google.auth.default()`: 'different' != 'my-project-id'"
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
|
|
119
97
|
async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_requests: None):
|
|
120
98
|
mocker.patch(
|
|
121
99
|
'pydantic_ai.models.vertexai.google.auth.default',
|
|
@@ -124,7 +102,7 @@ async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_request
|
|
|
124
102
|
model = VertexAIModel('gemini-1.5-flash')
|
|
125
103
|
|
|
126
104
|
with pytest.raises(UserError) as exc_info:
|
|
127
|
-
await model.
|
|
105
|
+
await model.ainit()
|
|
128
106
|
assert str(exc_info.value) == snapshot('No project_id provided and none found in `google.auth.default()`')
|
|
129
107
|
|
|
130
108
|
|
|
@@ -326,13 +326,13 @@ def test_response_tuple():
|
|
|
326
326
|
result = agent.run_sync('Hello')
|
|
327
327
|
assert result.data == snapshot(('a', 'a'))
|
|
328
328
|
|
|
329
|
-
assert m.
|
|
330
|
-
assert m.
|
|
329
|
+
assert m.last_model_request_parameters is not None
|
|
330
|
+
assert m.last_model_request_parameters.function_tools == snapshot([])
|
|
331
|
+
assert m.last_model_request_parameters.allow_text_result is False
|
|
331
332
|
|
|
332
|
-
assert m.
|
|
333
|
-
assert len(m.
|
|
334
|
-
|
|
335
|
-
assert m.agent_model_result_tools == snapshot(
|
|
333
|
+
assert m.last_model_request_parameters.result_tools is not None
|
|
334
|
+
assert len(m.last_model_request_parameters.result_tools) == 1
|
|
335
|
+
assert m.last_model_request_parameters.result_tools == snapshot(
|
|
336
336
|
[
|
|
337
337
|
ToolDefinition(
|
|
338
338
|
name='final_result',
|
|
@@ -384,13 +384,14 @@ def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
|
|
|
384
384
|
assert result.data == snapshot('success (no tool calls)')
|
|
385
385
|
assert got_tool_call_name == snapshot(None)
|
|
386
386
|
|
|
387
|
-
assert m.
|
|
388
|
-
assert m.
|
|
387
|
+
assert m.last_model_request_parameters is not None
|
|
388
|
+
assert m.last_model_request_parameters.function_tools == snapshot([])
|
|
389
|
+
assert m.last_model_request_parameters.allow_text_result is True
|
|
389
390
|
|
|
390
|
-
assert m.
|
|
391
|
-
assert len(m.
|
|
391
|
+
assert m.last_model_request_parameters.result_tools is not None
|
|
392
|
+
assert len(m.last_model_request_parameters.result_tools) == 1
|
|
392
393
|
|
|
393
|
-
assert m.
|
|
394
|
+
assert m.last_model_request_parameters.result_tools == snapshot(
|
|
394
395
|
[
|
|
395
396
|
ToolDefinition(
|
|
396
397
|
name='final_result',
|
|
@@ -459,13 +460,14 @@ class Bar(BaseModel):
|
|
|
459
460
|
assert result.data == mod.Foo(a=0, b='a')
|
|
460
461
|
assert got_tool_call_name == snapshot('final_result_Foo')
|
|
461
462
|
|
|
462
|
-
assert m.
|
|
463
|
-
assert m.
|
|
463
|
+
assert m.last_model_request_parameters is not None
|
|
464
|
+
assert m.last_model_request_parameters.function_tools == snapshot([])
|
|
465
|
+
assert m.last_model_request_parameters.allow_text_result is False
|
|
464
466
|
|
|
465
|
-
assert m.
|
|
466
|
-
assert len(m.
|
|
467
|
+
assert m.last_model_request_parameters.result_tools is not None
|
|
468
|
+
assert len(m.last_model_request_parameters.result_tools) == 2
|
|
467
469
|
|
|
468
|
-
assert m.
|
|
470
|
+
assert m.last_model_request_parameters.result_tools == snapshot(
|
|
469
471
|
[
|
|
470
472
|
ToolDefinition(
|
|
471
473
|
name='final_result_Foo',
|
|
@@ -17,6 +17,7 @@ from pytest_examples import CodeExample, EvalExample, find_examples
|
|
|
17
17
|
from pytest_mock import MockerFixture
|
|
18
18
|
|
|
19
19
|
from pydantic_ai._utils import group_by_temporal
|
|
20
|
+
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
20
21
|
from pydantic_ai.messages import (
|
|
21
22
|
ModelMessage,
|
|
22
23
|
ModelResponse,
|
|
@@ -288,6 +289,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
|
|
|
288
289
|
)
|
|
289
290
|
]
|
|
290
291
|
)
|
|
292
|
+
elif m.content.startswith('Write a list of 5 very rude things that I might say'):
|
|
293
|
+
raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
|
|
291
294
|
elif m.content.startswith('<examples>\n <user>'):
|
|
292
295
|
return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args={})])
|
|
293
296
|
elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2:
|
|
@@ -75,14 +75,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
|
|
|
75
75
|
'id': 0,
|
|
76
76
|
'message': 'my_agent run prompt=Hello',
|
|
77
77
|
'children': [
|
|
78
|
-
{'id': 1, 'message': 'preparing model
|
|
78
|
+
{'id': 1, 'message': 'preparing model request params run_step=1'},
|
|
79
79
|
{'id': 2, 'message': 'model request'},
|
|
80
80
|
{
|
|
81
81
|
'id': 3,
|
|
82
82
|
'message': 'handle model response -> tool-return',
|
|
83
83
|
'children': [{'id': 4, 'message': "running tools=['my_ret']"}],
|
|
84
84
|
},
|
|
85
|
-
{'id': 5, 'message': 'preparing model
|
|
85
|
+
{'id': 5, 'message': 'preparing model request params run_step=2'},
|
|
86
86
|
{'id': 6, 'message': 'model request'},
|
|
87
87
|
{'id': 7, 'message': 'handle model response -> final result'},
|
|
88
88
|
],
|
|
@@ -102,16 +102,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
|
|
|
102
102
|
'custom_result_text': None,
|
|
103
103
|
'custom_result_args': None,
|
|
104
104
|
'seed': 0,
|
|
105
|
-
'
|
|
106
|
-
'agent_model_allow_text_result': None,
|
|
107
|
-
'agent_model_result_tools': None,
|
|
105
|
+
'last_model_request_parameters': None,
|
|
108
106
|
},
|
|
109
107
|
'name': 'my_agent',
|
|
110
108
|
'end_strategy': 'early',
|
|
111
109
|
'model_settings': None,
|
|
112
110
|
}
|
|
113
111
|
),
|
|
114
|
-
'model_name': 'test
|
|
112
|
+
'model_name': 'test',
|
|
115
113
|
'agent_name': 'my_agent',
|
|
116
114
|
'logfire.msg_template': '{agent_name} run {prompt=}',
|
|
117
115
|
'logfire.msg': 'my_agent run prompt=Hello',
|
|
@@ -255,9 +253,9 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
|
|
|
255
253
|
'code.function': 'test_logfire',
|
|
256
254
|
'code.lineno': IsInt(),
|
|
257
255
|
'run_step': 1,
|
|
258
|
-
'logfire.msg_template': 'preparing model
|
|
256
|
+
'logfire.msg_template': 'preparing model request params {run_step=}',
|
|
259
257
|
'logfire.span_type': 'span',
|
|
260
|
-
'logfire.msg': 'preparing model
|
|
258
|
+
'logfire.msg': 'preparing model request params run_step=1',
|
|
261
259
|
'logfire.json_schema': '{"type":"object","properties":{"run_step":{}}}',
|
|
262
260
|
}
|
|
263
261
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|