pydantic-ai 0.0.22__tar.gz → 0.0.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/PKG-INFO +3 -3
  2. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/pyproject.toml +3 -3
  3. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_anthropic.py +9 -8
  4. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_cohere.py +2 -1
  5. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_gemini.py +133 -30
  6. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_groq.py +9 -8
  7. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_mistral.py +19 -20
  8. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model.py +31 -15
  9. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_function.py +3 -3
  10. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_names.py +1 -1
  11. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_openai.py +15 -9
  12. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_vertexai.py +15 -37
  13. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_agent.py +18 -16
  14. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_examples.py +3 -0
  15. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_logfire.py +6 -8
  16. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/.gitignore +0 -0
  17. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/LICENSE +0 -0
  18. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/Makefile +0 -0
  19. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/README.md +0 -0
  20. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/__init__.py +0 -0
  21. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/conftest.py +0 -0
  22. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/README.md +0 -0
  23. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/bank_database.py +0 -0
  24. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/fake_database.py +0 -0
  25. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/example_modules/weather_service.py +0 -0
  26. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/__init__.py +0 -0
  27. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_graph.py +0 -0
  28. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_history.py +0 -0
  29. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_mermaid.py +0 -0
  30. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/graph/test_state.py +0 -0
  31. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/import_examples.py +0 -0
  32. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/__init__.py +0 -0
  33. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/mock_async_stream.py +0 -0
  34. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/models/test_model_test.py +0 -0
  35. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_deps.py +0 -0
  36. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_format_as_xml.py +0 -0
  37. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_live.py +0 -0
  38. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_parts_manager.py +0 -0
  39. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_streaming.py +0 -0
  40. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_tools.py +0 -0
  41. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_usage_limits.py +0 -0
  42. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/test_utils.py +0 -0
  43. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/typed_agent.py +0 -0
  44. {pydantic_ai-0.0.22 → pydantic_ai-0.0.24}/tests/typed_graph.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs
5
5
  Project-URL: Homepage, https://ai.pydantic.dev
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -32,9 +32,9 @@ Classifier: Programming Language :: Python :: 3.13
32
32
  Classifier: Topic :: Internet
33
33
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
34
34
  Requires-Python: >=3.9
35
- Requires-Dist: pydantic-ai-slim[anthropic,cohere,graph,groq,mistral,openai,vertexai]==0.0.22
35
+ Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.24
36
36
  Provides-Extra: examples
37
- Requires-Dist: pydantic-ai-examples==0.0.22; extra == 'examples'
37
+ Requires-Dist: pydantic-ai-examples==0.0.24; extra == 'examples'
38
38
  Provides-Extra: logfire
39
39
  Requires-Dist: logfire>=2.3; extra == 'logfire'
40
40
  Description-Content-Type: text/markdown
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai"
7
- version = "0.0.22"
7
+ version = "0.0.24"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
@@ -37,7 +37,7 @@ classifiers = [
37
37
  ]
38
38
  requires-python = ">=3.9"
39
39
 
40
- dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.22"]
40
+ dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.24"]
41
41
 
42
42
  [project.urls]
43
43
  Homepage = "https://ai.pydantic.dev"
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
46
46
  Changelog = "https://github.com/pydantic/pydantic-ai/releases"
47
47
 
48
48
  [project.optional-dependencies]
49
- examples = ["pydantic-ai-examples==0.0.22"]
49
+ examples = ["pydantic-ai-examples==0.0.24"]
50
50
  logfire = ["logfire>=2.3"]
51
51
 
52
52
  [tool.uv.sources]
@@ -60,7 +60,8 @@ T = TypeVar('T')
60
60
  def test_init():
61
61
  m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
62
62
  assert m.client.api_key == 'foobar'
63
- assert m.name() == 'anthropic:claude-3-5-haiku-latest'
63
+ assert m.model_name == 'claude-3-5-haiku-latest'
64
+ assert m.system == 'anthropic'
64
65
 
65
66
 
66
67
  @dataclass
@@ -111,7 +112,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
111
112
  return AnthropicMessage(
112
113
  id='123',
113
114
  content=content,
114
- model='claude-3-5-haiku-latest',
115
+ model='claude-3-5-haiku-123',
115
116
  role='assistant',
116
117
  stop_reason='end_turn',
117
118
  type='message',
@@ -140,13 +141,13 @@ async def test_sync_request_text_response(allow_model_requests: None):
140
141
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
141
142
  ModelResponse(
142
143
  parts=[TextPart(content='world')],
143
- model_name='claude-3-5-haiku-latest',
144
+ model_name='claude-3-5-haiku-123',
144
145
  timestamp=IsNow(tz=timezone.utc),
145
146
  ),
146
147
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
147
148
  ModelResponse(
148
149
  parts=[TextPart(content='world')],
149
- model_name='claude-3-5-haiku-latest',
150
+ model_name='claude-3-5-haiku-123',
150
151
  timestamp=IsNow(tz=timezone.utc),
151
152
  ),
152
153
  ]
@@ -189,7 +190,7 @@ async def test_request_structured_response(allow_model_requests: None):
189
190
  tool_call_id='123',
190
191
  )
191
192
  ],
192
- model_name='claude-3-5-haiku-latest',
193
+ model_name='claude-3-5-haiku-123',
193
194
  timestamp=IsNow(tz=timezone.utc),
194
195
  ),
195
196
  ModelRequest(
@@ -251,7 +252,7 @@ async def test_request_tool_call(allow_model_requests: None):
251
252
  tool_call_id='1',
252
253
  )
253
254
  ],
254
- model_name='claude-3-5-haiku-latest',
255
+ model_name='claude-3-5-haiku-123',
255
256
  timestamp=IsNow(tz=timezone.utc),
