pydantic-ai 0.0.20__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai might be problematic. Click here for more details.
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/PKG-INFO +5 -5
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/README.md +2 -2
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/pyproject.toml +11 -3
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_anthropic.py +15 -5
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_cohere.py +3 -3
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_gemini.py +54 -27
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_groq.py +25 -26
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_mistral.py +13 -17
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_model.py +1 -2
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_model_function.py +12 -12
- pydantic_ai-0.0.21/tests/models/test_model_names.py +50 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_model_test.py +13 -2
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_openai.py +3 -3
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_agent.py +51 -53
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_examples.py +47 -54
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_live.py +12 -5
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_logfire.py +1 -13
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_parts_manager.py +26 -46
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_streaming.py +22 -24
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_tools.py +9 -10
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_usage_limits.py +1 -2
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_utils.py +1 -8
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/typed_graph.py +2 -2
- pydantic_ai-0.0.20/tests/models/test_ollama.py +0 -70
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/.gitignore +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/LICENSE +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/Makefile +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/__init__.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/conftest.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/example_modules/README.md +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/example_modules/bank_database.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/example_modules/fake_database.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/example_modules/weather_service.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/graph/__init__.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/graph/test_graph.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/graph/test_history.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/graph/test_mermaid.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/graph/test_state.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/import_examples.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/__init__.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/mock_async_stream.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/models/test_vertexai.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_deps.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/test_format_as_xml.py +0 -0
- {pydantic_ai-0.0.20 → pydantic_ai-0.0.21}/tests/typed_agent.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -32,9 +32,9 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
32
32
|
Classifier: Topic :: Internet
|
|
33
33
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
34
34
|
Requires-Python: >=3.9
|
|
35
|
-
Requires-Dist: pydantic-ai-slim[anthropic,cohere,graph,groq,mistral,openai,vertexai]==0.0.
|
|
35
|
+
Requires-Dist: pydantic-ai-slim[anthropic,cohere,graph,groq,mistral,openai,vertexai]==0.0.21
|
|
36
36
|
Provides-Extra: examples
|
|
37
|
-
Requires-Dist: pydantic-ai-examples==0.0.
|
|
37
|
+
Requires-Dist: pydantic-ai-examples==0.0.21; extra == 'examples'
|
|
38
38
|
Provides-Extra: logfire
|
|
39
39
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
40
40
|
Description-Content-Type: text/markdown
|
|
@@ -78,7 +78,7 @@ We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI
|
|
|
78
78
|
Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more).
|
|
79
79
|
|
|
80
80
|
* __Model-agnostic__
|
|
81
|
-
Supports OpenAI, Anthropic, Gemini, Ollama, Groq, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
81
|
+
Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
82
82
|
|
|
83
83
|
* __Pydantic Logfire Integration__
|
|
84
84
|
Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications.
|
|
@@ -116,7 +116,7 @@ from pydantic_ai import Agent
|
|
|
116
116
|
|
|
117
117
|
# Define a very simple agent including the model to use, you can also set the model when running the agent.
|
|
118
118
|
agent = Agent(
|
|
119
|
-
'gemini-1.5-flash',
|
|
119
|
+
'google-gla:gemini-1.5-flash',
|
|
120
120
|
# Register a static system prompt using a keyword argument to the agent.
|
|
121
121
|
# For more complex dynamically-generated system prompts, see the example below.
|
|
122
122
|
system_prompt='Be concise, reply with one sentence.',
|
|
@@ -37,7 +37,7 @@ We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI
|
|
|
37
37
|
Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more).
|
|
38
38
|
|
|
39
39
|
* __Model-agnostic__
|
|
40
|
-
Supports OpenAI, Anthropic, Gemini, Ollama, Groq, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
40
|
+
Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
41
41
|
|
|
42
42
|
* __Pydantic Logfire Integration__
|
|
43
43
|
Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications.
|
|
@@ -75,7 +75,7 @@ from pydantic_ai import Agent
|
|
|
75
75
|
|
|
76
76
|
# Define a very simple agent including the model to use, you can also set the model when running the agent.
|
|
77
77
|
agent = Agent(
|
|
78
|
-
'gemini-1.5-flash',
|
|
78
|
+
'google-gla:gemini-1.5-flash',
|
|
79
79
|
# Register a static system prompt using a keyword argument to the agent.
|
|
80
80
|
# For more complex dynamically-generated system prompts, see the example below.
|
|
81
81
|
system_prompt='Be concise, reply with one sentence.',
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.21"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -37,7 +37,7 @@ classifiers = [
|
|
|
37
37
|
]
|
|
38
38
|
requires-python = ">=3.9"
|
|
39
39
|
|
|
40
|
-
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.
|
|
40
|
+
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.21"]
|
|
41
41
|
|
|
42
42
|
[project.urls]
|
|
43
43
|
Homepage = "https://ai.pydantic.dev"
|
|
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
|
|
|
46
46
|
Changelog = "https://github.com/pydantic/pydantic-ai/releases"
|
|
47
47
|
|
|
48
48
|
[project.optional-dependencies]
|
|
49
|
-
examples = ["pydantic-ai-examples==0.0.
|
|
49
|
+
examples = ["pydantic-ai-examples==0.0.21"]
|
|
50
50
|
logfire = ["logfire>=2.3"]
|
|
51
51
|
|
|
52
52
|
[tool.uv.sources]
|
|
@@ -183,3 +183,11 @@ ignore_no_config = true
|
|
|
183
183
|
[tool.inline-snapshot.shortcuts]
|
|
184
184
|
snap-fix=["create", "fix"]
|
|
185
185
|
snap=["create"]
|
|
186
|
+
|
|
187
|
+
[tool.codespell]
|
|
188
|
+
# Ref: https://github.com/codespell-project/codespell#using-a-config-file
|
|
189
|
+
skip = '.git*,*.svg,*.lock,*.css'
|
|
190
|
+
check-hidden = true
|
|
191
|
+
# Ignore "formatting" like **L**anguage
|
|
192
|
+
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
|
|
193
|
+
# ignore-words-list = ''
|
|
@@ -11,7 +11,6 @@ from inline_snapshot import snapshot
|
|
|
11
11
|
|
|
12
12
|
from pydantic_ai import Agent, ModelRetry
|
|
13
13
|
from pydantic_ai.messages import (
|
|
14
|
-
ArgsDict,
|
|
15
14
|
ModelRequest,
|
|
16
15
|
ModelResponse,
|
|
17
16
|
RetryPromptPart,
|
|
@@ -47,7 +46,7 @@ with try_import() as imports_successful:
|
|
|
47
46
|
)
|
|
48
47
|
from anthropic.types.raw_message_delta_event import Delta
|
|
49
48
|
|
|
50
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
49
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
51
50
|
|
|
52
51
|
pytestmark = [
|
|
53
52
|
pytest.mark.skipif(not imports_successful(), reason='anthropic not installed'),
|
|
@@ -186,7 +185,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
186
185
|
parts=[
|
|
187
186
|
ToolCallPart(
|
|
188
187
|
tool_name='final_result',
|
|
189
|
-
args=
|
|
188
|
+
args={'response': [1, 2, 3]},
|
|
190
189
|
tool_call_id='123',
|
|
191
190
|
)
|
|
192
191
|
],
|
|
@@ -248,7 +247,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
248
247
|
parts=[
|
|
249
248
|
ToolCallPart(
|
|
250
249
|
tool_name='get_location',
|
|
251
|
-
args=
|
|
250
|
+
args={'loc_name': 'San Francisco'},
|
|
252
251
|
tool_call_id='1',
|
|
253
252
|
)
|
|
254
253
|
],
|
|
@@ -269,7 +268,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
269
268
|
parts=[
|
|
270
269
|
ToolCallPart(
|
|
271
270
|
tool_name='get_location',
|
|
272
|
-
args=
|
|
271
|
+
args={'loc_name': 'London'},
|
|
273
272
|
tool_call_id='2',
|
|
274
273
|
)
|
|
275
274
|
],
|
|
@@ -332,6 +331,17 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
|
|
|
332
331
|
)
|
|
333
332
|
|
|
334
333
|
|
|
334
|
+
async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
|
|
335
|
+
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
|
|
336
|
+
mock_client = MockAnthropic.create_mock(c)
|
|
337
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
338
|
+
agent = Agent(m)
|
|
339
|
+
|
|
340
|
+
result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
|
|
341
|
+
assert result.data == 'world'
|
|
342
|
+
assert get_mock_chat_completion_kwargs(mock_client)[0]['metadata']['user_id'] == '123'
|
|
343
|
+
|
|
344
|
+
|
|
335
345
|
async def test_stream_structured(allow_model_requests: None):
|
|
336
346
|
"""Test streaming structured responses with Anthropic's API.
|
|
337
347
|
|
|
@@ -169,7 +169,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
169
169
|
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
170
170
|
ModelResponse(
|
|
171
171
|
parts=[
|
|
172
|
-
ToolCallPart
|
|
172
|
+
ToolCallPart(
|
|
173
173
|
tool_name='final_result',
|
|
174
174
|
args='{"response": [1, 2, 123]}',
|
|
175
175
|
tool_call_id='123',
|
|
@@ -255,7 +255,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
255
255
|
),
|
|
256
256
|
ModelResponse(
|
|
257
257
|
parts=[
|
|
258
|
-
ToolCallPart
|
|
258
|
+
ToolCallPart(
|
|
259
259
|
tool_name='get_location',
|
|
260
260
|
args='{"loc_name": "San Fransisco"}',
|
|
261
261
|
tool_call_id='1',
|
|
@@ -276,7 +276,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
276
276
|
),
|
|
277
277
|
ModelResponse(
|
|
278
278
|
parts=[
|
|
279
|
-
ToolCallPart
|
|
279
|
+
ToolCallPart(
|
|
280
280
|
tool_name='get_location',
|
|
281
281
|
args='{"loc_name": "London"}',
|
|
282
282
|
tool_call_id='2',
|
|
@@ -15,7 +15,6 @@ from typing_extensions import Literal, TypeAlias
|
|
|
15
15
|
|
|
16
16
|
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, UserError
|
|
17
17
|
from pydantic_ai.messages import (
|
|
18
|
-
ArgsDict,
|
|
19
18
|
ModelRequest,
|
|
20
19
|
ModelResponse,
|
|
21
20
|
RetryPromptPart,
|
|
@@ -474,9 +473,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
|
|
|
474
473
|
|
|
475
474
|
async def test_request_structured_response(get_gemini_client: GetGeminiClient):
|
|
476
475
|
response = gemini_response(
|
|
477
|
-
_content_model_response(
|
|
478
|
-
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2, 123]})])
|
|
479
|
-
)
|
|
476
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2, 123]})]))
|
|
480
477
|
)
|
|
481
478
|
gemini_client = get_gemini_client(response)
|
|
482
479
|
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
|
|
@@ -491,7 +488,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
|
|
|
491
488
|
parts=[
|
|
492
489
|
ToolCallPart(
|
|
493
490
|
tool_name='final_result',
|
|
494
|
-
args=
|
|
491
|
+
args={'response': [1, 2, 123]},
|
|
495
492
|
)
|
|
496
493
|
],
|
|
497
494
|
model_name='gemini-1.5-flash',
|
|
@@ -511,16 +508,14 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
|
|
|
511
508
|
async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
512
509
|
responses = [
|
|
513
510
|
gemini_response(
|
|
514
|
-
_content_model_response(
|
|
515
|
-
ModelResponse(parts=[ToolCallPart.from_raw_args('get_location', {'loc_name': 'San Fransisco'})])
|
|
516
|
-
)
|
|
511
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('get_location', {'loc_name': 'San Fransisco'})]))
|
|
517
512
|
),
|
|
518
513
|
gemini_response(
|
|
519
514
|
_content_model_response(
|
|
520
515
|
ModelResponse(
|
|
521
516
|
parts=[
|
|
522
|
-
ToolCallPart
|
|
523
|
-
ToolCallPart
|
|
517
|
+
ToolCallPart('get_location', {'loc_name': 'London'}),
|
|
518
|
+
ToolCallPart('get_location', {'loc_name': 'New York'}),
|
|
524
519
|
]
|
|
525
520
|
)
|
|
526
521
|
)
|
|
@@ -554,7 +549,7 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
554
549
|
parts=[
|
|
555
550
|
ToolCallPart(
|
|
556
551
|
tool_name='get_location',
|
|
557
|
-
args=
|
|
552
|
+
args={'loc_name': 'San Fransisco'},
|
|
558
553
|
)
|
|
559
554
|
],
|
|
560
555
|
model_name='gemini-1.5-flash',
|
|
@@ -573,11 +568,11 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
|
|
|
573
568
|
parts=[
|
|
574
569
|
ToolCallPart(
|
|
575
570
|
tool_name='get_location',
|
|
576
|
-
args=
|
|
571
|
+
args={'loc_name': 'London'},
|
|
577
572
|
),
|
|
578
573
|
ToolCallPart(
|
|
579
574
|
tool_name='get_location',
|
|
580
|
-
args=
|
|
575
|
+
args={'loc_name': 'New York'},
|
|
581
576
|
),
|
|
582
577
|
],
|
|
583
578
|
model_name='gemini-1.5-flash',
|
|
@@ -664,9 +659,7 @@ async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
|
|
|
664
659
|
async def test_stream_structured(get_gemini_client: GetGeminiClient):
|
|
665
660
|
responses = [
|
|
666
661
|
gemini_response(
|
|
667
|
-
_content_model_response(
|
|
668
|
-
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2]})])
|
|
669
|
-
),
|
|
662
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])),
|
|
670
663
|
),
|
|
671
664
|
]
|
|
672
665
|
json_data = _gemini_streamed_response_ta.dump_json(responses, by_alias=True)
|
|
@@ -684,10 +677,10 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient):
|
|
|
684
677
|
async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
|
|
685
678
|
first_responses = [
|
|
686
679
|
gemini_response(
|
|
687
|
-
_content_model_response(ModelResponse(parts=[ToolCallPart
|
|
680
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('foo', {'x': 'a'})])),
|
|
688
681
|
),
|
|
689
682
|
gemini_response(
|
|
690
|
-
_content_model_response(ModelResponse(parts=[ToolCallPart
|
|
683
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('bar', {'y': 'b'})])),
|
|
691
684
|
),
|
|
692
685
|
]
|
|
693
686
|
d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True)
|
|
@@ -695,9 +688,7 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
|
|
|
695
688
|
|
|
696
689
|
second_responses = [
|
|
697
690
|
gemini_response(
|
|
698
|
-
_content_model_response(
|
|
699
|
-
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2]})])
|
|
700
|
-
),
|
|
691
|
+
_content_model_response(ModelResponse(parts=[ToolCallPart('final_result', {'response': [1, 2]})])),
|
|
701
692
|
),
|
|
702
693
|
]
|
|
703
694
|
d2 = _gemini_streamed_response_ta.dump_json(second_responses, by_alias=True)
|
|
@@ -727,8 +718,8 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
|
|
|
727
718
|
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
728
719
|
ModelResponse(
|
|
729
720
|
parts=[
|
|
730
|
-
ToolCallPart(tool_name='foo', args=
|
|
731
|
-
ToolCallPart(tool_name='bar', args=
|
|
721
|
+
ToolCallPart(tool_name='foo', args={'x': 'a'}),
|
|
722
|
+
ToolCallPart(tool_name='bar', args={'y': 'b'}),
|
|
732
723
|
],
|
|
733
724
|
model_name='gemini-1.5-flash',
|
|
734
725
|
timestamp=IsNow(tz=timezone.utc),
|
|
@@ -743,7 +734,7 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
|
|
|
743
734
|
parts=[
|
|
744
735
|
ToolCallPart(
|
|
745
736
|
tool_name='final_result',
|
|
746
|
-
args=
|
|
737
|
+
args={'response': [1, 2]},
|
|
747
738
|
)
|
|
748
739
|
],
|
|
749
740
|
model_name='gemini-1.5-flash',
|
|
@@ -772,7 +763,7 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
|
|
|
772
763
|
_function_call_part_from_call(
|
|
773
764
|
ToolCallPart(
|
|
774
765
|
tool_name='get_location',
|
|
775
|
-
args=
|
|
766
|
+
args={'loc_name': 'San Fransisco'},
|
|
776
767
|
)
|
|
777
768
|
),
|
|
778
769
|
],
|
|
@@ -795,7 +786,7 @@ async def test_empty_text_ignored():
|
|
|
795
786
|
content = _content_model_response(
|
|
796
787
|
ModelResponse(
|
|
797
788
|
parts=[
|
|
798
|
-
ToolCallPart
|
|
789
|
+
ToolCallPart('final_result', {'response': [1, 2, 123]}),
|
|
799
790
|
TextPart(content='xxx'),
|
|
800
791
|
]
|
|
801
792
|
)
|
|
@@ -814,7 +805,7 @@ async def test_empty_text_ignored():
|
|
|
814
805
|
content = _content_model_response(
|
|
815
806
|
ModelResponse(
|
|
816
807
|
parts=[
|
|
817
|
-
ToolCallPart
|
|
808
|
+
ToolCallPart('final_result', {'response': [1, 2, 123]}),
|
|
818
809
|
TextPart(content=''),
|
|
819
810
|
]
|
|
820
811
|
)
|
|
@@ -826,3 +817,39 @@ async def test_empty_text_ignored():
|
|
|
826
817
|
'parts': [{'function_call': {'name': 'final_result', 'args': {'response': [1, 2, 123]}}}],
|
|
827
818
|
}
|
|
828
819
|
)
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
async def test_model_settings(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None) -> None:
|
|
823
|
+
def handler(request: httpx.Request) -> httpx.Response:
|
|
824
|
+
generation_config = json.loads(request.content)['generation_config']
|
|
825
|
+
assert generation_config == {
|
|
826
|
+
'max_output_tokens': 1,
|
|
827
|
+
'temperature': 0.1,
|
|
828
|
+
'top_p': 0.2,
|
|
829
|
+
'presence_penalty': 0.3,
|
|
830
|
+
'frequency_penalty': 0.4,
|
|
831
|
+
}
|
|
832
|
+
return httpx.Response(
|
|
833
|
+
200,
|
|
834
|
+
content=_gemini_response_ta.dump_json(
|
|
835
|
+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
|
|
836
|
+
by_alias=True,
|
|
837
|
+
),
|
|
838
|
+
headers={'Content-Type': 'application/json'},
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
gemini_client = client_with_handler(handler)
|
|
842
|
+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
|
|
843
|
+
agent = Agent(m)
|
|
844
|
+
|
|
845
|
+
result = await agent.run(
|
|
846
|
+
'hello',
|
|
847
|
+
model_settings={
|
|
848
|
+
'max_tokens': 1,
|
|
849
|
+
'temperature': 0.1,
|
|
850
|
+
'top_p': 0.2,
|
|
851
|
+
'presence_penalty': 0.3,
|
|
852
|
+
'frequency_penalty': 0.4,
|
|
853
|
+
},
|
|
854
|
+
)
|
|
855
|
+
assert result.data == 'world'
|
|
@@ -13,7 +13,6 @@ from typing_extensions import TypedDict
|
|
|
13
13
|
|
|
14
14
|
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior
|
|
15
15
|
from pydantic_ai.messages import (
|
|
16
|
-
ArgsJson,
|
|
17
16
|
ModelRequest,
|
|
18
17
|
ModelResponse,
|
|
19
18
|
RetryPromptPart,
|
|
@@ -51,9 +50,9 @@ pytestmark = [
|
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
def test_init():
|
|
54
|
-
m = GroqModel('llama-3.
|
|
53
|
+
m = GroqModel('llama-3.3-70b-versatile', api_key='foobar')
|
|
55
54
|
assert m.client.api_key == 'foobar'
|
|
56
|
-
assert m.name() == 'groq:llama-3.
|
|
55
|
+
assert m.name() == 'groq:llama-3.3-70b-versatile'
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
@dataclass
|
|
@@ -103,7 +102,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
|
|
|
103
102
|
id='123',
|
|
104
103
|
choices=[Choice(finish_reason='stop', index=0, message=message)],
|
|
105
104
|
created=1704067200, # 2024-01-01
|
|
106
|
-
model='llama-3.
|
|
105
|
+
model='llama-3.3-70b-versatile',
|
|
107
106
|
object='chat.completion',
|
|
108
107
|
usage=usage,
|
|
109
108
|
)
|
|
@@ -112,7 +111,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
|
|
|
112
111
|
async def test_request_simple_success(allow_model_requests: None):
|
|
113
112
|
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
|
|
114
113
|
mock_client = MockGroq.create_mock(c)
|
|
115
|
-
m = GroqModel('llama-3.
|
|
114
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
116
115
|
agent = Agent(m)
|
|
117
116
|
|
|
118
117
|
result = await agent.run('hello')
|
|
@@ -130,13 +129,13 @@ async def test_request_simple_success(allow_model_requests: None):
|
|
|
130
129
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
131
130
|
ModelResponse(
|
|
132
131
|
parts=[TextPart(content='world')],
|
|
133
|
-
model_name='llama-3.
|
|
132
|
+
model_name='llama-3.3-70b-versatile',
|
|
134
133
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
135
134
|
),
|
|
136
135
|
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
137
136
|
ModelResponse(
|
|
138
137
|
parts=[TextPart(content='world')],
|
|
139
|
-
model_name='llama-3.
|
|
138
|
+
model_name='llama-3.3-70b-versatile',
|
|
140
139
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
141
140
|
),
|
|
142
141
|
]
|
|
@@ -149,7 +148,7 @@ async def test_request_simple_usage(allow_model_requests: None):
|
|
|
149
148
|
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
|
|
150
149
|
)
|
|
151
150
|
mock_client = MockGroq.create_mock(c)
|
|
152
|
-
m = GroqModel('llama-3.
|
|
151
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
153
152
|
agent = Agent(m)
|
|
154
153
|
|
|
155
154
|
result = await agent.run('Hello')
|
|
@@ -171,7 +170,7 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
171
170
|
)
|
|
172
171
|
)
|
|
173
172
|
mock_client = MockGroq.create_mock(c)
|
|
174
|
-
m = GroqModel('llama-3.
|
|
173
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
175
174
|
agent = Agent(m, result_type=list[int])
|
|
176
175
|
|
|
177
176
|
result = await agent.run('Hello')
|
|
@@ -183,11 +182,11 @@ async def test_request_structured_response(allow_model_requests: None):
|
|
|
183
182
|
parts=[
|
|
184
183
|
ToolCallPart(
|
|
185
184
|
tool_name='final_result',
|
|
186
|
-
args=
|
|
185
|
+
args='{"response": [1, 2, 123]}',
|
|
187
186
|
tool_call_id='123',
|
|
188
187
|
)
|
|
189
188
|
],
|
|
190
|
-
model_name='llama-3.
|
|
189
|
+
model_name='llama-3.3-70b-versatile',
|
|
191
190
|
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
192
191
|
),
|
|
193
192
|
ModelRequest(
|
|
@@ -245,7 +244,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
245
244
|
completion_message(ChatCompletionMessage(content='final response', role='assistant')),
|
|
246
245
|
]
|
|
247
246
|
mock_client = MockGroq.create_mock(responses)
|
|
248
|
-
m = GroqModel('llama-3.
|
|
247
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
249
248
|
agent = Agent(m, system_prompt='this is the system prompt')
|
|
250
249
|
|
|
251
250
|
@agent.tool_plain
|
|
@@ -269,11 +268,11 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
269
268
|
parts=[
|
|
270
269
|
ToolCallPart(
|
|
271
270
|
tool_name='get_location',
|
|
272
|
-
args=
|
|
271
|
+
args='{"loc_name": "San Fransisco"}',
|
|
273
272
|
tool_call_id='1',
|
|
274
273
|
)
|
|
275
274
|
],
|
|
276
|
-
model_name='llama-3.
|
|
275
|
+
model_name='llama-3.3-70b-versatile',
|
|
277
276
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
278
277
|
),
|
|
279
278
|
ModelRequest(
|
|
@@ -290,11 +289,11 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
290
289
|
parts=[
|
|
291
290
|
ToolCallPart(
|
|
292
291
|
tool_name='get_location',
|
|
293
|
-
args=
|
|
292
|
+
args='{"loc_name": "London"}',
|
|
294
293
|
tool_call_id='2',
|
|
295
294
|
)
|
|
296
295
|
],
|
|
297
|
-
model_name='llama-3.
|
|
296
|
+
model_name='llama-3.3-70b-versatile',
|
|
298
297
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
299
298
|
),
|
|
300
299
|
ModelRequest(
|
|
@@ -309,7 +308,7 @@ async def test_request_tool_call(allow_model_requests: None):
|
|
|
309
308
|
),
|
|
310
309
|
ModelResponse(
|
|
311
310
|
parts=[TextPart(content='final response')],
|
|
312
|
-
model_name='llama-3.
|
|
311
|
+
model_name='llama-3.3-70b-versatile',
|
|
313
312
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
314
313
|
),
|
|
315
314
|
]
|
|
@@ -327,7 +326,7 @@ def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -
|
|
|
327
326
|
],
|
|
328
327
|
created=1704067200, # 2024-01-01
|
|
329
328
|
x_groq=None,
|
|
330
|
-
model='llama-3.
|
|
329
|
+
model='llama-3.3-70b-versatile',
|
|
331
330
|
object='chat.completion.chunk',
|
|
332
331
|
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
|
|
333
332
|
)
|
|
@@ -340,7 +339,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha
|
|
|
340
339
|
async def test_stream_text(allow_model_requests: None):
|
|
341
340
|
stream = text_chunk('hello '), text_chunk('world'), chunk([])
|
|
342
341
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
343
|
-
m = GroqModel('llama-3.
|
|
342
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
344
343
|
agent = Agent(m)
|
|
345
344
|
|
|
346
345
|
async with agent.run_stream('') as result:
|
|
@@ -352,7 +351,7 @@ async def test_stream_text(allow_model_requests: None):
|
|
|
352
351
|
async def test_stream_text_finish_reason(allow_model_requests: None):
|
|
353
352
|
stream = text_chunk('hello '), text_chunk('world'), text_chunk('.', finish_reason='stop')
|
|
354
353
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
355
|
-
m = GroqModel('llama-3.
|
|
354
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
356
355
|
agent = Agent(m)
|
|
357
356
|
|
|
358
357
|
async with agent.run_stream('') as result:
|
|
@@ -399,7 +398,7 @@ async def test_stream_structured(allow_model_requests: None):
|
|
|
399
398
|
chunk([]),
|
|
400
399
|
)
|
|
401
400
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
402
|
-
m = GroqModel('llama-3.
|
|
401
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
403
402
|
agent = Agent(m, result_type=MyTypedDict)
|
|
404
403
|
|
|
405
404
|
async with agent.run_stream('') as result:
|
|
@@ -422,10 +421,10 @@ async def test_stream_structured(allow_model_requests: None):
|
|
|
422
421
|
parts=[
|
|
423
422
|
ToolCallPart(
|
|
424
423
|
tool_name='final_result',
|
|
425
|
-
args=
|
|
424
|
+
args='{"first": "One", "second": "Two"}',
|
|
426
425
|
)
|
|
427
426
|
],
|
|
428
|
-
model_name='llama-3.
|
|
427
|
+
model_name='llama-3.3-70b-versatile',
|
|
429
428
|
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
|
|
430
429
|
),
|
|
431
430
|
ModelRequest(
|
|
@@ -450,7 +449,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
|
|
|
450
449
|
struc_chunk(None, None, finish_reason='stop'),
|
|
451
450
|
)
|
|
452
451
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
453
|
-
m = GroqModel('llama-3.
|
|
452
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
454
453
|
agent = Agent(m, result_type=MyTypedDict)
|
|
455
454
|
|
|
456
455
|
async with agent.run_stream('') as result:
|
|
@@ -470,7 +469,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
|
|
|
470
469
|
async def test_no_content(allow_model_requests: None):
|
|
471
470
|
stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()])
|
|
472
471
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
473
|
-
m = GroqModel('llama-3.
|
|
472
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
474
473
|
agent = Agent(m, result_type=MyTypedDict)
|
|
475
474
|
|
|
476
475
|
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
|
|
@@ -481,7 +480,7 @@ async def test_no_content(allow_model_requests: None):
|
|
|
481
480
|
async def test_no_delta(allow_model_requests: None):
|
|
482
481
|
stream = chunk([]), text_chunk('hello '), text_chunk('world')
|
|
483
482
|
mock_client = MockGroq.create_mock_stream(stream)
|
|
484
|
-
m = GroqModel('llama-3.
|
|
483
|
+
m = GroqModel('llama-3.3-70b-versatile', groq_client=mock_client)
|
|
485
484
|
agent = Agent(m)
|
|
486
485
|
|
|
487
486
|
async with agent.run_stream('') as result:
|