pydantic-ai 0.0.40__tar.gz → 0.0.41__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai might be problematic. Click here for more details.

Files changed (96) hide show
  1. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/PKG-INFO +3 -3
  2. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/pyproject.toml +3 -3
  3. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_anthropic.py +33 -16
  4. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_fallback.py +76 -0
  5. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_mistral.py +33 -26
  6. pydantic_ai-0.0.41/tests/providers/test_anthropic.py +56 -0
  7. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_groq.py +0 -9
  8. pydantic_ai-0.0.41/tests/providers/test_mistral.py +58 -0
  9. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_provider_names.py +5 -1
  10. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_live.py +4 -2
  11. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/.gitignore +0 -0
  12. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/LICENSE +0 -0
  13. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/Makefile +0 -0
  14. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/README.md +0 -0
  15. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/__init__.py +0 -0
  16. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/assets/dummy.pdf +0 -0
  17. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/assets/kiwi.png +0 -0
  18. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/assets/marcelo.mp3 +0 -0
  19. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/conftest.py +0 -0
  20. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/example_modules/README.md +0 -0
  21. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/example_modules/bank_database.py +0 -0
  22. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/example_modules/fake_database.py +0 -0
  23. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/example_modules/weather_service.py +0 -0
  24. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/__init__.py +0 -0
  25. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_file_persistence.py +0 -0
  26. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_graph.py +0 -0
  27. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_mermaid.py +0 -0
  28. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_persistence.py +0 -0
  29. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_state.py +0 -0
  30. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/graph/test_utils.py +0 -0
  31. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/import_examples.py +0 -0
  32. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/json_body_serializer.py +0 -0
  33. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/__init__.py +0 -0
  34. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_document_binary_content_input.yaml +0 -0
  35. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_document_url_input.yaml +0 -0
  36. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_image_url_input.yaml +0 -0
  37. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_image_url_input_invalid_mime_type.yaml +0 -0
  38. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_multiple_parallel_tool_calls.yaml +0 -0
  39. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_anthropic/test_text_document_url_input.yaml +0 -0
  40. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model.yaml +0 -0
  41. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_anthropic_model_without_tools.yaml +0 -0
  42. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_iter_stream.yaml +0 -0
  43. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_max_tokens.yaml +0 -0
  44. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_retry.yaml +0 -0
  45. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_stream.yaml +0 -0
  46. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_structured_response.yaml +0 -0
  47. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_bedrock_model_top_p.yaml +0 -0
  48. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_document_url_input.yaml +0 -0
  49. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_image_as_binary_content_input.yaml +0 -0
  50. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_image_url_input.yaml +0 -0
  51. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_text_as_binary_content_input.yaml +0 -0
  52. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_bedrock/test_text_document_url_input.yaml +0 -0
  53. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_gemini/test_document_url_input.yaml +0 -0
  54. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_gemini/test_image_as_binary_content_input.yaml +0 -0
  55. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_gemini/test_image_url_input.yaml +0 -0
  56. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_groq/test_image_as_binary_content_input.yaml +0 -0
  57. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_groq/test_image_url_input.yaml +0 -0
  58. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_openai/test_audio_as_binary_content_input.yaml +0 -0
  59. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_openai/test_document_url_input.yaml +0 -0
  60. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_openai/test_image_as_binary_content_input.yaml +0 -0
  61. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[developer].yaml +0 -0
  62. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/cassettes/test_openai/test_openai_o1_mini_system_role[system].yaml +0 -0
  63. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/mock_async_stream.py +0 -0
  64. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_bedrock.py +0 -0
  65. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_cohere.py +0 -0
  66. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_gemini.py +0 -0
  67. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_groq.py +0 -0
  68. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_instrumented.py +0 -0
  69. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_model.py +0 -0
  70. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_model_function.py +0 -0
  71. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_model_names.py +0 -0
  72. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_model_test.py +0 -0
  73. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_openai.py +0 -0
  74. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/models/test_vertexai.py +0 -0
  75. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/__init__.py +0 -0
  76. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/cassettes/test_azure/test_azure_provider_call.yaml +0 -0
  77. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_azure.py +0 -0
  78. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_bedrock.py +0 -0
  79. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_deepseek.py +0 -0
  80. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_google_gla.py +0 -0
  81. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/providers/test_google_vertex.py +0 -0
  82. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_agent.py +0 -0
  83. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_cli.py +0 -0
  84. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_deps.py +0 -0
  85. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_examples.py +0 -0
  86. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_format_as_xml.py +0 -0
  87. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_json_body_serializer.py +0 -0
  88. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_logfire.py +0 -0
  89. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_messages.py +0 -0
  90. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_parts_manager.py +0 -0
  91. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_streaming.py +0 -0
  92. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_tools.py +0 -0
  93. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_usage_limits.py +0 -0
  94. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/test_utils.py +0 -0
  95. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/typed_agent.py +0 -0
  96. {pydantic_ai-0.0.40 → pydantic_ai-0.0.41}/tests/typed_graph.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai
3
- Version: 0.0.40
3
+ Version: 0.0.41
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs
5
5
  Project-URL: Homepage, https://ai.pydantic.dev
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai
@@ -28,9 +28,9 @@ Classifier: Topic :: Internet
28
28
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
29
29
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
30
30
  Requires-Python: >=3.9
31
- Requires-Dist: pydantic-ai-slim[anthropic,bedrock,cli,cohere,groq,mistral,openai,vertexai]==0.0.40
31
+ Requires-Dist: pydantic-ai-slim[anthropic,bedrock,cli,cohere,groq,mistral,openai,vertexai]==0.0.41
32
32
  Provides-Extra: examples
33
- Requires-Dist: pydantic-ai-examples==0.0.40; extra == 'examples'
33
+ Requires-Dist: pydantic-ai-examples==0.0.41; extra == 'examples'
34
34
  Provides-Extra: logfire
35
35
  Requires-Dist: logfire>=2.3; extra == 'logfire'
36
36
  Description-Content-Type: text/markdown
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai"
7
- version = "0.0.40"
7
+ version = "0.0.41"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
@@ -36,7 +36,7 @@ classifiers = [
36
36
  ]
37
37
  requires-python = ">=3.9"
38
38
  dependencies = [
39
- "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere,bedrock,cli]==0.0.40",
39
+ "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere,bedrock,cli]==0.0.41",
40
40
  ]
41
41
 
42
42
  [project.urls]
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
46
46
  Changelog = "https://github.com/pydantic/pydantic-ai/releases"
47
47
 
48
48
  [project.optional-dependencies]
49
- examples = ["pydantic-ai-examples==0.0.40"]
49
+ examples = ["pydantic-ai-examples==0.0.41"]
50
50
  logfire = ["logfire>=2.3"]
51
51
 
52
52
  [tool.uv.sources]
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
7
7
  from datetime import timezone
8
8
  from functools import cached_property
9
9
  from typing import Any, TypeVar, Union, cast
10
+ from unittest.mock import patch
10
11
 
11
12
  import httpx
12
13
  import pytest
@@ -53,6 +54,7 @@ with try_import() as imports_successful:
53
54
  from anthropic.types.raw_message_delta_event import Delta
54
55
 
55
56
  from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
57
+ from pydantic_ai.providers.anthropic import AnthropicProvider
56
58
 
57
59
  # note: we use Union here so that casting works with Python 3.9
58
60
  MockAnthropicMessage = Union[AnthropicMessage, Exception]
@@ -68,7 +70,7 @@ T = TypeVar('T')
68
70
 
69
71
 
70
72
  def test_init():
71
- m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
73
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
72
74
  assert m.client.api_key == 'foobar'
73
75
  assert m.model_name == 'claude-3-5-haiku-latest'
74
76
  assert m.system == 'anthropic'
@@ -81,6 +83,7 @@ class MockAnthropic:
81
83
  stream: Sequence[MockRawMessageStreamEvent] | Sequence[Sequence[MockRawMessageStreamEvent]] | None = None
82
84
  index = 0
83
85
  chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
86
+ base_url: str | None = None
84
87
 
85
88
  @cached_property
86
89
  def messages(self) -> Any:
@@ -134,7 +137,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
134
137
  async def test_sync_request_text_response(allow_model_requests: None):
135
138
  c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
136
139
  mock_client = MockAnthropic.create_mock(c)
137
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
140
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
138
141
  agent = Agent(m)
139
142
 
140
143
  result = await agent.run('hello')
@@ -171,7 +174,7 @@ async def test_async_request_text_response(allow_model_requests: None):
171
174
  usage=AnthropicUsage(input_tokens=3, output_tokens=5),
172
175
  )
173
176
  mock_client = MockAnthropic.create_mock(c)
174
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
177
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
175
178
  agent = Agent(m)
176
179
 
177
180
  result = await agent.run('hello')
@@ -185,7 +188,7 @@ async def test_request_structured_response(allow_model_requests: None):
185
188
  usage=AnthropicUsage(input_tokens=3, output_tokens=5),
186
189
  )
187
190
  mock_client = MockAnthropic.create_mock(c)
188
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
191
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
189
192
  agent = Agent(m, result_type=list[int])
190
193
 
191
194
  result = await agent.run('hello')
@@ -235,7 +238,7 @@ async def test_request_tool_call(allow_model_requests: None):
235
238
  ]
236
239
 
237
240
  mock_client = MockAnthropic.create_mock(responses)