256
257
  ),
257
258
  ModelRequest(
@@ -272,7 +273,7 @@ async def test_request_tool_call(allow_model_requests: None):
272
273
  tool_call_id='2',
273
274
  )
274
275
  ],
275
- model_name='claude-3-5-haiku-latest',
276
+ model_name='claude-3-5-haiku-123',
276
277
  timestamp=IsNow(tz=timezone.utc),
277
278
  ),
278
279
  ModelRequest(
@@ -287,7 +288,7 @@ async def test_request_tool_call(allow_model_requests: None):
287
288
  ),
288
289
  ModelResponse(
289
290
  parts=[TextPart(content='final response')],
290
- model_name='claude-3-5-haiku-latest',
291
+ model_name='claude-3-5-haiku-123',
291
292
  timestamp=IsNow(tz=timezone.utc),
292
293
  ),
293
294
  ]
@@ -44,7 +44,8 @@ pytestmark = [
44
44
 
45
45
  def test_init():
46
46
  m = CohereModel('command-r7b-12-2024', api_key='foobar')
47
- assert m.name() == 'cohere:command-r7b-12-2024'
47
+ assert m.model_name == 'command-r7b-12-2024'
48
+ assert m.system == 'cohere'
48
49
 
49
50
 
50
51
  @dataclass
@@ -24,9 +24,11 @@ from pydantic_ai.messages import (
24
24
  ToolReturnPart,
25
25
  UserPromptPart,
26
26
  )
27
+ from pydantic_ai.models import ModelRequestParameters
27
28
  from pydantic_ai.models.gemini import (
28
29
  ApiKeyAuth,
29
30
  GeminiModel,
31
+ GeminiModelSettings,
30
32
  _content_model_response,
31
33
  _function_call_part_from_call,
32
34
  _gemini_response_ta,
@@ -36,6 +38,7 @@ from pydantic_ai.models.gemini import (
36
38
  _GeminiFunction,
37
39
  _GeminiFunctionCallingConfig,
38
40
  _GeminiResponse,
41
+ _GeminiSafetyRating,
39
42
  _GeminiTextPart,
40
43
  _GeminiToolConfig,
41
44
  _GeminiTools,
@@ -75,18 +78,21 @@ def test_api_key_empty(env: TestEnv):
75
78
  GeminiModel('gemini-1.5-flash')
76
79
 
77
80
 
78
- async def test_agent_model_simple(allow_model_requests: None):
81
+ async def test_model_simple(allow_model_requests: None):
79
82
  m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
80
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
81
- assert isinstance(agent_model.http_client, httpx.AsyncClient)
82
- assert agent_model.model_name == 'gemini-1.5-flash'
83
- assert isinstance(agent_model.auth, ApiKeyAuth)
84
- assert agent_model.auth.api_key == 'via-arg'
85
- assert agent_model.tools is None
86
- assert agent_model.tool_config is None
83
+ assert isinstance(m.http_client, httpx.AsyncClient)
84
+ assert m.model_name == 'gemini-1.5-flash'
85
+ assert isinstance(m.auth, ApiKeyAuth)
86
+ assert m.auth.api_key == 'via-arg'
87
87
 
88
+ arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[])
89
+ tools = m._get_tools(arc)
90
+ tool_config = m._get_tool_config(arc, tools)
91
+ assert tools is None
92
+ assert tool_config is None
88
93
 
89
- async def test_agent_model_tools(allow_model_requests: None):
94
+
95
+ async def test_model_tools(allow_model_requests: None):
90
96
  m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
91
97
  tools = [
92
98
  ToolDefinition(
@@ -110,8 +116,11 @@ async def test_agent_model_tools(allow_model_requests: None):
110
116
  'This is the tool for the final Result',
111
117
  {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
112
118
  )
113
- agent_model = await m.agent_model(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
114
- assert agent_model.tools == snapshot(
119
+
120
+ arc = ModelRequestParameters(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
121
+ tools = m._get_tools(arc)
122
+ tool_config = m._get_tool_config(arc, tools)
123
+ assert tools == snapshot(
115
124
  _GeminiTools(
116
125
  function_declarations=[
117
126
  _GeminiFunction(
@@ -139,7 +148,7 @@ async def test_agent_model_tools(allow_model_requests: None):
139
148
  ]
140
149
  )
141
150
  )
142
- assert agent_model.tool_config is None
151
+ assert tool_config is None
143
152
 
144
153
 
145
154
  async def test_require_response_tool(allow_model_requests: None):
@@ -149,8 +158,10 @@ async def test_require_response_tool(allow_model_requests: None):
149
158
  'This is the tool for the final Result',
150
159
  {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
151
160
  )
152
- agent_model = await m.agent_model(function_tools=[], allow_text_result=False, result_tools=[result_tool])
153
- assert agent_model.tools == snapshot(
161
+ arc = ModelRequestParameters(function_tools=[], allow_text_result=False, result_tools=[result_tool])
162
+ tools = m._get_tools(arc)
163
+ tool_config = m._get_tool_config(arc, tools)
164
+ assert tools == snapshot(
154
165
  _GeminiTools(
155
166
  function_declarations=[
156
167
  _GeminiFunction(
@@ -164,7 +175,7 @@ async def test_require_response_tool(allow_model_requests: None):
164
175
  ]
165
176
  )
166
177
  )
167
- assert agent_model.tool_config == snapshot(
178
+ assert tool_config == snapshot(
168
179
  _GeminiToolConfig(
169
180
  function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result'])
170
181
  )
@@ -206,8 +217,9 @@ async def test_json_def_replaced(allow_model_requests: None):
206
217
  'This is the tool for the final Result',
207
218
  json_schema,
208
219
  )
209
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
210
- assert agent_model.tools == snapshot(
220
+ assert m._get_tools(
221
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
222
+ ) == snapshot(
211
223
  _GeminiTools(
212
224
  function_declarations=[
213
225
  _GeminiFunction(
@@ -252,8 +264,9 @@ async def test_json_def_replaced_any_of(allow_model_requests: None):
252
264
  'This is the tool for the final Result',
253
265
  json_schema,
254
266
  )
255
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
256
- assert agent_model.tools == snapshot(
267
+ assert m._get_tools(
268
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
269
+ ) == snapshot(
257
270
  _GeminiTools(
258
271
  function_declarations=[
259
272
  _GeminiFunction(
@@ -315,7 +328,7 @@ async def test_json_def_recursive(allow_model_requests: None):
315
328
  json_schema,
316
329
  )
317
330
  with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'):
318
- await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
331
+ m._get_tools(ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]))
319
332
 
320
333
 
321
334
  async def test_json_def_date(allow_model_requests: None):
@@ -346,8 +359,9 @@ async def test_json_def_date(allow_model_requests: None):
346
359
  'This is the tool for the final Result',
347
360
  json_schema,
348
361
  )
349
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
350
- assert agent_model.tools == snapshot(
362
+ assert m._get_tools(
363
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
364
+ ) == snapshot(
351
365
  _GeminiTools(
352
366
  function_declarations=[
353
367
  _GeminiFunction(
@@ -426,7 +440,7 @@ def gemini_response(content: _GeminiContent, finish_reason: Literal['STOP'] | No
426
440
  candidate = _GeminiCandidates(content=content, index=0, safety_ratings=[])
427
441
  if finish_reason: # pragma: no cover
428
442
  candidate['finish_reason'] = finish_reason
429
- return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
443
+ return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage(), model_version='gemini-1.5-flash-123')
430
444
 
431
445
 
432
446
  def example_usage() -> _GeminiUsageMetaData:
@@ -445,7 +459,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
445
459
  [
446
460
  ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
447
461
  ModelResponse(
448
- parts=[TextPart(content='Hello world')], model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc)
462
+ parts=[TextPart(content='Hello world')],
463
+ model_name='gemini-1.5-flash-123',
464
+ timestamp=IsNow(tz=timezone.utc),
449
465
  ),
450
466
  ]
451
467
  )
@@ -458,13 +474,13 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
458
474
  ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
459
475
  ModelResponse(
460
476
  parts=[TextPart(content='Hello world')],
461
- model_name='gemini-1.5-flash',
477
+ model_name='gemini-1.5-flash-123',
462
478
  timestamp=IsNow(tz=timezone.utc),
463
479
  ),
464
480
  ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
465
481
  ModelResponse(
466
482
  parts=[TextPart(content='Hello world')],
467
- model_name='gemini-1.5-flash',
483
+ model_name='gemini-1.5-flash-123',
468
484
  timestamp=IsNow(tz=timezone.utc),
469
485
  ),
470
486
  ]
@@ -491,7 +507,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
491
507
  args={'response': [1, 2, 123]},
492
508
  )
493
509
  ],
494
- model_name='gemini-1.5-flash',
510
+ model_name='gemini-1.5-flash-123',
495
511
  timestamp=IsNow(tz=timezone.utc),
496
512
  ),
497
513
  ModelRequest(
@@ -552,7 +568,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
552
568
  args={'loc_name': 'San Fransisco'},
553
569
  )
554
570
  ],
555
- model_name='gemini-1.5-flash',
571
+ model_name='gemini-1.5-flash-123',
556
572
  timestamp=IsNow(tz=timezone.utc),
557
573
  ),
558
574
  ModelRequest(
@@ -575,7 +591,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
575
591
  args={'loc_name': 'New York'},
576
592
  ),
577
593
  ],
578
- model_name='gemini-1.5-flash',
594
+ model_name='gemini-1.5-flash-123',
579
595
  timestamp=IsNow(tz=timezone.utc),
580
596
  ),
581
597
  ModelRequest(
@@ -590,7 +606,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
590
606
  ),
591
607
  ModelResponse(
592
608
  parts=[TextPart(content='final response')],
593
- model_name='gemini-1.5-flash',
609
+ model_name='gemini-1.5-flash-123',
594
610
  timestamp=IsNow(tz=timezone.utc),
595
611
  ),
596
612
  ]
@@ -853,3 +869,90 @@ async def test_model_settings(client_with_handler: ClientWithHandler, env: TestE
853
869
  },
854
870
  )
855
871
  assert result.data == 'world'
872
+
873
+
874
+ def gemini_no_content_response(
875
+ safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY'
876
+ ) -> _GeminiResponse:
877
+ candidate = _GeminiCandidates(safety_ratings=safety_ratings)
878
+ if finish_reason:
879
+ candidate['finish_reason'] = finish_reason
880
+ return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
881
+
882
+
883
+ async def test_safety_settings_unsafe(
884
+ client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
885
+ ) -> None:
886
+ try:
887
+
888
+ def handler(request: httpx.Request) -> httpx.Response:
889
+ safety_settings = json.loads(request.content)['safety_settings']
890
+ assert safety_settings == [
891
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
892
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
893
+ ]
894
+
895
+ return httpx.Response(
896
+ 200,
897
+ content=_gemini_response_ta.dump_json(
898
+ gemini_no_content_response(
899
+ finish_reason='SAFETY',
900
+ safety_ratings=[
901
+ {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True}
902
+ ],
903
+ ),
904
+ by_alias=True,
905
+ ),
906
+ headers={'Content-Type': 'application/json'},
907
+ )
908
+
909
+ gemini_client = client_with_handler(handler)
910
+ m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
911
+ agent = Agent(m)
912
+
913
+ await agent.run(
914
+ 'a request for something rude',
915
+ model_settings=GeminiModelSettings(
916
+ gemini_safety_settings=[
917
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
918
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
919
+ ]
920
+ ),
921
+ )
922
+ except UnexpectedModelBehavior as e:
923
+ assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')"
924
+
925
+
926
+ async def test_safety_settings_safe(
927
+ client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
928
+ ) -> None:
929
+ def handler(request: httpx.Request) -> httpx.Response:
930
+ safety_settings = json.loads(request.content)['safety_settings']
931
+ assert safety_settings == [
932
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
933
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
934
+ ]
935
+
936
+ return httpx.Response(
937
+ 200,
938
+ content=_gemini_response_ta.dump_json(
939
+ gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
940
+ by_alias=True,
941
+ ),
942
+ headers={'Content-Type': 'application/json'},
943
+ )
944
+
945
+ gemini_client = client_with_handler(handler)
946
+ m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
947
+ agent = Agent(m)
948
+
949
+ result = await agent.run(
950
+ 'hello',
951
+ model_settings=GeminiModelSettings(
952
+ gemini_safety_settings=[
953
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
954
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
955
+ ]
956
+ ),
957
+ )
958
+ assert result.data == 'world'
@@ -52,7 +52,8 @@ pytestmark = [
52
52
  def test_init():
53
53
  m = GroqModel('llama-3.3-70b-versatile', api_key='foobar')
54
54
  assert m.client.api_key == 'foobar'
55
- assert m.name() == 'groq:llama-3.3-70b-versatile'
55
+ assert m.model_name == 'llama-3.3-70b-versatile'
56
+ assert m.system == 'groq'
56
57
 
57
58
 
58
59
  @dataclass
@@ -102,7 +103,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
102
103
  id='123',
103
104
  choices=[Choice(finish_reason='stop', index=0, message=message)],
104
105
  created=1704067200, # 2024-01-01
105
- model='llama-3.3-70b-versatile',
106
+ model='llama-3.3-70b-versatile-123',
106
107
  object='chat.completion',
107
108
  usage=usage,
108
109
  )
@@ -129,13 +130,13 @@ async def test_request_simple_success(allow_model_requests: None):
129
130
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
130
131
  ModelResponse(
131
132
  parts=[TextPart(content='world')],
132
- model_name='llama-3.3-70b-versatile',
133
+ model_name='llama-3.3-70b-versatile-123',
133
134
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
134
135
  ),
135
136
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
136
137
  ModelResponse(
137
138
  parts=[TextPart(content='world')],
138
- model_name='llama-3.3-70b-versatile',
139
+ model_name='llama-3.3-70b-versatile-123',
139
140
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
140
141
  ),
141
142
  ]
@@ -186,7 +187,7 @@ async def test_request_structured_response(allow_model_requests: None):
186
187
  tool_call_id='123',
187
188
  )
188
189
  ],
189
- model_name='llama-3.3-70b-versatile',
190
+ model_name='llama-3.3-70b-versatile-123',
190
191
  timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
191
192
  ),
192
193
  ModelRequest(
@@ -272,7 +273,7 @@ async def test_request_tool_call(allow_model_requests: None):
272
273
  tool_call_id='1',
273
274
  )
274
275
  ],
275
- model_name='llama-3.3-70b-versatile',
276
+ model_name='llama-3.3-70b-versatile-123',
276
277
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
277
278
  ),
278
279
  ModelRequest(
@@ -293,7 +294,7 @@ async def test_request_tool_call(allow_model_requests: None):
293
294
  tool_call_id='2',
294
295
  )
295
296
  ],
296
- model_name='llama-3.3-70b-versatile',
297
+ model_name='llama-3.3-70b-versatile-123',
297
298
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
298
299
  ),
299
300
  ModelRequest(
@@ -308,7 +309,7 @@ async def test_request_tool_call(allow_model_requests: None):
308
309
  ),
309
310
  ModelResponse(
310
311
  parts=[TextPart(content='final response')],
311
- model_name='llama-3.3-70b-versatile',
312
+ model_name='llama-3.3-70b-versatile-123',
312
313
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
313
314
  ),
314
315
  ]
