pydantic-ai 0.0.21__tar.gz → 0.0.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/.gitignore +3 -1
  2. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/PKG-INFO +3 -3
  3. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/pyproject.toml +7 -4
  4. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_anthropic.py +2 -1
  5. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_cohere.py +2 -1
  6. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_gemini.py +123 -22
  7. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_groq.py +2 -1
  8. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_mistral.py +4 -5
  9. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_model.py +31 -15
  10. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_model_function.py +3 -3
  11. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_model_names.py +1 -1
  12. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_openai.py +16 -2
  13. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_vertexai.py +15 -37
  14. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_agent.py +18 -16
  15. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_examples.py +4 -1
  16. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_logfire.py +6 -8
  17. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/LICENSE +0 -0
  18. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/Makefile +0 -0
  19. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/README.md +0 -0
  20. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/__init__.py +0 -0
  21. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/conftest.py +0 -0
  22. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/example_modules/README.md +0 -0
  23. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/example_modules/bank_database.py +0 -0
  24. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/example_modules/fake_database.py +0 -0
  25. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/example_modules/weather_service.py +0 -0
  26. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/graph/__init__.py +0 -0
  27. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/graph/test_graph.py +0 -0
  28. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/graph/test_history.py +0 -0
  29. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/graph/test_mermaid.py +0 -0
  30. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/graph/test_state.py +0 -0
  31. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/import_examples.py +0 -0
  32. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/__init__.py +0 -0
  33. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/mock_async_stream.py +0 -0
  34. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/models/test_model_test.py +0 -0
  35. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_deps.py +0 -0
  36. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_format_as_xml.py +0 -0
  37. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_live.py +0 -0
  38. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_parts_manager.py +0 -0
  39. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_streaming.py +0 -0
  40. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_tools.py +0 -0
  41. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_usage_limits.py +0 -0
  42. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/test_utils.py +0 -0
  43. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/typed_agent.py +0 -0
  44. {pydantic_ai-0.0.21 → pydantic_ai-0.0.23}/tests/typed_graph.py +0 -0
@@ -3,7 +3,8 @@ site
3
3
  .venv
4
4
  dist
5
5
  __pycache__
6
- *.env
6
+ .env
7
+ .dev.vars
7
8
  /scratch/
8
9
  /.coverage
9
10
  env*/
@@ -14,3 +15,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
14
15
  .cache/
15
16
  .vscode/
16
17
  /question_graph_history.json
18
+ /docs-site/.wrangler/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai
3
- Version: 0.0.21
3
+ Version: 0.0.23
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs
5
5
  Project-URL: Homepage, https://ai.pydantic.dev
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -32,9 +32,9 @@ Classifier: Programming Language :: Python :: 3.13
32
32
  Classifier: Topic :: Internet
33
33
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
34
34
  Requires-Python: >=3.9
35
- Requires-Dist: pydantic-ai-slim[anthropic,cohere,graph,groq,mistral,openai,vertexai]==0.0.21
35
+ Requires-Dist: pydantic-ai-slim[anthropic,cohere,groq,mistral,openai,vertexai]==0.0.23
36
36
  Provides-Extra: examples
37
- Requires-Dist: pydantic-ai-examples==0.0.21; extra == 'examples'
37
+ Requires-Dist: pydantic-ai-examples==0.0.23; extra == 'examples'
38
38
  Provides-Extra: logfire
39
39
  Requires-Dist: logfire>=2.3; extra == 'logfire'
40
40
  Description-Content-Type: text/markdown
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai"
7
- version = "0.0.21"
7
+ version = "0.0.23"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
@@ -37,7 +37,7 @@ classifiers = [
37
37
  ]
38
38
  requires-python = ">=3.9"
39
39
 
40
- dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.21"]
40
+ dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.23"]
41
41
 
42
42
  [project.urls]
43
43
  Homepage = "https://ai.pydantic.dev"
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
46
46
  Changelog = "https://github.com/pydantic/pydantic-ai/releases"
47
47
 
48
48
  [project.optional-dependencies]
49
- examples = ["pydantic-ai-examples==0.0.21"]
49
+ examples = ["pydantic-ai-examples==0.0.23"]
50
50
  logfire = ["logfire>=2.3"]