238
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
241
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
239
242
  agent = Agent(m, system_prompt='this is the system prompt')
240
243
 
241
244
  @agent.tool_plain
@@ -327,7 +330,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
327
330
  ]
328
331
 
329
332
  mock_client = MockAnthropic.create_mock(responses)
330
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
333
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
331
334
  agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
332
335
 
333
336
  @agent.tool_plain
@@ -366,7 +369,7 @@ async def test_multiple_parallel_tool_calls(allow_model_requests: None):
366
369
  # However, we do want to use the environment variable if present when rewriting VCR cassettes.
367
370
  api_key = os.environ.get('ANTHROPIC_API_KEY', 'mock-value')
368
371
  agent = Agent(
369
- AnthropicModel('claude-3-5-haiku-latest', api_key=api_key),
372
+ AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=api_key)),
370
373
  system_prompt=system_prompt,
371
374
  tools=[retrieve_entity_info],
372
375
  )
@@ -436,7 +439,7 @@ async def test_multiple_parallel_tool_calls(allow_model_requests: None):
436
439
  async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
437
440
  c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
438
441
  mock_client = MockAnthropic.create_mock(c)
439
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
442
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
440
443
  agent = Agent(m)
441
444
 
442
445
  result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
@@ -525,7 +528,7 @@ async def test_stream_structured(allow_model_requests: None):
525
528
  ]
526
529
 
527
530
  mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
528
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
531
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
529
532
  agent = Agent(m)
530
533
 
531
534
  tool_called = False
@@ -555,7 +558,7 @@ async def test_stream_structured(allow_model_requests: None):
555
558
 
556
559
  @pytest.mark.vcr()
557
560
  async def test_image_url_input(allow_model_requests: None, anthropic_api_key: str):
558
- m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
561
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
559
562
  agent = Agent(m)
560
563
 
561
564
  result = await agent.run(
@@ -573,7 +576,7 @@ Potatoes are root vegetables that are staple foods in many cuisines around the w
573
576
 
574
577
  @pytest.mark.vcr()
575
578
  async def test_image_url_input_invalid_mime_type(allow_model_requests: None, anthropic_api_key: str):
576
- m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
579
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
577
580
  agent = Agent(m)
578
581
 
579
582
  result = await agent.run(
@@ -593,7 +596,7 @@ async def test_image_url_input_invalid_mime_type(allow_model_requests: None, ant
593
596
  async def test_audio_as_binary_content_input(allow_model_requests: None, media_type: str):
594
597
  c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
595
598
  mock_client = MockAnthropic.create_mock(c)
596
- m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
599
+ m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
597
600
  agent = Agent(m)
598
601
 
599
602
  base64_content = b'//uQZ'
@@ -610,7 +613,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
610
613
  body={'error': 'test error'},
611
614
  )
612
615
  )
613
- m = AnthropicModel('claude-3-5-sonnet-latest', anthropic_client=mock_client)
616
+ m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(anthropic_client=mock_client))
614
617
  agent = Agent(m)
615
618
  with pytest.raises(ModelHTTPError) as exc_info:
616
619
  agent.run_sync('hello')
@@ -623,7 +626,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
623
626
  async def test_document_binary_content_input(
624
627
  allow_model_requests: None, anthropic_api_key: str, document_content: BinaryContent
625
628
  ):
626
- m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
629
+ m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
627
630
  agent = Agent(m)
628
631
 
629
632
  result = await agent.run(['What is the main content on this document?', document_content])