@@ -48,7 +48,6 @@ with try_import() as imports_successful:
48
48
  from mistralai.types.basemodel import Unset as MistralUnset
49
49
 
50
50
  from pydantic_ai.models.mistral import (
51
- MistralAgentModel,
52
51
  MistralModel,
53
52
  MistralStreamedResponse,
54
53
  )
@@ -124,7 +123,7 @@ def completion_message(
124
123
  id='123',
125
124
  choices=[MistralChatCompletionChoice(finish_reason='stop', index=0, message=message)],
126
125
  created=1704067200 if with_created else None, # 2024-01-01
127
- model='mistral-large-latest',
126
+ model='mistral-large-123',
128
127
  object='chat.completion',
129
128
  usage=usage or MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
130
129
  )
@@ -218,13 +217,13 @@ async def test_multiple_completions(allow_model_requests: None):
218
217
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
219
218
  ModelResponse(
220
219
  parts=[TextPart(content='world')],
221
- model_name='mistral-large-latest',
220
+ model_name='mistral-large-123',
222
221
  timestamp=IsNow(tz=timezone.utc),
223
222
  ),
224
223
  ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
225
224
  ModelResponse(
226
225
  parts=[TextPart(content='hello again')],
227
- model_name='mistral-large-latest',
226
+ model_name='mistral-large-123',
228
227
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
229
228
  ),