51
51
 
52
52
  [tool.uv.sources]
@@ -65,7 +65,6 @@ lint = [
65
65
  "ruff>=0.6.9",
66
66
  ]
67
67
  docs = [
68
- "algoliasearch>=4.12.0",
69
68
  "black>=24.10.0",
70
69
  "bs4>=0.0.2",
71
70
  "markdownify>=0.14.1",
@@ -74,6 +73,10 @@ docs = [
74
73
  "mkdocs-material[imaging]>=9.5.45",
75
74
  "mkdocstrings-python>=1.12.2",
76
75
  ]
76
+ docs-upload = [
77
+ "algoliasearch>=4.12.0",
78
+ "pydantic>=2.10.1",
79
+ ]
77
80
 
78
81
  [tool.hatch.build.targets.wheel]
79
82
  only-include = ["/README.md"]
@@ -60,7 +60,8 @@ T = TypeVar('T')
60
60
  def test_init():
61
61
  m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
62
62
  assert m.client.api_key == 'foobar'
63
- assert m.name() == 'anthropic:claude-3-5-haiku-latest'
63
+ assert m.model_name == 'claude-3-5-haiku-latest'
64
+ assert m.system == 'anthropic'
64
65
 
65
66
 
66
67
  @dataclass
@@ -44,7 +44,8 @@ pytestmark = [
44
44
 
45
45
  def test_init():
46
46
  m = CohereModel('command-r7b-12-2024', api_key='foobar')
47
- assert m.name() == 'cohere:command-r7b-12-2024'
47
+ assert m.model_name == 'command-r7b-12-2024'
48
+ assert m.system == 'cohere'
48
49
 
49
50
 
50
51
  @dataclass
@@ -24,9 +24,11 @@ from pydantic_ai.messages import (
24
24
  ToolReturnPart,
25
25
  UserPromptPart,
26
26
  )
27
+ from pydantic_ai.models import ModelRequestParameters
27
28
  from pydantic_ai.models.gemini import (
28
29
  ApiKeyAuth,
29
30
  GeminiModel,
31
+ GeminiModelSettings,
30
32
  _content_model_response,
31
33
  _function_call_part_from_call,
32
34
  _gemini_response_ta,
@@ -36,6 +38,7 @@ from pydantic_ai.models.gemini import (
36
38
  _GeminiFunction,
37
39
  _GeminiFunctionCallingConfig,
38
40
  _GeminiResponse,
41
+ _GeminiSafetyRating,
39
42
  _GeminiTextPart,
40
43
  _GeminiToolConfig,
41
44
  _GeminiTools,
@@ -75,18 +78,21 @@ def test_api_key_empty(env: TestEnv):
75
78
  GeminiModel('gemini-1.5-flash')
76
79
 
77
80
 
78
- async def test_agent_model_simple(allow_model_requests: None):
81
+ async def test_model_simple(allow_model_requests: None):
79
82
  m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
80
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
81
- assert isinstance(agent_model.http_client, httpx.AsyncClient)
82
- assert agent_model.model_name == 'gemini-1.5-flash'
83
- assert isinstance(agent_model.auth, ApiKeyAuth)
84
- assert agent_model.auth.api_key == 'via-arg'
85
- assert agent_model.tools is None
86
- assert agent_model.tool_config is None
83
+ assert isinstance(m.http_client, httpx.AsyncClient)
84
+ assert m.model_name == 'gemini-1.5-flash'
85
+ assert isinstance(m.auth, ApiKeyAuth)
86
+ assert m.auth.api_key == 'via-arg'
87
87
 
88
+ arc = ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[])
89
+ tools = m._get_tools(arc)
90
+ tool_config = m._get_tool_config(arc, tools)
91
+ assert tools is None
92
+ assert tool_config is None
88
93
 
89
- async def test_agent_model_tools(allow_model_requests: None):
94
+
95
+ async def test_model_tools(allow_model_requests: None):
90
96
  m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
91
97
  tools = [
92
98
  ToolDefinition(
@@ -110,8 +116,11 @@ async def test_agent_model_tools(allow_model_requests: None):
110
116
  'This is the tool for the final Result',
111
117
  {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
112
118
  )
113
- agent_model = await m.agent_model(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
114
- assert agent_model.tools == snapshot(
119
+
120
+ arc = ModelRequestParameters(function_tools=tools, allow_text_result=True, result_tools=[result_tool])
121
+ tools = m._get_tools(arc)
122
+ tool_config = m._get_tool_config(arc, tools)
123
+ assert tools == snapshot(
115
124
  _GeminiTools(
116
125
  function_declarations=[
117
126
  _GeminiFunction(
@@ -139,7 +148,7 @@ async def test_agent_model_tools(allow_model_requests: None):
139
148
  ]
140
149
  )
141
150
  )
142
- assert agent_model.tool_config is None
151
+ assert tool_config is None
143
152
 
144
153
 
145
154
  async def test_require_response_tool(allow_model_requests: None):
@@ -149,8 +158,10 @@ async def test_require_response_tool(allow_model_requests: None):
149
158
  'This is the tool for the final Result',
150
159
  {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
151
160
  )
152
- agent_model = await m.agent_model(function_tools=[], allow_text_result=False, result_tools=[result_tool])
153
- assert agent_model.tools == snapshot(
161
+ arc = ModelRequestParameters(function_tools=[], allow_text_result=False, result_tools=[result_tool])
162
+ tools = m._get_tools(arc)
163
+ tool_config = m._get_tool_config(arc, tools)
164
+ assert tools == snapshot(
154
165
  _GeminiTools(
155
166
  function_declarations=[
156
167
  _GeminiFunction(
@@ -164,7 +175,7 @@ async def test_require_response_tool(allow_model_requests: None):
164
175
  ]
165
176
  )
166
177
  )
167
- assert agent_model.tool_config == snapshot(
178
+ assert tool_config == snapshot(
168
179
  _GeminiToolConfig(
169
180
  function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result'])
170
181
  )
@@ -206,8 +217,9 @@ async def test_json_def_replaced(allow_model_requests: None):
206
217
  'This is the tool for the final Result',
207
218
  json_schema,
208
219
  )
209
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
210
- assert agent_model.tools == snapshot(
220
+ assert m._get_tools(
221
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
222
+ ) == snapshot(
211
223
  _GeminiTools(
212
224
  function_declarations=[
213
225
  _GeminiFunction(
@@ -252,8 +264,9 @@ async def test_json_def_replaced_any_of(allow_model_requests: None):
252
264
  'This is the tool for the final Result',
253
265
  json_schema,
254
266
  )
255
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
256
- assert agent_model.tools == snapshot(
267
+ assert m._get_tools(
268
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
269
+ ) == snapshot(
257
270
  _GeminiTools(
258
271
  function_declarations=[
259
272
  _GeminiFunction(
@@ -315,7 +328,7 @@ async def test_json_def_recursive(allow_model_requests: None):
315
328
  json_schema,
316
329
  )
317
330
  with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'):
318
- await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
331
+ m._get_tools(ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool]))
319
332
 
320
333
 
321
334
  async def test_json_def_date(allow_model_requests: None):
@@ -346,8 +359,9 @@ async def test_json_def_date(allow_model_requests: None):
346
359
  'This is the tool for the final Result',
347
360
  json_schema,
348
361
  )
349
- agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool])
350
- assert agent_model.tools == snapshot(
362
+ assert m._get_tools(
363
+ ModelRequestParameters(function_tools=[], allow_text_result=True, result_tools=[result_tool])
364
+ ) == snapshot(
351
365
  _GeminiTools(
352
366
  function_declarations=[
353
367
  _GeminiFunction(
@@ -853,3 +867,90 @@ async def test_model_settings(client_with_handler: ClientWithHandler, env: TestE
853
867
  },
854
868
  )
855
869
  assert result.data == 'world'
870
+
871
+
872
+ def gemini_no_content_response(
873
+ safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY'
874
+ ) -> _GeminiResponse:
875
+ candidate = _GeminiCandidates(safety_ratings=safety_ratings)
876
+ if finish_reason:
877
+ candidate['finish_reason'] = finish_reason
878
+ return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
879
+
880
+
881
+ async def test_safety_settings_unsafe(
882
+ client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
883
+ ) -> None:
884
+ try:
885
+
886
+ def handler(request: httpx.Request) -> httpx.Response:
887
+ safety_settings = json.loads(request.content)['safety_settings']
888
+ assert safety_settings == [
889
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
890
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
891
+ ]
892
+
893
+ return httpx.Response(
894
+ 200,
895
+ content=_gemini_response_ta.dump_json(
896
+ gemini_no_content_response(
897
+ finish_reason='SAFETY',
898
+ safety_ratings=[
899
+ {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True}
900
+ ],
901
+ ),
902
+ by_alias=True,
903
+ ),
904
+ headers={'Content-Type': 'application/json'},
905
+ )
906
+
907
+ gemini_client = client_with_handler(handler)
908
+ m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
909
+ agent = Agent(m)
910
+
911
+ await agent.run(
912
+ 'a request for something rude',
913
+ model_settings=GeminiModelSettings(
914
+ gemini_safety_settings=[
915
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
916
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
917
+ ]
918
+ ),
919
+ )
920
+ except UnexpectedModelBehavior as e:
921
+ assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')"
922
+
923
+
924
+ async def test_safety_settings_safe(
925
+ client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
926
+ ) -> None:
927
+ def handler(request: httpx.Request) -> httpx.Response:
928
+ safety_settings = json.loads(request.content)['safety_settings']
929
+ assert safety_settings == [
930
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
931
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
932
+ ]
933
+
934
+ return httpx.Response(
935
+ 200,
936
+ content=_gemini_response_ta.dump_json(
937
+ gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
938
+ by_alias=True,
939
+ ),
940
+ headers={'Content-Type': 'application/json'},
941
+ )
942
+
943
+ gemini_client = client_with_handler(handler)
944
+ m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
945
+ agent = Agent(m)
946
+
947
+ result = await agent.run(
948
+ 'hello',
949
+ model_settings=GeminiModelSettings(
950
+ gemini_safety_settings=[
951
+ {'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
952
+ {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
953
+ ]
954
+ ),
955
+ )
956
+ assert result.data == 'world'
@@ -52,7 +52,8 @@ pytestmark = [
52
52
  def test_init():
53
53
  m = GroqModel('llama-3.3-70b-versatile', api_key='foobar')
54
54
  assert m.client.api_key == 'foobar'
55
- assert m.name() == 'groq:llama-3.3-70b-versatile'
55
+ assert m.model_name == 'llama-3.3-70b-versatile'
56
+ assert m.system == 'groq'
56
57
 
57
58
 
58
59
  @dataclass
@@ -48,7 +48,6 @@ with try_import() as imports_successful:
48
48
  from mistralai.types.basemodel import Unset as MistralUnset
49
49
 
50
50
  from pydantic_ai.models.mistral import (
51
- MistralAgentModel,
52
51
  MistralModel,
53
52
  MistralStreamedResponse,
54
53
  )
@@ -1668,8 +1667,8 @@ def test_generate_user_output_format_complex():
1668
1667
  'prop_unrecognized_type': {'type': 'customSomething'},
1669
1668
  }
1670
1669
  }
1671
- mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}')
1672
- result = mam._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
1670
+ m = MistralModel('', json_mode_schema_prompt='{schema}')
1671
+ result = m._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
1673
1672
  assert result.content == (
1674
1673
  "{'prop_anyOf': 'Optional[str]', "
1675
1674
  "'prop_no_type': 'Any', "
@@ -1685,8 +1684,8 @@ def test_generate_user_output_format_complex():
1685
1684
 
1686
1685
  def test_generate_user_output_format_multiple():
1687
1686
  schema = {'properties': {'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}}}
1688
- mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}')
1689
- result = mam._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
1687
+ m = MistralModel('', json_mode_schema_prompt='{schema}')
1688
+ result = m._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
1690
1689
  assert result.content == "[{'prop_anyOf': 'Optional[str]'}, {'prop_anyOf': 'Optional[str]'}]"
1691
1690
 
1692
1691
 
@@ -8,66 +8,81 @@ from pydantic_ai.models import infer_model
8
8
  from ..conftest import TestEnv
9
9
 
10
10
  TEST_CASES = [
11
- ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
12
- ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
13
- ('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'),
14
- ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
15
- ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
11
+ ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
12
+ ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
13
+ ('OPENAI_API_KEY', 'o1', 'o1', 'openai', 'openai', 'OpenAIModel'),
14
+ ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
15
+ ('GEMINI_API_KEY', 'gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
16
16
  (
17
17
  'GEMINI_API_KEY',
18
18
  'google-vertex:gemini-1.5-flash',
19
- 'google-vertex:gemini-1.5-flash',
19
+ 'gemini-1.5-flash',
20
+ 'google-vertex',
20
21
  'vertexai',
21
22
  'VertexAIModel',
22
23
  ),
23
24
  (
24
25
  'GEMINI_API_KEY',
25
26
  'vertexai:gemini-1.5-flash',
26
- 'google-vertex:gemini-1.5-flash',
27
+ 'gemini-1.5-flash',
28
+ 'google-vertex',
27
29
  'vertexai',
28
30
  'VertexAIModel',
29
31
  ),
30
32
  (
31
33
  'ANTHROPIC_API_KEY',
32
34
  'anthropic:claude-3-5-haiku-latest',
33
- 'anthropic:claude-3-5-haiku-latest',
35
+ 'claude-3-5-haiku-latest',
36
+ 'anthropic',
34
37
  'anthropic',
35
38
  'AnthropicModel',
36
39
  ),
37
40
  (
38
41
  'ANTHROPIC_API_KEY',
39
42
  'claude-3-5-haiku-latest',
40
- 'anthropic:claude-3-5-haiku-latest',
43
+ 'claude-3-5-haiku-latest',
44
+ 'anthropic',
41
45
  'anthropic',
42
46
  'AnthropicModel',
43
47
  ),
44
48
  (
45
49
  'GROQ_API_KEY',
46
50
  'groq:llama-3.3-70b-versatile',
47
- 'groq:llama-3.3-70b-versatile',
51
+ 'llama-3.3-70b-versatile',
52
+ 'groq',
48
53
  'groq',
49
54
  'GroqModel',
50
55
  ),
51
56
  (
52
57
  'MISTRAL_API_KEY',
53
58
  'mistral:mistral-small-latest',
54
- 'mistral:mistral-small-latest',
59
+ 'mistral-small-latest',
60
+ 'mistral',
55
61
  'mistral',
56
62
  'MistralModel',
57
63
  ),
58
64
  (
59
65
  'CO_API_KEY',
60
66
  'cohere:command',
61
- 'cohere:command',
67
+ 'command',
68
+ 'cohere',
62
69
  'cohere',
63
70
  'CohereModel',
64
71
  ),
65
72
  ]
66
73
 
67
74
 
68
- @pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES)
75
+ @pytest.mark.parametrize(
76
+ 'mock_api_key, model_name, expected_model_name, expected_system, module_name, model_class_name', TEST_CASES
77
+ )
69
78
  def test_infer_model(
70
- env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str
79
+ env: TestEnv,
80
+ mock_api_key: str,
81
+ model_name: str,
82
+ expected_model_name: str,
83
+ expected_system: str,
84
+ module_name: str,
85
+ model_class_name: str,
71
86
  ):
72
87
  try:
73
88
  model_module = import_module(f'pydantic_ai.models.{module_name}')
@@ -79,7 +94,8 @@ def test_infer_model(
79
94
 
80
95
  m = infer_model(model_name) # pyright: ignore[reportArgumentType]
81
96
  assert isinstance(m, expected_model)
82
- assert m.name() == expected_model_name
97
+ assert m.model_name == expected_model_name
98
+ assert m.system == expected_system
83
99
 
84
100
  m2 = infer_model(m)
85
101
  assert m2 is m
@@ -41,13 +41,13 @@ async def stream_hello(_messages: list[ModelMessage], _agent_info: AgentInfo) ->
41
41
 
42
42
  def test_init() -> None:
43
43
  m = FunctionModel(function=hello)
44
- assert m.name() == 'function:hello:'
44
+ assert m.model_name == 'function:hello:'
45
45
 
46
46
  m1 = FunctionModel(stream_function=stream_hello)
47
- assert m1.name() == 'function::stream_hello'
47
+ assert m1.model_name == 'function::stream_hello'
48
48
 
49
49
  m2 = FunctionModel(function=hello, stream_function=stream_hello)
50
- assert m2.name() == 'function:hello:stream_hello'
50
+ assert m2.model_name == 'function:hello:stream_hello'
51
51
 
52
52
 
53
53
  async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
@@ -38,7 +38,7 @@ def test_known_model_names():
38
38
  groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)]
39
39
  mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)]
40
40
  openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [
41
- n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt')
41
+ n for n in get_model_names(OpenAIModelName) if n.startswith('o1') or n.startswith('gpt') or n.startswith('o3')
42
42
  ]
43
43
  extra_names = ['test']
44
44
 
@@ -25,11 +25,11 @@ from pydantic_ai.messages import (
25
25
  from pydantic_ai.result import Usage
26
26
  from pydantic_ai.settings import ModelSettings
27
27
 
28
- from ..conftest import IsNow, try_import
28
+ from ..conftest import IsNow, TestEnv, try_import
29
29
  from .mock_async_stream import MockAsyncStream
30
30
 
31
31
  with try_import() as imports_successful:
32
- from openai import NOT_GIVEN, AsyncOpenAI
32
+ from openai import NOT_GIVEN, AsyncOpenAI, OpenAIError
33
33
  from openai.types import chat
34
34
  from openai.types.chat.chat_completion import Choice
35
35
  from openai.types.chat.chat_completion_chunk import (
@@ -65,6 +65,20 @@ def test_init_with_base_url():
65
65
  m.name()
66
66
 
67
67
 
68
+ def test_init_with_non_openai_model():
69
+ m = OpenAIModel('llama3.2-vision:latest', base_url='https://example.com/v1/')
70
+ m.name()
71
+
72
+
73
+ def test_init_of_openai_without_api_key_raises_error(env: TestEnv):
74
+ env.remove('OPENAI_API_KEY')
75
+ with pytest.raises(
76
+ OpenAIError,
77
+ match='^The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable$',
78
+ ):
79
+ OpenAIModel('gpt-4o')
80
+
81
+
68
82
  @dataclass
69
83
  class MockOpenAI:
70
84
  completions: chat.ChatCompletion | list[chat.ChatCompletion] | None = None
@@ -30,17 +30,18 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
30
30
  save_service_account(service_account_path, 'my-project-id')
31
31
 
32
32
  model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path)
33
- assert model.url is None
34
- assert model.auth is None
33
+ assert model._url is None
34
+ assert model._auth is None
35
35
 
36
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
36
+ await model.ainit()
37
37
 
38
38
  assert model.url == snapshot(
39
39
  'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
40
40
  'publishers/google/models/gemini-1.5-flash:'
41
41
  )
42
42
  assert model.auth is not None
43
- assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
43
+ assert model.model_name == snapshot('gemini-1.5-flash')
44
+ assert model.system == snapshot('google-vertex')
44
45
 
45
46
 
46
47
  class NoOpCredentials:
@@ -53,12 +54,12 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
53
54
  return_value=(NoOpCredentials(), 'my-project-id'),
54
55
  )
55
56
  model = VertexAIModel('gemini-1.5-flash')
56
- assert model.url is None
57
- assert model.auth is None
57
+ assert model._url is None
58
+ assert model._auth is None
58
59
 
59
60
  assert patch.call_count == 0
60
61
 
61
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
62
+ await model.ainit()
62
63
 
63
64
  assert patch.call_count == 1
64
65
 
@@ -67,9 +68,10 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
67
68
  'publishers/google/models/gemini-1.5-flash:'
68
69
  )
69
70
  assert model.auth is not None
70
- assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
71
+ assert model.model_name == snapshot('gemini-1.5-flash')
72
+ assert model.system == snapshot('google-vertex')
71
73
 
72
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
74
+ await model.ainit()
73
75
  assert model.url is not None
74
76
  assert model.auth is not None
75
77
  assert patch.call_count == 1
@@ -80,10 +82,10 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
80
82
  save_service_account(service_account_path, 'my-project-id')
81
83
 
82
84
  model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='my-project-id')
83
- assert model.url is None
84
- assert model.auth is None
85
+ assert model._url is None
86
+ assert model._auth is None
85
87
 
86
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
88
+ await model.ainit()
87
89
 
88
90
  assert model.url == snapshot(
89
91
  'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
@@ -92,30 +94,6 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
92
94
  assert model.auth is not None
93
95
 
94
96
 
95
- async def test_init_service_account_wrong_project_id(tmp_path: Path, allow_model_requests: None):
96
- service_account_path = tmp_path / 'service_account.json'
97
- save_service_account(service_account_path, 'my-project-id')
98
-
99
- model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='different')
100
-
101
- with pytest.raises(UserError) as exc_info:
102
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
103
- assert str(exc_info.value) == snapshot(
104
- "The project_id you provided does not match the one from service account file: 'different' != 'my-project-id'"
105
- )
106
-
107
-
108
- async def test_init_env_wrong_project_id(mocker: MockerFixture, allow_model_requests: None):
109
- mocker.patch('pydantic_ai.models.vertexai.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
110
- model = VertexAIModel('gemini-1.5-flash', project_id='different')
111
-
112
- with pytest.raises(UserError) as exc_info:
113
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
114
- assert str(exc_info.value) == snapshot(
115
- "The project_id you provided does not match the one from `google.auth.default()`: 'different' != 'my-project-id'"
116
- )
117
-
118
-
119
97
  async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_requests: None):
120
98
  mocker.patch(
121
99
  'pydantic_ai.models.vertexai.google.auth.default',
@@ -124,7 +102,7 @@ async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_request
124
102
  model = VertexAIModel('gemini-1.5-flash')
125
103
 
126
104
  with pytest.raises(UserError) as exc_info:
127
- await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
105
+ await model.ainit()
128
106
  assert str(exc_info.value) == snapshot('No project_id provided and none found in `google.auth.default()`')
129
107
 
130
108
 
@@ -326,13 +326,13 @@ def test_response_tuple():
326
326
  result = agent.run_sync('Hello')
327
327
  assert result.data == snapshot(('a', 'a'))
328
328
 
329
- assert m.agent_model_function_tools == snapshot([])
330
- assert m.agent_model_allow_text_result is False
329
+ assert m.last_model_request_parameters is not None
330
+ assert m.last_model_request_parameters.function_tools == snapshot([])
331
+ assert m.last_model_request_parameters.allow_text_result is False
331
332
 
332
- assert m.agent_model_result_tools is not None
333
- assert len(m.agent_model_result_tools) == 1
334
-
335
- assert m.agent_model_result_tools == snapshot(
333
+ assert m.last_model_request_parameters.result_tools is not None
334
+ assert len(m.last_model_request_parameters.result_tools) == 1
335
+ assert m.last_model_request_parameters.result_tools == snapshot(
336
336
  [
337
337
  ToolDefinition(
338
338
  name='final_result',
@@ -384,13 +384,14 @@ def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
384
384
  assert result.data == snapshot('success (no tool calls)')
385
385
  assert got_tool_call_name == snapshot(None)
386
386
 
387
- assert m.agent_model_function_tools == snapshot([])
388
- assert m.agent_model_allow_text_result is True
387
+ assert m.last_model_request_parameters is not None
388
+ assert m.last_model_request_parameters.function_tools == snapshot([])
389
+ assert m.last_model_request_parameters.allow_text_result is True
389
390
 
390
- assert m.agent_model_result_tools is not None
391
- assert len(m.agent_model_result_tools) == 1
391
+ assert m.last_model_request_parameters.result_tools is not None
392
+ assert len(m.last_model_request_parameters.result_tools) == 1
392
393
 
393
- assert m.agent_model_result_tools == snapshot(
394
+ assert m.last_model_request_parameters.result_tools == snapshot(
394
395
  [
395
396
  ToolDefinition(
396
397
  name='final_result',
@@ -459,13 +460,14 @@ class Bar(BaseModel):
459
460
  assert result.data == mod.Foo(a=0, b='a')
460
461
  assert got_tool_call_name == snapshot('final_result_Foo')
461
462
 
462
- assert m.agent_model_function_tools == snapshot([])
463
- assert m.agent_model_allow_text_result is False
463
+ assert m.last_model_request_parameters is not None
464
+ assert m.last_model_request_parameters.function_tools == snapshot([])
465
+ assert m.last_model_request_parameters.allow_text_result is False
464
466
 
465
- assert m.agent_model_result_tools is not None
466
- assert len(m.agent_model_result_tools) == 2
467
+ assert m.last_model_request_parameters.result_tools is not None
468
+ assert len(m.last_model_request_parameters.result_tools) == 2
467
469
 
468
- assert m.agent_model_result_tools == snapshot(
470
+ assert m.last_model_request_parameters.result_tools == snapshot(
469
471
  [
470
472
  ToolDefinition(
471
473
  name='final_result_Foo',
@@ -17,6 +17,7 @@ from pytest_examples import CodeExample, EvalExample, find_examples
17
17
  from pytest_mock import MockerFixture
18
18
 
19
19
  from pydantic_ai._utils import group_by_temporal
20
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
20
21
  from pydantic_ai.messages import (
21
22
  ModelMessage,
22
23
  ModelResponse,
@@ -179,7 +180,7 @@ text_responses: dict[str, str | ToolCallPart] = {
179
180
  'The weather in West London is raining, while in Wiltshire it is sunny.'
180
181
  ),
181
182
  'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.',
182
- 'Explain?': 'This is an excellent joke invent by Samuel Colvin, it needs no explanation.',
183
+ 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.',
183
184
  'What is the capital of France?': 'Paris',
184
185
  'What is the capital of Italy?': 'Rome',
185
186
  'What is the capital of the UK?': 'London',
@@ -288,6 +289,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
288
289
  )
289
290
  ]
290
291
  )
292
+ elif m.content.startswith('Write a list of 5 very rude things that I might say'):
293
+ raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
291
294
  elif m.content.startswith('<examples>\n <user>'):
292
295
  return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args={})])
293
296
  elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2:
@@ -75,14 +75,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
75
75
  'id': 0,
76
76
  'message': 'my_agent run prompt=Hello',
77
77
  'children': [
78
- {'id': 1, 'message': 'preparing model and tools run_step=1'},
78
+ {'id': 1, 'message': 'preparing model request params run_step=1'},
79
79
  {'id': 2, 'message': 'model request'},
80
80
  {
81
81
  'id': 3,
82
82
  'message': 'handle model response -> tool-return',
83
83
  'children': [{'id': 4, 'message': "running tools=['my_ret']"}],
84
84
  },
85
- {'id': 5, 'message': 'preparing model and tools run_step=2'},
85
+ {'id': 5, 'message': 'preparing model request params run_step=2'},
86
86
  {'id': 6, 'message': 'model request'},
87
87
  {'id': 7, 'message': 'handle model response -> final result'},
88
88
  ],
@@ -102,16 +102,14 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
102
102
  'custom_result_text': None,
103
103
  'custom_result_args': None,
104
104
  'seed': 0,
105
- 'agent_model_function_tools': None,
106
- 'agent_model_allow_text_result': None,
107
- 'agent_model_result_tools': None,
105
+ 'last_model_request_parameters': None,
108
106
  },
109
107
  'name': 'my_agent',
110
108
  'end_strategy': 'early',
111
109
  'model_settings': None,
112
110
  }
113
111
  ),
114
- 'model_name': 'test-model',
112
+ 'model_name': 'test',
115
113
  'agent_name': 'my_agent',
116
114
  'logfire.msg_template': '{agent_name} run {prompt=}',
117
115
  'logfire.msg': 'my_agent run prompt=Hello',
@@ -255,9 +253,9 @@ def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
255
253
  'code.function': 'test_logfire',
256
254
  'code.lineno': IsInt(),
257
255
  'run_step': 1,
258
- 'logfire.msg_template': 'preparing model and tools {run_step=}',
256
+ 'logfire.msg_template': 'preparing model request params {run_step=}',
259
257
  'logfire.span_type': 'span',
260
- 'logfire.msg': 'preparing model and tools run_step=1',
258
+ 'logfire.msg': 'preparing model request params run_step=1',
261
259
  'logfire.json_schema': '{"type":"object","properties":{"run_step":{}}}',
262
260
  }
263
261
  )
File without changes
File without changes
File without changes