@@ -634,7 +637,7 @@ async def test_document_binary_content_input(
634
637
 
635
638
  @pytest.mark.vcr()
636
639
  async def test_document_url_input(allow_model_requests: None, anthropic_api_key: str):
637
- m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
640
+ m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
638
641
  agent = Agent(m)
639
642
 
640
643
  document_url = DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')
@@ -647,7 +650,7 @@ async def test_document_url_input(allow_model_requests: None, anthropic_api_key:
647
650
 
648
651
  @pytest.mark.vcr()
649
652
  async def test_text_document_url_input(allow_model_requests: None, anthropic_api_key: str):
650
- m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
653
+ m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
651
654
  agent = Agent(m)
652
655
 
653
656
  text_document_url = DocumentUrl(url='https://example-files.online-convert.com/document/txt/example.txt')
@@ -668,3 +671,17 @@ This document is a TXT test file that primarily contains information about the u
668
671
 
669
672
  The document is formatted as a test file with metadata including its purpose, file type, and version. It also includes attribution information indicating the content is from Wikipedia and is licensed under Attribution-ShareAlike 4.0.\
670
673
  """)
674
+
675
+
676
+ def test_init_with_provider():
677
+ provider = AnthropicProvider(api_key='api-key')
678
+ model = AnthropicModel('claude-3-opus-latest', provider=provider)
679
+ assert model.model_name == 'claude-3-opus-latest'
680
+ assert model.client == provider.client
681
+
682
+
683
+ def test_init_with_provider_string():
684
+ with patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env-api-key'}, clear=False):
685
+ model = AnthropicModel('claude-3-opus-latest', provider='anthropic')
686
+ assert model.model_name == 'claude-3-opus-latest'
687
+ assert model.client is not None
@@ -185,6 +185,82 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None:
185
185
  )
186
186
 
187
187
 
188
+ @pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed')
189
+ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None:
190
+ fallback_model = FallbackModel(failure_model_stream, success_model_stream)
191
+ agent = Agent(model=fallback_model, instrument=True)
192
+ async with agent.run_stream('input') as result:
193
+ assert [c async for c, _is_last in result.stream_structured(debounce_by=None)] == snapshot(
194
+ [
195
+ ModelResponse(
196
+ parts=[TextPart(content='hello ')],
197
+ model_name='function::success_response_stream',
198
+ timestamp=IsNow(tz=timezone.utc),
199
+ ),
200
+ ModelResponse(
201
+ parts=[TextPart(content='hello world')],
202
+ model_name='function::success_response_stream',
203
+ timestamp=IsNow(tz=timezone.utc),
204
+ ),
205
+ ModelResponse(
206
+ parts=[TextPart(content='hello world')],
207
+ model_name='function::success_response_stream',
208
+ timestamp=IsNow(tz=timezone.utc),
209
+ ),
210
+ ]
211
+ )
212
+ assert result.is_complete
213
+
214
+ assert capfire.exporter.exported_spans_as_dict() == snapshot(
215
+ [
216
+ {
217
+ 'name': 'preparing model request params',
218
+ 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
219
+ 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
220
+ 'start_time': 2000000000,
221
+ 'end_time': 3000000000,
222
+ 'attributes': {
223
+ 'run_step': 1,
224
+ 'logfire.span_type': 'span',
225
+ 'logfire.msg': 'preparing model request params',
226
+ },
227
+ },
228
+ {
229
+ 'name': 'chat function::success_response_stream',
230
+ 'context': {'trace_id': 1, 'span_id': 5, 'is_remote': False},
231
+ 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
232
+ 'start_time': 4000000000,
233
+ 'end_time': 5000000000,
234
+ 'attributes': {
235
+ 'gen_ai.operation.name': 'chat',
236
+ 'logfire.span_type': 'span',
237
+ 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream',
238
+ 'gen_ai.system': 'function',
239
+ 'gen_ai.request.model': 'function::success_response_stream',
240
+ 'gen_ai.usage.input_tokens': 50,
241
+ 'gen_ai.usage.output_tokens': 2,
242
+ 'gen_ai.response.model': 'function::success_response_stream',
243
+ 'events': '[{"content": "input", "role": "user", "gen_ai.system": "function", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}, {"index": 0, "message": {"role": "assistant", "content": "hello world"}, "gen_ai.system": "function", "event.name": "gen_ai.choice"}]',
244
+ 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}}}',
245
+ },
246
+ },
247
+ {
248
+ 'name': 'agent run',
249
+ 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
250
+ 'parent': None,
251
+ 'start_time': 1000000000,
252
+ 'end_time': 6000000000,
253
+ 'attributes': {
254
+ 'model_name': 'fallback:function::failure_response_stream,function::success_response_stream',
255
+ 'agent_name': 'agent',
256
+ 'logfire.msg': 'agent run',
257
+ 'logfire.span_type': 'span',
258
+ },
259
+ },
260
+ ]
261
+ )
262
+
263
+
188
264
  def test_all_failed() -> None:
189
265
  fallback_model = FallbackModel(failure_model, failure_model)
190
266
  agent = Agent(model=fallback_model)
@@ -56,6 +56,7 @@ with try_import() as imports_successful:
56
56
  MistralModel,
57
57
  MistralStreamedResponse,
58
58
  )
59
+ from pydantic_ai.providers.mistral import MistralProvider
59
60
 
60
61
  # note: we use Union here so that casting works with Python 3.9
61
62
  MockChatCompletion = Union[MistralChatCompletionResponse, Exception]
@@ -98,13 +99,13 @@ class MockMistralAI:
98
99
  self, *_args: Any, stream: bool = False, **_kwargs: Any
99
100
  ) -> MistralChatCompletionResponse | MockAsyncStream[MockCompletionEvent]:
100
101
  if stream or self.stream:
101
- assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
102
+ assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided'
102
103
  if isinstance(self.stream[0], list):
103
104
  response = MockAsyncStream(iter(cast(list[MockCompletionEvent], self.stream[self.index])))
104
105
  else:
105
106
  response = MockAsyncStream(iter(cast(list[MockCompletionEvent], self.stream)))
106
107
  else:
107
- assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided'
108
+ assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided'
108
109
  if isinstance(self.completions, Sequence):
109
110
  raise_if_exception(self.completions[self.index])
110
111
  response = cast(MistralChatCompletionResponse, self.completions[self.index])
@@ -173,8 +174,14 @@ def func_chunk(
173
174
  #####################
174
175
 
175
176
 
177
+ def test_init_deprecated():
178
+ m = MistralModel('mistral-large-latest', api_key='foobar') # pyright: ignore[reportDeprecated]
179
+ assert m.model_name == 'mistral-large-latest'
180
+ assert m.base_url == 'https://api.mistral.ai'
181
+
182
+
176
183
  def test_init():
177
- m = MistralModel('mistral-large-latest', api_key='foobar')
184
+ m = MistralModel('mistral-large-latest', provider=MistralProvider(api_key='foobar'))
178
185
  assert m.model_name == 'mistral-large-latest'
179
186
  assert m.base_url == 'https://api.mistral.ai'
180
187
 
@@ -194,7 +201,7 @@ async def test_multiple_completions(allow_model_requests: None):
194
201
  completion_message(MistralAssistantMessage(content='hello again')),
195
202
  ]
196
203
  mock_client = MockMistralAI.create_mock(completions)
197
- model = MistralModel('mistral-large-latest', client=mock_client)
204
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
198
205
  agent = Agent(model=model)
199
206
 
200
207
  result = await agent.run('hello')
@@ -237,7 +244,7 @@ async def test_three_completions(allow_model_requests: None):
237
244
  completion_message(MistralAssistantMessage(content='final message')),
238
245
  ]
239
246
  mock_client = MockMistralAI.create_mock(completions)
240
- model = MistralModel('mistral-large-latest', client=mock_client)
247
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
241
248
  agent = Agent(model=model)
242
249
 
243
250
  result = await agent.run('hello')
@@ -296,7 +303,7 @@ async def test_stream_text(allow_model_requests: None):
296
303
  chunk([]),
297
304
  ]
298
305
  mock_client = MockMistralAI.create_stream_mock(stream)
299
- model = MistralModel('mistral-large-latest', client=mock_client)
306
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
300
307
  agent = Agent(model=model)
301
308
 
302
309
  async with agent.run_stream('') as result:
@@ -317,7 +324,7 @@ async def test_stream_text_finish_reason(allow_model_requests: None):
317
324
  text_chunk('.', finish_reason='stop'),
318
325
  ]
319
326
  mock_client = MockMistralAI.create_stream_mock(stream)
320
- model = MistralModel('mistral-large-latest', client=mock_client)
327
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
321
328
  agent = Agent(model=model)
322
329
 
323
330
  async with agent.run_stream('') as result:
@@ -335,7 +342,7 @@ async def test_no_delta(allow_model_requests: None):
335
342
  text_chunk('world'),
336
343
  ]
337
344
  mock_client = MockMistralAI.create_stream_mock(stream)
338
- model = MistralModel('mistral-large-latest', client=mock_client)
345
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
339
346
  agent = Agent(model=model)
340
347
 
341
348
  async with agent.run_stream('') as result:
@@ -372,7 +379,7 @@ async def test_request_model_structured_with_arguments_dict_response(allow_model
372
379
  usage=MistralUsageInfo(prompt_tokens=1, completion_tokens=2, total_tokens=3),
373
380
  )
374
381
  mock_client = MockMistralAI.create_mock(completion)
375
- model = MistralModel('mistral-large-latest', client=mock_client)
382
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
376
383
  agent = Agent(model=model, result_type=CityLocation)
377
384
 
378
385
  result = await agent.run('User prompt value')
@@ -430,7 +437,7 @@ async def test_request_model_structured_with_arguments_str_response(allow_model_
430
437
  )
431
438
  )
432
439
  mock_client = MockMistralAI.create_mock(completion)
433
- model = MistralModel('mistral-large-latest', client=mock_client)
440
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
434
441
  agent = Agent(model=model, result_type=CityLocation)
435
442
 
436
443
  result = await agent.run('User prompt value')
@@ -483,7 +490,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
483
490
  )
484
491
  )
485
492
  mock_client = MockMistralAI.create_mock(completion)
486
- model = MistralModel('mistral-large-latest', client=mock_client)
493
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
487
494
  agent = Agent(model=model, result_type=int, system_prompt='System prompt value')
488
495
 
489
496
  result = await agent.run('User prompt value')
@@ -568,7 +575,7 @@ async def test_stream_structured_with_all_type(allow_model_requests: None):
568
575
  ]
569
576
 
570
577
  mock_client = MockMistralAI.create_stream_mock(stream)
571
- model = MistralModel('mistral-large-latest', client=mock_client)
578
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
572
579
  agent = Agent(model, result_type=MyTypedDict)
573
580
 
574
581
  async with agent.run_stream('User prompt value') as result:
@@ -678,7 +685,7 @@ async def test_stream_result_type_primitif_dict(allow_model_requests: None):
678
685
  ]
679
686
 
680
687
  mock_client = MockMistralAI.create_stream_mock(stream)
681
- model = MistralModel('mistral-large-latest', client=mock_client)
688
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
682
689
  agent = Agent(model=model, result_type=MyTypedDict)
683
690
 
684
691
  async with agent.run_stream('User prompt value') as result:
@@ -734,7 +741,7 @@ async def test_stream_result_type_primitif_int(allow_model_requests: None):
734
741
  ]
735
742
 
736
743
  mock_client = MockMistralAI.create_stream_mock(stream)
737
- model = MistralModel('mistral-large-latest', client=mock_client)
744
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
738
745
  agent = Agent(model=model, result_type=int)
739
746
 
740
747
  async with agent.run_stream('User prompt value') as result:
@@ -793,7 +800,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
793
800
  ]
794
801
 
795
802
  mock_client = MockMistralAI.create_stream_mock(stream)
796
- model = MistralModel('mistral-large-latest', client=mock_client)
803
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
797
804
  agent = Agent(model, result_type=list[str])
798
805
 
799
806
  async with agent.run_stream('User prompt value') as result:
@@ -886,7 +893,7 @@ async def test_stream_result_type_basemodel_with_default_params(allow_model_requ
886
893
  ]
887
894
 
888
895
  mock_client = MockMistralAI.create_stream_mock(stream)
889
- model = MistralModel('mistral-large-latest', client=mock_client)
896
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
890
897
  agent = Agent(model=model, result_type=MyTypedBaseModel)
891
898
 
892
899
  async with agent.run_stream('User prompt value') as result:
@@ -971,7 +978,7 @@ async def test_stream_result_type_basemodel_with_required_params(allow_model_req
971
978
  ]
972
979
 
973
980
  mock_client = MockMistralAI.create_stream_mock(stream)
974
- model = MistralModel('mistral-large-latest', client=mock_client)
981
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
975
982
  agent = Agent(model=model, result_type=MyTypedBaseModel)
976
983
 
977
984
  async with agent.run_stream('User prompt value') as result:
@@ -1043,7 +1050,7 @@ async def test_request_tool_call(allow_model_requests: None):
1043
1050
  completion_message(MistralAssistantMessage(content='final response', role='assistant')),
1044
1051
  ]
1045
1052
  mock_client = MockMistralAI.create_mock(completion)
1046
- model = MistralModel('mistral-large-latest', client=mock_client)
1053
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1047
1054
  agent = Agent(model, system_prompt='this is the system prompt')
1048
1055
 
1049
1056
  @agent.tool_plain
@@ -1180,7 +1187,7 @@ async def test_request_tool_call_with_result_type(allow_model_requests: None):
1180
1187
  ),
1181
1188
  ]
1182
1189
  mock_client = MockMistralAI.create_mock(completion)
1183
- model = MistralModel('mistral-large-latest', client=mock_client)
1190
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1184
1191
  agent = Agent(model, system_prompt='this is the system prompt', result_type=MyTypedDict)
1185
1192
 
1186
1193
  @agent.tool_plain
@@ -1316,7 +1323,7 @@ async def test_stream_tool_call_with_return_type(allow_model_requests: None):
1316
1323
  ]
1317
1324
 
1318
1325
  mock_client = MockMistralAI.create_stream_mock(completion)
1319
- model = MistralModel('mistral-large-latest', client=mock_client)
1326
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1320
1327
  agent = Agent(model, system_prompt='this is the system prompt', result_type=MyTypedDict)
1321
1328
 
1322
1329
  @agent.tool_plain
@@ -1416,7 +1423,7 @@ async def test_stream_tool_call(allow_model_requests: None):
1416
1423
  ]
1417
1424
 
1418
1425
  mock_client = MockMistralAI.create_stream_mock(completion)
1419
- model = MistralModel('mistral-large-latest', client=mock_client)
1426
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1420
1427
  agent = Agent(model, system_prompt='this is the system prompt')
1421
1428
 
1422
1429
  @agent.tool_plain
@@ -1516,7 +1523,7 @@ async def test_stream_tool_call_with_retry(allow_model_requests: None):
1516
1523
  ]
1517
1524
 
1518
1525
  mock_client = MockMistralAI.create_stream_mock(completion)
1519
- model = MistralModel('mistral-large-latest', client=mock_client)
1526
+ model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1520
1527
  agent = Agent(model, system_prompt='this is the system prompt')
1521
1528
 
1522
1529
  @agent.tool_plain
@@ -1743,7 +1750,7 @@ def test_validate_required_json_schema(desc: str, schema: dict[str, Any], data:
1743
1750
  async def test_image_url_input(allow_model_requests: None):
1744
1751
  c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
1745
1752
  mock_client = MockMistralAI.create_mock(c)
1746
- m = MistralModel('mistral-large-latest', client=mock_client)
1753
+ m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1747
1754
  agent = Agent(m)
1748
1755
 
1749
1756
  result = await agent.run(
@@ -1779,7 +1786,7 @@ async def test_image_url_input(allow_model_requests: None):
1779
1786
  async def test_image_as_binary_content_input(allow_model_requests: None):
1780
1787
  c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
1781
1788
  mock_client = MockMistralAI.create_mock(c)
1782
- m = MistralModel('mistral-large-latest', client=mock_client)
1789
+ m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1783
1790
  agent = Agent(m)
1784
1791
 
1785
1792
  base64_content = (
@@ -1810,7 +1817,7 @@ async def test_image_as_binary_content_input(allow_model_requests: None):
1810
1817
  async def test_audio_as_binary_content_input(allow_model_requests: None):
1811
1818
  c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
1812
1819
  mock_client = MockMistralAI.create_mock(c)
1813
- m = MistralModel('mistral-large-latest', client=mock_client)
1820
+ m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1814
1821
  agent = Agent(m)
1815
1822
 
1816
1823
  base64_content = b'//uQZ'
@@ -1827,7 +1834,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
1827
1834
  body='test error',
1828
1835
  )
1829
1836
  )
1830
- m = MistralModel('mistral-large-latest', client=mock_client)
1837
+ m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
1831
1838
  agent = Agent(m)
1832
1839
  with pytest.raises(ModelHTTPError) as exc_info:
1833
1840
  agent.run_sync('hello')
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from unittest.mock import patch
5
+
6
+ import httpx
7
+ import pytest
8
+
9
+ from ..conftest import try_import
10
+
11
+ with try_import() as imports_successful:
12
+ from anthropic import AsyncAnthropic
13
+
14
+ from pydantic_ai.providers.anthropic import AnthropicProvider
15
+
16
+
17
+ pytestmark = pytest.mark.skipif(not imports_successful(), reason='need to install anthropic')
18
+
19
+
20
+ def test_anthropic_provider():
21
+ provider = AnthropicProvider(api_key='api-key')
22
+ assert provider.name == 'anthropic'
23
+ assert provider.base_url == 'https://api.anthropic.com'
24
+ assert isinstance(provider.client, AsyncAnthropic)
25
+ assert provider.client.api_key == 'api-key'
26
+
27
+
28
+ def test_anthropic_provider_need_api_key() -> None:
29
+ with patch.dict(os.environ, {}, clear=True):
30
+ with pytest.raises(
31
+ ValueError,
32
+ match=r'.*ANTHROPIC_API_KEY.*',
33
+ ):
34
+ AnthropicProvider()
35
+
36
+
37
+ def test_anthropic_provider_pass_http_client() -> None:
38
+ http_client = httpx.AsyncClient()
39
+ provider = AnthropicProvider(http_client=http_client, api_key='api-key')
40
+ assert isinstance(provider.client, AsyncAnthropic)
41
+ # Verify the http_client is being used by the AsyncAnthropic client
42
+ assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
43
+
44
+
45
+ def test_anthropic_provider_pass_anthropic_client() -> None:
46
+ anthropic_client = AsyncAnthropic(api_key='api-key')
47
+ provider = AnthropicProvider(anthropic_client=anthropic_client)
48
+ assert provider.client == anthropic_client
49
+
50
+
51
+ def test_anthropic_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> None:
52
+ # Test with environment variable for base_url
53
+ custom_base_url = 'https://custom.anthropic.com/v1'
54
+ monkeypatch.setenv('ANTHROPIC_BASE_URL', custom_base_url)
55
+ provider = AnthropicProvider(api_key='api-key')
56
+ assert provider.base_url.rstrip('/') == custom_base_url.rstrip('/')
@@ -55,12 +55,3 @@ def test_groq_provider_with_env_base_url(monkeypatch: pytest.MonkeyPatch) -> Non
55
55
  monkeypatch.setenv('GROQ_BASE_URL', 'https://custom.groq.com/v1')
56
56
  provider = GroqProvider(api_key='api-key')
57
57
  assert provider.base_url == 'https://custom.groq.com/v1'
58
-
59
-
60
- def test_infer_groq_provider():
61
- with patch.dict(os.environ, {'GROQ_API_KEY': 'test-api-key'}, clear=False):
62
- from pydantic_ai.providers import infer_provider
63
-
64
- provider = infer_provider('groq')
65
- assert provider.name == 'groq'
66
- assert isinstance(provider, GroqProvider)
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ import re
5
+ from unittest.mock import patch
6
+
7
+ import httpx
8
+ import pytest
9
+
10
+ from ..conftest import try_import
11
+
12
+ with try_import() as imports_successful:
13
+ from mistralai import Mistral
14
+
15
+ from pydantic_ai.providers.mistral import MistralProvider
16
+
17
+
18
+ pytestmark = pytest.mark.skipif(not imports_successful(), reason='mistral not installed')
19
+
20
+
21
+ def test_mistral_provider():
22
+ provider = MistralProvider(api_key='api-key')
23
+ assert provider.name == 'mistral'
24
+ assert provider.base_url == 'https://api.mistral.ai'
25
+ assert isinstance(provider.client, Mistral)
26
+ assert provider.client.sdk_configuration.security.api_key == 'api-key' # pyright: ignore
27
+
28
+
29
+ def test_mistral_provider_need_api_key() -> None:
30
+ with patch.dict(os.environ, {}, clear=True):
31
+ with pytest.raises(
32
+ ValueError,
33
+ match=re.escape(
34
+ 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
35
+ 'to use the Mistral provider.'
36
+ ),
37
+ ):
38
+ MistralProvider()
39
+
40
+
41
+ def test_mistral_provider_pass_http_client() -> None:
42
+ http_client = httpx.AsyncClient()
43
+ provider = MistralProvider(http_client=http_client, api_key='api-key')
44
+ assert provider.client.sdk_configuration.async_client == http_client
45
+
46
+
47
+ def test_mistral_provider_pass_groq_client() -> None:
48
+ mistral_client = Mistral(api_key='api-key')
49
+ provider = MistralProvider(mistral_client=mistral_client)
50
+ assert provider.client == mistral_client
51
+
52
+
53
+ def test_mistral_provider_with_base_url() -> None:
54
+ # Test with environment variable for base_url
55
+ provider = MistralProvider(
56
+ mistral_client=Mistral(api_key='test-api-key', server_url='https://custom.mistral.com/v1'),
57
+ )
58
+ assert provider.base_url == 'https://custom.mistral.com/v1'
@@ -11,22 +11,26 @@ from pydantic_ai.providers import Provider, infer_provider
11
11
  from ..conftest import try_import
12
12
 
13
13
  with try_import() as imports_successful:
14
+ from pydantic_ai.providers.anthropic import AnthropicProvider
14
15
  from pydantic_ai.providers.deepseek import DeepSeekProvider
15
16
  from pydantic_ai.providers.google_gla import GoogleGLAProvider
16
17
  from pydantic_ai.providers.google_vertex import GoogleVertexProvider
17
18
  from pydantic_ai.providers.groq import GroqProvider
19
+ from pydantic_ai.providers.mistral import MistralProvider
18
20
  from pydantic_ai.providers.openai import OpenAIProvider
19
21
 
20
22
  test_infer_provider_params = [
23
+ ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'),
21
24
  ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
22
25
  ('openai', OpenAIProvider, None),
23
26
  ('google-vertex', GoogleVertexProvider, None),
24
27
  ('google-gla', GoogleGLAProvider, 'GEMINI_API_KEY'),
25
28
  ('groq', GroqProvider, 'GROQ_API_KEY'),
29
+ ('mistral', MistralProvider, 'MISTRAL_API_KEY'),
26
30
  ]
27
31
 
28
32
  if not imports_successful():
29
- test_infer_provider_params = []
33
+ test_infer_provider_params = [] # pragma: no cover
30
34
 
31
35
  pytestmark = pytest.mark.skipif(not imports_successful(), reason='need to install all extra packages')
32
36
 
@@ -57,8 +57,9 @@ def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
57
57
 
58
58
  def anthropic(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
59
59
  from pydantic_ai.models.anthropic import AnthropicModel
60
+ from pydantic_ai.providers.anthropic import AnthropicProvider
60
61
 
61
- return AnthropicModel('claude-3-5-sonnet-latest', http_client=http_client)
62
+ return AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(http_client=http_client))
62
63
 
63
64
 
64
65
  def ollama(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
@@ -72,8 +73,9 @@ def ollama(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
72
73
 
73
74
  def mistral(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
74
75
  from pydantic_ai.models.mistral import MistralModel
76
+ from pydantic_ai.providers.mistral import MistralProvider
75
77
 
76
- return MistralModel('mistral-small-latest', http_client=http_client)
78
+ return MistralModel('mistral-small-latest', provider=MistralProvider(http_client=http_client))
77
79
 
78
80
 
79
81
  # TODO(Marcelo): We've surpassed the limit of our API key on Cohere.
File without changes
File without changes
File without changes
File without changes