230
229
  ]
@@ -270,19 +269,19 @@ async def test_three_completions(allow_model_requests: None):
270
269
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
271
270
  ModelResponse(
272
271
  parts=[TextPart(content='world')],
273
- model_name='mistral-large-latest',
272
+ model_name='mistral-large-123',
274
273
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
275
274
  ),
276
275
  ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
277
276
  ModelResponse(
278
277
  parts=[TextPart(content='hello again')],
279
- model_name='mistral-large-latest',
278
+ model_name='mistral-large-123',
280
279
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
281
280
  ),
282
281
  ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]),
283
282
  ModelResponse(
284
283
  parts=[TextPart(content='final message')],
285
- model_name='mistral-large-latest',
284
+ model_name='mistral-large-123',
286
285
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
287
286
  ),
288
287
  ]
@@ -397,7 +396,7 @@ async def test_request_model_structured_with_arguments_dict_response(allow_model
397
396
  tool_call_id='123',
398
397
  )
399
398
  ],
400
- model_name='mistral-large-latest',
399
+ model_name='mistral-large-123',
401
400
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
402
401
  ),
403
402
  ModelRequest(
@@ -459,7 +458,7 @@ async def test_request_model_structured_with_arguments_str_response(allow_model_
459
458
  tool_call_id='123',
460
459
  )
461
460
  ],
462
- model_name='mistral-large-latest',
461
+ model_name='mistral-large-123',
463
462
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
464
463
  ),
465
464
  ModelRequest(
@@ -520,7 +519,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
520
519
  tool_call_id='123',
521
520
  )
522
521
  ],
523
- model_name='mistral-large-latest',
522
+ model_name='mistral-large-123',
524
523
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
525
524
  ),
526
525
  ModelRequest(
@@ -1105,7 +1104,7 @@ async def test_request_tool_call(allow_model_requests: None):
1105
1104
  tool_call_id='1',
1106
1105
  )
1107
1106
  ],
1108
- model_name='mistral-large-latest',
1107
+ model_name='mistral-large-123',
1109
1108
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1110
1109
  ),
1111
1110
  ModelRequest(
@@ -1126,7 +1125,7 @@ async def test_request_tool_call(allow_model_requests: None):
1126
1125
  tool_call_id='2',
1127
1126
  )
1128
1127
  ],
1129
- model_name='mistral-large-latest',
1128
+ model_name='mistral-large-123',
1130
1129
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1131
1130
  ),
1132
1131
  ModelRequest(
@@ -1141,7 +1140,7 @@ async def test_request_tool_call(allow_model_requests: None):
1141
1140
  ),
1142
1141
  ModelResponse(
1143
1142
  parts=[TextPart(content='final response')],
1144
- model_name='mistral-large-latest',
1143
+ model_name='mistral-large-123',
1145
1144
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1146
1145
  ),
1147
1146
  ]
@@ -1245,7 +1244,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
1245
1244
  tool_call_id='1',
1246
1245
  )
1247
1246
  ],
1248
- model_name='mistral-large-latest',
1247
+ model_name='mistral-large-123',
1249
1248
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1250
1249
  ),
1251
1250
  ModelRequest(
@@ -1266,7 +1265,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
1266
1265
  tool_call_id='2',
1267
1266
  )
1268
1267
  ],
1269
- model_name='mistral-large-latest',
1268
+ model_name='mistral-large-123',
1270
1269
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1271
1270
  ),
1272
1271
  ModelRequest(
@@ -1287,7 +1286,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
1287
1286
  tool_call_id='1',
1288
1287
  )
1289
1288
  ],
1290
- model_name='mistral-large-latest',
1289
+ model_name='mistral-large-123',
1291
1290
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1292
1291
  ),
1293
1292
  ModelRequest(
@@ -1668,8 +1667,8 @@ def test_generate_user_output_format_complex():
1668
1667
  'prop_unrecognized_type': {'type': 'customSomething'},
1669
1668
  }
1670
1669
  }
1671
- mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}')
1672
- result = mam._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
1670
+ m = MistralModel('', json_mode_schema_prompt='{schema}')
1671
+ result = m._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
1673
1672
  assert result.content == (
1674
1673
  "{'prop_anyOf': 'Optional[str]', "
1675
1674
  "'prop_no_type': 'Any', "
@@ -1685,8 +1684,8 @@ def test_generate_user_output_format_complex():
1685
1684
 
1686
1685
  def test_generate_user_output_format_multiple():
1687
1686
  schema = {'properties': {'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}}}
1688
- mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}')
1689
- result = mam._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
1687
+ m = MistralModel('', json_mode_schema_prompt='{schema}')
1688
+ result = m._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
1690
1689
  assert result.content == "[{'prop_anyOf': 'Optional[str]'}, {'prop_anyOf': 'Optional[str]'}]"
1691
1690
 
1692
1691
 
@@ -8,66 +8,81 @@ from pydantic_ai.models import infer_model
8
8
  from ..conftest import TestEnv
9
9
 
10
10
  TEST_CASES = [
11
- ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
12
- ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
13
- ('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'),
14
- ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
15
- ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
11
+ ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
12
+ ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
13
+ ('OPENAI_API_KEY', 'o1', 'o1', 'openai', 'openai', 'OpenAIModel'),
14
+ ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
15
+ ('GEMINI_API_KEY', 'gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
16
16
  (
17
17
  'GEMINI_API_KEY',
18
18
  'google-vertex:gemini-1.5-flash',
19
- 'google-vertex:gemini-1.5-flash',
19
+ 'gemini-1.5-flash',
20
+ 'google-vertex',
20
21
  'vertexai',
21
22
  'VertexAIModel',
22
23
  ),
23
24
  (
24
25
  'GEMINI_API_KEY',
25
26
  'vertexai:gemini-1.5-flash',
26
- 'google-vertex:gemini-1.5-flash',
27
+ 'gemini-1.5-flash',
28
+ 'google-vertex',
27
29
  'vertexai',
28
30
  'VertexAIModel',
29
31
  ),
30
32
  (
31
33
  'ANTHROPIC_API_KEY',
32
34
  'anthropic:claude-3-5-haiku-latest',
33
- 'anthropic:claude-3-5-haiku-latest',
35
+ 'claude-3-5-haiku-latest',
36
+ 'anthropic',
34
37
  'anthropic',
35
38
  'AnthropicModel',
36
39
  ),
37
40
  (
38
41
  'ANTHROPIC_API_KEY',
39
42
  'claude-3-5-haiku-latest',
40
- 'anthropic:claude-3-5-haiku-latest',
43
+ 'claude-3-5-haiku-latest',
44
+ 'anthropic',
41
45
  'anthropic',
42
46
  'AnthropicModel',
43
47
  ),
44
48
  (
45
49
  'GROQ_API_KEY',
46
50
  'groq:llama-3.3-70b-versatile',
47
- 'groq:llama-3.3-70b-versatile',
51
+ 'llama-3.3-70b-versatile',
52
+ 'groq',
48
53
  'groq',
49
54
  'GroqModel',
50
55
  ),
51
56
  (
52
57
  'MISTRAL_API_KEY',
53
58
  'mistral:mistral-small-latest',
54
- 'mistral:mistral-small-latest',
59
+ 'mistral-small-latest',
60
+ 'mistral',
55
61
  'mistral',
56
62
  'MistralModel',
57
63
  ),
58
64
  (
59
65
  'CO_API_KEY',
60
66
  'cohere:command',
61
- 'cohere:command',
67
+ 'command',
68
+ 'cohere',
62
69
  'cohere',
63
70
  'CohereModel',
64
71
  ),
65
72
  ]
66
73
 
67
74
 
68
- @pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES)
75
+ @pytest.mark.parametrize(
76
+ 'mock_api_key, model_name, expected_model_name, expected_system, module_name, model_class_name', TEST_CASES
77
+ )
69
78
  def test_infer_model(
70
- env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str
79
+ env: TestEnv,
80
+ mock_api_key: str,
81
+ model_name: str,
82
+ expected_model_name: str,
83
+ expected_system: str,
84
+ module_name: str,
85
+ model_class_name: str,
71
86
  ):
72
87
  try:
73
88
  model_module = import_module(f'pydantic_ai.models.{module_name}')
@@ -79,7 +94,8 @@ def test_infer_model(
79
94
 
80
95
  m = infer_model(model_name) # pyright: ignore[reportArgumentType]
81
96
  assert isinstance(m, expected_model)
82
- assert m.name() == expected_model_name
97
+ assert m.model_name == expected_model_name
98
+ assert m.system == expected_system
83
99
 
84
100
  m2 = infer_model(m)
85
101
  assert m2 is m
@@ -41,13 +41,13 @@ async def stream_hello(_messages: list[ModelMessage], _agent_info: AgentInfo) ->
41
41
 
42
42
  def test_init() -> None:
43
43
  m = FunctionModel(function=hello)
44
- assert m.name() == 'function:hello:'
44
+ assert m.model_name == 'function:hello:'
45
45
 
46
46
  m1 = FunctionModel(stream_function=stream_hello)
47
- assert m1.name() == 'function::stream_hello'
47
+ assert m1.model_name == 'function::stream_hello'
48
48
 
49
49
  m2 = FunctionModel(function=hello, stream_function=stream_hello)
50
- assert m2.name() == 'function:hello:stream_hello'
50
+ assert m2.model_name == 'function:hello:stream_hello'
51
51
 
52
52
 
53
53
  async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
@@ -38,7 +38,7 @@ def test_known_model_names():
38
38
  groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
39
39
  mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
40
40
  openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
41
- n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt')
41
+ n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
42
42
  ]
43
43
  extra_names = ['test']
44
44
 
@@ -25,7 +25,7 @@ from pydantic_ai.messages import (
25
25
  from pydantic_ai.result import Usage
26
26
  from pydantic_ai.settings import ModelSettings
27
27
 
28
- from ..conftest import IsNow, try_import
28
+ from ..conftest import IsNow, TestEnv, try_import
29
29
  from .mock_async_stream import MockAsyncStream
30
30
 
31
31
  with try_import() as imports_successful:
@@ -65,12 +65,18 @@ def test_init_with_base_url():
65
65
  m.name()
66
66
 
67
67
 
68
+ def test_init_with_no_api_key_will_still_setup_client():
69
+ m = OpenAIModel('llama3.2', base_url='http://localhost:19434/v1')
70
+ assert str(m.client.base_url) == 'http://localhost:19434/v1/'
71
+
72
+
68
73
  def test_init_with_non_openai_model():
69
74
  m = OpenAIModel('llama3.2-vision:latest', base_url='https://example.com/v1/')
70
75
  m.name()
71
76
 
72
77
 
73
- def test_init_of_openai_without_api_key_raises_error():
78
+ def test_init_of_openai_without_api_key_raises_error(env: TestEnv):
79
+ env.remove('OPENAI_API_KEY')
74
80
  with pytest.raises(
75
81
  OpenAIError,
76
82
  match='^The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable$',
@@ -135,7 +141,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
135
141
  id='123',
136
142
  choices=[Choice(finish_reason='stop', index=0, message=message)],
137
143
  created=1704067200, # 2024-01-01
138
- model='gpt-4o',
144
+ model='gpt-4o-123',
139
145
  object='chat.completion',
140
146
  usage=usage,
141
147
  )
@@ -162,13 +168,13 @@ async def test_request_simple_success(allow_model_requests: None):
162
168
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
163
169
  ModelResponse(
164
170
  parts=[TextPart(content='world')],
165
- model_name='gpt-4o',
171
+ model_name='gpt-4o-123',
166
172
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
167
173
  ),
168
174
  ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
169
175
  ModelResponse(
170
176
  parts=[TextPart(content='world')],
171
- model_name='gpt-4o',
177
+ model_name='gpt-4o-123',
172
178
  timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
173
179
  ),
174
180
  ]
@@ -232,7 +238,7 @@ async def test_request_structured_response(allow_model_requests: None):
232
238
  tool_call_id='123',
233
239
  )
234
240
  ],
235
- model_name='gpt-4o',
241
+ model_name='gpt-4o-123',
236
242
  timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
237
243
  ),
238
244
  ModelRequest(
@@ -320,7 +326,7 @@ async def test_request_tool_call(allow_model_requests: None):
320
326
  tool_call_id='1',
321
327
  )
322
328
  ],
323
- model_name='gpt-4o',
329
+ model_name='gpt-4o-123',
324
330
  timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
325
331
  ),
326
332
  ModelRequest(
@@ -341,7 +347,7 @@ async def test_request_tool_call(allow_model_requests: None):
341
347
  tool_call_id='2',
342
348
  )
343
349
  ],
344
- model_name='gpt-4o',
350
+ model_name='gpt-4o-123',
345
351
  timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
346
352
  ),
347
353
  ModelRequest(
@@ -356,7 +362,7 @@ async def test_request_tool_call(allow_model_requests: None):
356
362
  ),
357
363
  ModelResponse(
358
364
  parts=[TextPart(content='final response')],
359
- model_name='gpt-4o',
365
+ model_name='gpt-4o-123',
360
366
  timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
361
367
  ),
362
368
  ]
@@ -30,17 +30,18 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
30
30
  save_service_account(service_account_path, 'my-project-id')
31
31
 
32
32
  model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path)
33
- assert model.url is None
34
- assert model.auth is None
33
+ assert model._url is None
34
+ assert model._auth is None
35
35
 
36
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
36
+ await model.ainit()
37
37
 
38
38
  assert model.url == snapshot(
39
39
  'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
40
40
  'publishers/google/models/gemini-1.5-flash:'
41
41
  )
42
42
  assert model.auth is not None
43
- assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
43
+ assert model.model_name == snapshot('gemini-1.5-flash')
44
+ assert model.system == snapshot('google-vertex')
44
45
 
45
46
 
46
47
  class NoOpCredentials:
@@ -53,12 +54,12 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
53
54
  return_value=(NoOpCredentials(), 'my-project-id'),
54
55
  )
55
56
  model = VertexAIModel('gemini-1.5-flash')
56
- assert model.url is None
57
- assert model.auth is None
57
+ assert model._url is None
58
+ assert model._auth is None
58
59
 
59
60
  assert patch.call_count == 0
60
61
 
61
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
62
+ await model.ainit()
62
63
 
63
64
  assert patch.call_count == 1
64
65
 
@@ -67,9 +68,10 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
67
68
  'publishers/google/models/gemini-1.5-flash:'
68
69
  )
69
70
  assert model.auth is not None
70
- assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
71
+ assert model.model_name == snapshot('gemini-1.5-flash')
72
+ assert model.system == snapshot('google-vertex')
71
73
 
72
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
74
+ await model.ainit()
73
75
  assert model.url is not None
74
76
  assert model.auth is not None
75
77
  assert patch.call_count == 1
@@ -80,10 +82,10 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
80
82
  save_service_account(service_account_path, 'my-project-id')
81
83
 
82
84
  model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='my-project-id')
83
- assert model.url is None
84
- assert model.auth is None
85
+ assert model._url is None
86
+ assert model._auth is None
85
87
 
86
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
88
+ await model.ainit()
87
89
 
88
90
  assert model.url == snapshot(
89
91
  'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
@@ -92,30 +94,6 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
92
94
  assert model.auth is not None
93
95
 
94
96
 
95
- async def test_init_service_account_wrong_project_id(tmp_path: Path, allow_model_requests: None):
96
- service_account_path = tmp_path / 'service_account.json'
97
- save_service_account(service_account_path, 'my-project-id')
98
-
99
- model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='different')
100
-
101
- with pytest.raises(UserError) as exc_info:
102
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
103
- assert str(exc_info.value) == snapshot(
104
- "The project_id you provided does not match the one from service account file: 'different' != 'my-project-id'"
105
- )
106
-
107
-
108
- async def test_init_env_wrong_project_id(mocker: MockerFixture, allow_model_requests: None):
109
- mocker.patch('pydantic_ai.models.vertexai.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
110
- model = VertexAIModel('gemini-1.5-flash', project_id='different')
111
-
112
- with pytest.raises(UserError) as exc_info:
113
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
114
- assert str(exc_info.value) == snapshot(
115
- "The project_id you provided does not match the one from `google.auth.default()`: 'different' != 'my-project-id'"
116
- )
117
-
118
-
119
97
  async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_requests: None):
120
98
  mocker.patch(
121
99
  'pydantic_ai.models.vertexai.google.auth.default',
@@ -124,7 +102,7 @@ async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_request
124
102
  model = VertexAIModel('gemini-1.5-flash')
125
103
 
126
104
  with pytest.raises(UserError) as exc_info:
127
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
105
+ await model.ainit()
128
106
  assert str(exc_info.value) == snapshot('No project_id provided and none found in `google.auth.default()`')
129
107
 
130
108
 
@@ -326,13 +326,13 @@ def test_response_tuple():
326
326
  result = agent.run_sync('Hello')
327
327
  assert result.data == snapshot(('a', 'a'))
328
328
 
329
- assert m.agent_model_function_tools == snapshot([])
330
- assert m.agent_model_allow_text_result is False
329
+ assert m.last_model_request_parameters is not None
330
+ assert m.last_model_request_parameters.function_tools == snapshot([])
331
+ assert m.last_model_request_parameters.allow_text_result is False
331
332
 
332
- assert m.agent_model_result_tools is not None
333
- assert len(m.agent_model_result_tools) == 1
334
-
335
- assert m.agent_model_result_tools == snapshot(
333
+ assert m.last_model_request_parameters.result_tools is not None
334
+ assert len(m.last_model_request_parameters.result_tools) == 1
335
+ assert m.last_model_request_parameters.result_tools == snapshot(
336
336
  [
337
337
  ToolDefinition(
338
338
  name='final_result',
@@ -384,13 +384,14 @@ def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
384
384
  assert result.data == snapshot('success (no tool calls)')
385
385
  assert got_tool_call_name == snapshot(None)
386
386
 
387
- assert m.agent_model_function_tools == snapshot([])
388
- assert m.agent_model_allow_text_result is True
387
+ assert m.last_model_request_parameters is not None
388
+ assert m.last_model_request_parameters.function_tools == snapshot([])
389
+ assert m.last_model_request_parameters.allow_text_result is True
389
390
 
390
- assert m.agent_model_result_tools is not None
391
- assert len(m.agent_model_result_tools) == 1
391
+ assert m.last_model_request_parameters.result_tools is not None
392
+ assert len(m.last_model_request_parameters.result_tools) == 1
392
393
 
393
- assert m.agent_model_result_tools == snapshot(
394
+ assert m.last_model_request_parameters.result_tools == snapshot(
394
395
  [
395
396
  ToolDefinition(
396
397
  name='final_result',
@@ -459,13 +460,14 @@ class Bar(BaseModel):
459
460
  assert result.data == mod.Foo(a=0, b='a')
460
461
  assert got_tool_call_name == snapshot('final_result_Foo')
461
462
 
462
- assert m.agent_model_function_tools == snapshot([])
463
- assert m.agent_model_allow_text_result is False
463
+ assert m.last_model_request_parameters is not None
464
+ assert m.last_model_request_parameters.function_tools == snapshot([])
465
+ assert m.last_model_request_parameters.allow_text_result is False
464
466
 
465
- assert m.agent_model_result_tools is not None
466
- assert len(m.agent_model_result_tools) == 2
467
+ assert m.last_model_request_parameters.result_tools is not None
468
+ assert len(m.last_model_request_parameters.result_tools) == 2
467
469
 
468
- assert m.agent_model_result_tools == snapshot(
470
+ assert m.last_model_request_parameters.result_tools == snapshot(
469
471
  [
470
472
  ToolDefinition(
471
473
  name='final_result_Foo',
@@ -17,6 +17,7 @@ from pytest_examples import CodeExample, EvalExample, find_examples
17
17
  from pytest_mock import MockerFixture
18
18
 
19
19
  from pydantic_ai._utils import group_by_temporal
20
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
20
21
  from pydantic_ai.messages import (
21
22
  ModelMessage,
22
23
  ModelResponse,
@@ -288,6 +289,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
288
289
  )
289
290
  ]
290
291
  )
292
+ elif m.content.startswith('Write a list of 5 very rude things that I might say'):
293
+ raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
291
294
  elif m.content.startswith('<examples>\n <user>'):
292
295
  return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args={})])
293
296
  elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2:
@@ -75,14 +75,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
75
75
  'id': 0,
76
76
  'message': 'my_agent run prompt=Hello',
77
77
  'children': [
78
- {'id': 1, 'message': 'preparing model and tools run_step=1'},
78
+ {'id': 1, 'message': 'preparing model request params run_step=1'},
79
79
  {'id': 2, 'message': 'model request'},
80
80
  {
81
81
  'id': 3,
82
82
  'message': 'handle model response -> tool-return',
83
83
  'children': [{'id': 4, 'message': "running tools=['my_ret']"}],
84
84
  },
85
- {'id': 5, 'message': 'preparing model and tools run_step=2'},
85
+ {'id': 5, 'message': 'preparing model request params run_step=2'},
86
86
  {'id': 6, 'message': 'model request'},
87
87
  {'id': 7, 'message': 'handle model response -> final result'},
88
88
  ],
@@ -102,16 +102,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
102
102
  'custom_result_text': None,
103
103
  'custom_result_args': None,
104
104
  'seed': 0,
105
- 'agent_model_function_tools': None,
106
- 'agent_model_allow_text_result': None,
107
- 'agent_model_result_tools': None,
105
+ 'last_model_request_parameters': None,
108
106
  },
109
107
  'name': 'my_agent',
110
108
  'end_strategy': 'early',
111
109
  'model_settings': None,
112
110
  }
113
111
  ),
114
- 'model_name': 'test-model',
112
+ 'model_name': 'test',
115
113
  'agent_name': 'my_agent',
116
114
  'logfire.msg_template': '{agent_name} run {prompt=}',
117
115
  'logfire.msg': 'my_agent run prompt=Hello',
@@ -255,9 +253,9 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
255
253
  'code.function': 'test_logfire',
256
254
  'code.lineno': IsInt(),
257
255
  'run_step': 1,
258
- 'logfire.msg_template': 'preparing model and tools {run_step=}',
256
+ 'logfire.msg_template': 'preparing model request params {run_step=}',
259
257
  'logfire.span_type': 'span',
260
- 'logfire.msg': 'preparing model and tools run_step=1',
258
+ 'logfire.msg': 'preparing model request params run_step=1',
261
259
  'logfire.json_schema': '{"type":"object","properties":{"run_step":{}}}',
262
260
  }
263
261
  )
File without changes
File without changes
File without changes
File without changes