pydantic-ai 0.0.19__tar.gz → 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai might be problematic. Click here for more details.
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/PKG-INFO +5 -5
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/README.md +2 -2
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/pyproject.toml +11 -3
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/conftest.py +9 -2
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/graph/test_mermaid.py +32 -0
- pydantic_ai-0.0.21/tests/models/test_anthropic.py +451 -0
- pydantic_ai-0.0.21/tests/models/test_cohere.py +311 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_gemini.py +82 -36
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_groq.py +38 -22
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_mistral.py +65 -27
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_model.py +7 -1
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_model_function.py +71 -34
- pydantic_ai-0.0.21/tests/models/test_model_names.py +50 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_model_test.py +31 -10
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_openai.py +116 -27
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/test_vertexai.py +3 -3
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_agent.py +148 -96
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_deps.py +2 -2
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_examples.py +53 -60
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_live.py +12 -5
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_logfire.py +4 -14
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_parts_manager.py +26 -46
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_streaming.py +46 -25
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_tools.py +118 -25
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_usage_limits.py +7 -7
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_utils.py +12 -8
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/typed_graph.py +2 -2
- pydantic_ai-0.0.19/tests/models/test_anthropic.py +0 -243
- pydantic_ai-0.0.19/tests/models/test_ollama.py +0 -61
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/.gitignore +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/LICENSE +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/Makefile +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/__init__.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/example_modules/README.md +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/example_modules/bank_database.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/example_modules/fake_database.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/example_modules/weather_service.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/graph/__init__.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/graph/test_graph.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/graph/test_history.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/graph/test_state.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/import_examples.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/__init__.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/models/mock_async_stream.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/test_format_as_xml.py +0 -0
- {pydantic_ai-0.0.19 → pydantic_ai-0.0.21}/tests/typed_agent.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -32,9 +32,9 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
32
32
|
Classifier: Topic :: Internet
|
|
33
33
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
34
34
|
Requires-Python: >=3.9
|
|
35
|
-
Requires-Dist: pydantic-ai-slim[anthropic,graph,groq,mistral,openai,vertexai]==0.0.
|
|
35
|
+
Requires-Dist: pydantic-ai-slim[anthropic,cohere,graph,groq,mistral,openai,vertexai]==0.0.21
|
|
36
36
|
Provides-Extra: examples
|
|
37
|
-
Requires-Dist: pydantic-ai-examples==0.0.
|
|
37
|
+
Requires-Dist: pydantic-ai-examples==0.0.21; extra == 'examples'
|
|
38
38
|
Provides-Extra: logfire
|
|
39
39
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
40
40
|
Description-Content-Type: text/markdown
|
|
@@ -78,7 +78,7 @@ We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI
|
|
|
78
78
|
Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more).
|
|
79
79
|
|
|
80
80
|
* __Model-agnostic__
|
|
81
|
-
Supports OpenAI, Anthropic, Gemini, Ollama, Groq, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
81
|
+
Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
82
82
|
|
|
83
83
|
* __Pydantic Logfire Integration__
|
|
84
84
|
Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications.
|
|
@@ -116,7 +116,7 @@ from pydantic_ai import Agent
|
|
|
116
116
|
|
|
117
117
|
# Define a very simple agent including the model to use, you can also set the model when running the agent.
|
|
118
118
|
agent = Agent(
|
|
119
|
-
'gemini-1.5-flash',
|
|
119
|
+
'google-gla:gemini-1.5-flash',
|
|
120
120
|
# Register a static system prompt using a keyword argument to the agent.
|
|
121
121
|
# For more complex dynamically-generated system prompts, see the example below.
|
|
122
122
|
system_prompt='Be concise, reply with one sentence.',
|
|
@@ -37,7 +37,7 @@ We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI
|
|
|
37
37
|
Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more).
|
|
38
38
|
|
|
39
39
|
* __Model-agnostic__
|
|
40
|
-
Supports OpenAI, Anthropic, Gemini, Ollama, Groq, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
40
|
+
Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/).
|
|
41
41
|
|
|
42
42
|
* __Pydantic Logfire Integration__
|
|
43
43
|
Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications.
|
|
@@ -75,7 +75,7 @@ from pydantic_ai import Agent
|
|
|
75
75
|
|
|
76
76
|
# Define a very simple agent including the model to use, you can also set the model when running the agent.
|
|
77
77
|
agent = Agent(
|
|
78
|
-
'gemini-1.5-flash',
|
|
78
|
+
'google-gla:gemini-1.5-flash',
|
|
79
79
|
# Register a static system prompt using a keyword argument to the agent.
|
|
80
80
|
# For more complex dynamically-generated system prompts, see the example below.
|
|
81
81
|
system_prompt='Be concise, reply with one sentence.',
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.21"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "Samuel Colvin", email = "samuel@pydantic.dev" },
|
|
@@ -37,7 +37,7 @@ classifiers = [
|
|
|
37
37
|
]
|
|
38
38
|
requires-python = ">=3.9"
|
|
39
39
|
|
|
40
|
-
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral]==0.0.
|
|
40
|
+
dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral,cohere]==0.0.21"]
|
|
41
41
|
|
|
42
42
|
[project.urls]
|
|
43
43
|
Homepage = "https://ai.pydantic.dev"
|
|
@@ -46,7 +46,7 @@ Documentation = "https://ai.pydantic.dev"
|
|
|
46
46
|
Changelog = "https://github.com/pydantic/pydantic-ai/releases"
|
|
47
47
|
|
|
48
48
|
[project.optional-dependencies]
|
|
49
|
-
examples = ["pydantic-ai-examples==0.0.
|
|
49
|
+
examples = ["pydantic-ai-examples==0.0.21"]
|
|
50
50
|
logfire = ["logfire>=2.3"]
|
|
51
51
|
|
|
52
52
|
[tool.uv.sources]
|
|
@@ -183,3 +183,11 @@ ignore_no_config = true
|
|
|
183
183
|
[tool.inline-snapshot.shortcuts]
|
|
184
184
|
snap-fix=["create", "fix"]
|
|
185
185
|
snap=["create"]
|
|
186
|
+
|
|
187
|
+
[tool.codespell]
|
|
188
|
+
# Ref: https://github.com/codespell-project/codespell#using-a-config-file
|
|
189
|
+
skip = '.git*,*.svg,*.lock,*.css'
|
|
190
|
+
check-hidden = true
|
|
191
|
+
# Ignore "formatting" like **L**anguage
|
|
192
|
+
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
|
|
193
|
+
# ignore-words-list = ''
|
|
@@ -30,7 +30,14 @@ if TYPE_CHECKING:
|
|
|
30
30
|
def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
|
|
31
31
|
def IsFloat(*args: Any, **kwargs: Any) -> float: ...
|
|
32
32
|
else:
|
|
33
|
-
from dirty_equals import IsFloat, IsNow
|
|
33
|
+
from dirty_equals import IsFloat, IsNow as _IsNow
|
|
34
|
+
|
|
35
|
+
def IsNow(*args: Any, **kwargs: Any):
|
|
36
|
+
# Increase the default value of `delta` to 10 to reduce test flakiness on overburdened machines
|
|
37
|
+
if 'delta' not in kwargs:
|
|
38
|
+
kwargs['delta'] = 10
|
|
39
|
+
return _IsNow(*args, **kwargs)
|
|
40
|
+
|
|
34
41
|
|
|
35
42
|
try:
|
|
36
43
|
from logfire.testing import CaptureLogfire
|
|
@@ -166,7 +173,7 @@ def try_import() -> Iterator[Callable[[], bool]]:
|
|
|
166
173
|
import_success = True
|
|
167
174
|
|
|
168
175
|
|
|
169
|
-
@pytest.fixture
|
|
176
|
+
@pytest.fixture(autouse=True)
|
|
170
177
|
def set_event_loop() -> Iterator[None]:
|
|
171
178
|
new_loop = asyncio.new_event_loop()
|
|
172
179
|
asyncio.set_event_loop(new_loop)
|
|
@@ -194,6 +194,38 @@ stateDiagram-v2
|
|
|
194
194
|
""")
|
|
195
195
|
|
|
196
196
|
|
|
197
|
+
def test_mermaid_code_all_nodes_no_direction():
|
|
198
|
+
assert graph3.mermaid_code() == snapshot("""\
|
|
199
|
+
---
|
|
200
|
+
title: graph3
|
|
201
|
+
---
|
|
202
|
+
stateDiagram-v2
|
|
203
|
+
AllNodes --> AllNodes
|
|
204
|
+
AllNodes --> Foo
|
|
205
|
+
AllNodes --> Bar
|
|
206
|
+
Foo --> Bar
|
|
207
|
+
Bar --> [*]\
|
|
208
|
+
""")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def test_mermaid_code_all_nodes_with_direction_lr():
|
|
212
|
+
assert graph3.mermaid_code(direction='LR') == snapshot("""\
|
|
213
|
+
---
|
|
214
|
+
title: graph3
|
|
215
|
+
---
|
|
216
|
+
stateDiagram-v2
|
|
217
|
+
direction LR
|
|
218
|
+
AllNodes --> AllNodes
|
|
219
|
+
AllNodes --> Foo
|
|
220
|
+
AllNodes --> Bar
|
|
221
|
+
Foo --> Bar
|
|
222
|
+
Bar --> [*]\
|
|
223
|
+
""")
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
# Tests for direction ends here
|
|
227
|
+
|
|
228
|
+
|
|
197
229
|
def test_docstring_notes_classvar():
|
|
198
230
|
assert Spam.docstring_notes is True
|
|
199
231
|
assert repr(Spam()) == 'Spam()'
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import timezone
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from typing import Any, TypeVar, cast
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
from inline_snapshot import snapshot
|
|
11
|
+
|
|
12
|
+
from pydantic_ai import Agent, ModelRetry
|
|
13
|
+
from pydantic_ai.messages import (
|
|
14
|
+
ModelRequest,
|
|
15
|
+
ModelResponse,
|
|
16
|
+
RetryPromptPart,
|
|
17
|
+
SystemPromptPart,
|
|
18
|
+
TextPart,
|
|
19
|
+
ToolCallPart,
|
|
20
|
+
ToolReturnPart,
|
|
21
|
+
UserPromptPart,
|
|
22
|
+
)
|
|
23
|
+
from pydantic_ai.result import Usage
|
|
24
|
+
from pydantic_ai.settings import ModelSettings
|
|
25
|
+
|
|
26
|
+
from ..conftest import IsNow, try_import
|
|
27
|
+
from .mock_async_stream import MockAsyncStream
|
|
28
|
+
|
|
29
|
+
with try_import() as imports_successful:
|
|
30
|
+
from anthropic import NOT_GIVEN, AsyncAnthropic
|
|
31
|
+
from anthropic.types import (
|
|
32
|
+
ContentBlock,
|
|
33
|
+
InputJSONDelta,
|
|
34
|
+
Message as AnthropicMessage,
|
|
35
|
+
MessageDeltaUsage,
|
|
36
|
+
RawContentBlockDeltaEvent,
|
|
37
|
+
RawContentBlockStartEvent,
|
|
38
|
+
RawContentBlockStopEvent,
|
|
39
|
+
RawMessageDeltaEvent,
|
|
40
|
+
RawMessageStartEvent,
|
|
41
|
+
RawMessageStopEvent,
|
|
42
|
+
RawMessageStreamEvent,
|
|
43
|
+
TextBlock,
|
|
44
|
+
ToolUseBlock,
|
|
45
|
+
Usage as AnthropicUsage,
|
|
46
|
+
)
|
|
47
|
+
from anthropic.types.raw_message_delta_event import Delta
|
|
48
|
+
|
|
49
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
50
|
+
|
|
51
|
+
pytestmark = [
|
|
52
|
+
pytest.mark.skipif(not imports_successful(), reason='anthropic not installed'),
|
|
53
|
+
pytest.mark.anyio,
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
# Type variable for generic AsyncStream
|
|
57
|
+
T = TypeVar('T')
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_init():
|
|
61
|
+
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
|
|
62
|
+
assert m.client.api_key == 'foobar'
|
|
63
|
+
assert m.name() == 'anthropic:claude-3-5-haiku-latest'
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class MockAnthropic:
|
|
68
|
+
messages_: AnthropicMessage | list[AnthropicMessage] | None = None
|
|
69
|
+
stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]] | None = None
|
|
70
|
+
index = 0
|
|
71
|
+
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
|
|
72
|
+
|
|
73
|
+
@cached_property
|
|
74
|
+
def messages(self) -> Any:
|
|
75
|
+
return type('Messages', (), {'create': self.messages_create})
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic:
|
|
79
|
+
return cast(AsyncAnthropic, cls(messages_=messages_))
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def create_stream_mock(
|
|
83
|
+
cls, stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]]
|
|
84
|
+
) -> AsyncAnthropic:
|
|
85
|
+
return cast(AsyncAnthropic, cls(stream=stream))
|
|
86
|
+
|
|
87
|
+
async def messages_create(
|
|
88
|
+
self, *_args: Any, stream: bool = False, **kwargs: Any
|
|
89
|
+
) -> AnthropicMessage | MockAsyncStream[RawMessageStreamEvent]:
|
|
90
|
+
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
|
|
91
|
+
|
|
92
|
+
if stream:
|
|
93
|
+
assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided'
|
|
94
|
+
# noinspection PyUnresolvedReferences
|
|
95
|
+
if isinstance(self.stream[0], list):
|
|
96
|
+
indexed_stream = cast(list[RawMessageStreamEvent], self.stream[self.index])
|
|
97
|
+
response = MockAsyncStream(iter(indexed_stream))
|
|
98
|
+
else:
|
|
99
|
+
response = MockAsyncStream(iter(cast(list[RawMessageStreamEvent], self.stream)))
|
|
100
|
+
else:
|
|
101
|
+
assert self.messages_ is not None, '`messages` must be provided'
|
|
102
|
+
if isinstance(self.messages_, list):
|
|
103
|
+
response = self.messages_[self.index]
|
|
104
|
+
else:
|
|
105
|
+
response = self.messages_
|
|
106
|
+
self.index += 1
|
|
107
|
+
return response
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> AnthropicMessage:
|
|
111
|
+
return AnthropicMessage(
|
|
112
|
+
id='123',
|
|
113
|
+
content=content,
|
|
114
|
+
model='claude-3-5-haiku-latest',
|
|
115
|
+
role='assistant',
|
|
116
|
+
stop_reason='end_turn',
|
|
117
|
+
type='message',
|
|
118
|
+
usage=usage,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async def test_sync_request_text_response(allow_model_requests: None):
|
|
123
|
+
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
|
|
124
|
+
mock_client = MockAnthropic.create_mock(c)
|
|
125
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
126
|
+
agent = Agent(m)
|
|
127
|
+
|
|
128
|
+
result = await agent.run('hello')
|
|
129
|
+
assert result.data == 'world'
|
|
130
|
+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
|
|
131
|
+
|
|
132
|
+
# reset the index so we get the same response again
|
|
133
|
+
mock_client.index = 0 # type: ignore
|
|
134
|
+
|
|
135
|
+
result = await agent.run('hello', message_history=result.new_messages())
|
|
136
|
+
assert result.data == 'world'
|
|
137
|
+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
|
|
138
|
+
assert result.all_messages() == snapshot(
|
|
139
|
+
[
|
|
140
|
+
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
141
|
+
ModelResponse(
|
|
142
|
+
parts=[TextPart(content='world')],
|
|
143
|
+
model_name='claude-3-5-haiku-latest',
|
|
144
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
145
|
+
),
|
|
146
|
+
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
147
|
+
ModelResponse(
|
|
148
|
+
parts=[TextPart(content='world')],
|
|
149
|
+
model_name='claude-3-5-haiku-latest',
|
|
150
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
151
|
+
),
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def test_async_request_text_response(allow_model_requests: None):
|
|
157
|
+
c = completion_message(
|
|
158
|
+
[TextBlock(text='world', type='text')],
|
|
159
|
+
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
|
|
160
|
+
)
|
|
161
|
+
mock_client = MockAnthropic.create_mock(c)
|
|
162
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
163
|
+
agent = Agent(m)
|
|
164
|
+
|
|
165
|
+
result = await agent.run('hello')
|
|
166
|
+
assert result.data == 'world'
|
|
167
|
+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=3, response_tokens=5, total_tokens=8))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
async def test_request_structured_response(allow_model_requests: None):
|
|
171
|
+
c = completion_message(
|
|
172
|
+
[ToolUseBlock(id='123', input={'response': [1, 2, 3]}, name='final_result', type='tool_use')],
|
|
173
|
+
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
|
|
174
|
+
)
|
|
175
|
+
mock_client = MockAnthropic.create_mock(c)
|
|
176
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
177
|
+
agent = Agent(m, result_type=list[int])
|
|
178
|
+
|
|
179
|
+
result = await agent.run('hello')
|
|
180
|
+
assert result.data == [1, 2, 3]
|
|
181
|
+
assert result.all_messages() == snapshot(
|
|
182
|
+
[
|
|
183
|
+
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
|
|
184
|
+
ModelResponse(
|
|
185
|
+
parts=[
|
|
186
|
+
ToolCallPart(
|
|
187
|
+
tool_name='final_result',
|
|
188
|
+
args={'response': [1, 2, 3]},
|
|
189
|
+
tool_call_id='123',
|
|
190
|
+
)
|
|
191
|
+
],
|
|
192
|
+
model_name='claude-3-5-haiku-latest',
|
|
193
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
194
|
+
),
|
|
195
|
+
ModelRequest(
|
|
196
|
+
parts=[
|
|
197
|
+
ToolReturnPart(
|
|
198
|
+
tool_name='final_result',
|
|
199
|
+
content='Final result processed.',
|
|
200
|
+
tool_call_id='123',
|
|
201
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
202
|
+
)
|
|
203
|
+
]
|
|
204
|
+
),
|
|
205
|
+
]
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def test_request_tool_call(allow_model_requests: None):
|
|
210
|
+
responses = [
|
|
211
|
+
completion_message(
|
|
212
|
+
[ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')],
|
|
213
|
+
usage=AnthropicUsage(input_tokens=2, output_tokens=1),
|
|
214
|
+
),
|
|
215
|
+
completion_message(
|
|
216
|
+
[ToolUseBlock(id='2', input={'loc_name': 'London'}, name='get_location', type='tool_use')],
|
|
217
|
+
usage=AnthropicUsage(input_tokens=3, output_tokens=2),
|
|
218
|
+
),
|
|
219
|
+
completion_message(
|
|
220
|
+
[TextBlock(text='final response', type='text')],
|
|
221
|
+
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
|
|
222
|
+
),
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
mock_client = MockAnthropic.create_mock(responses)
|
|
226
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
227
|
+
agent = Agent(m, system_prompt='this is the system prompt')
|
|
228
|
+
|
|
229
|
+
@agent.tool_plain
|
|
230
|
+
async def get_location(loc_name: str) -> str:
|
|
231
|
+
if loc_name == 'London':
|
|
232
|
+
return json.dumps({'lat': 51, 'lng': 0})
|
|
233
|
+
else:
|
|
234
|
+
raise ModelRetry('Wrong location, please try again')
|
|
235
|
+
|
|
236
|
+
result = await agent.run('hello')
|
|
237
|
+
assert result.data == 'final response'
|
|
238
|
+
assert result.all_messages() == snapshot(
|
|
239
|
+
[
|
|
240
|
+
ModelRequest(
|
|
241
|
+
parts=[
|
|
242
|
+
SystemPromptPart(content='this is the system prompt'),
|
|
243
|
+
UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)),
|
|
244
|
+
]
|
|
245
|
+
),
|
|
246
|
+
ModelResponse(
|
|
247
|
+
parts=[
|
|
248
|
+
ToolCallPart(
|
|
249
|
+
tool_name='get_location',
|
|
250
|
+
args={'loc_name': 'San Francisco'},
|
|
251
|
+
tool_call_id='1',
|
|
252
|
+
)
|
|
253
|
+
],
|
|
254
|
+
model_name='claude-3-5-haiku-latest',
|
|
255
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
256
|
+
),
|
|
257
|
+
ModelRequest(
|
|
258
|
+
parts=[
|
|
259
|
+
RetryPromptPart(
|
|
260
|
+
content='Wrong location, please try again',
|
|
261
|
+
tool_name='get_location',
|
|
262
|
+
tool_call_id='1',
|
|
263
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
264
|
+
)
|
|
265
|
+
]
|
|
266
|
+
),
|
|
267
|
+
ModelResponse(
|
|
268
|
+
parts=[
|
|
269
|
+
ToolCallPart(
|
|
270
|
+
tool_name='get_location',
|
|
271
|
+
args={'loc_name': 'London'},
|
|
272
|
+
tool_call_id='2',
|
|
273
|
+
)
|
|
274
|
+
],
|
|
275
|
+
model_name='claude-3-5-haiku-latest',
|
|
276
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
277
|
+
),
|
|
278
|
+
ModelRequest(
|
|
279
|
+
parts=[
|
|
280
|
+
ToolReturnPart(
|
|
281
|
+
tool_name='get_location',
|
|
282
|
+
content='{"lat": 51, "lng": 0}',
|
|
283
|
+
tool_call_id='2',
|
|
284
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
285
|
+
)
|
|
286
|
+
]
|
|
287
|
+
),
|
|
288
|
+
ModelResponse(
|
|
289
|
+
parts=[TextPart(content='final response')],
|
|
290
|
+
model_name='claude-3-5-haiku-latest',
|
|
291
|
+
timestamp=IsNow(tz=timezone.utc),
|
|
292
|
+
),
|
|
293
|
+
]
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def get_mock_chat_completion_kwargs(async_anthropic: AsyncAnthropic) -> list[dict[str, Any]]:
|
|
298
|
+
if isinstance(async_anthropic, MockAnthropic):
|
|
299
|
+
return async_anthropic.chat_completion_kwargs
|
|
300
|
+
else: # pragma: no cover
|
|
301
|
+
raise RuntimeError('Not a MockOpenAI instance')
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@pytest.mark.parametrize('parallel_tool_calls', [True, False])
|
|
305
|
+
async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None:
|
|
306
|
+
responses = [
|
|
307
|
+
completion_message(
|
|
308
|
+
[ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')],
|
|
309
|
+
usage=AnthropicUsage(input_tokens=2, output_tokens=1),
|
|
310
|
+
),
|
|
311
|
+
completion_message(
|
|
312
|
+
[TextBlock(text='final response', type='text')],
|
|
313
|
+
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
|
|
314
|
+
),
|
|
315
|
+
]
|
|
316
|
+
|
|
317
|
+
mock_client = MockAnthropic.create_mock(responses)
|
|
318
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
319
|
+
agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
|
|
320
|
+
|
|
321
|
+
@agent.tool_plain
|
|
322
|
+
async def get_location(loc_name: str) -> str:
|
|
323
|
+
if loc_name == 'London':
|
|
324
|
+
return json.dumps({'lat': 51, 'lng': 0})
|
|
325
|
+
else:
|
|
326
|
+
raise ModelRetry('Wrong location, please try again')
|
|
327
|
+
|
|
328
|
+
await agent.run('hello')
|
|
329
|
+
assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice']['disable_parallel_tool_use'] == (
|
|
330
|
+
not parallel_tool_calls
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
|
|
335
|
+
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
|
|
336
|
+
mock_client = MockAnthropic.create_mock(c)
|
|
337
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
338
|
+
agent = Agent(m)
|
|
339
|
+
|
|
340
|
+
result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
|
|
341
|
+
assert result.data == 'world'
|
|
342
|
+
assert get_mock_chat_completion_kwargs(mock_client)[0]['metadata']['user_id'] == '123'
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
async def test_stream_structured(allow_model_requests: None):
|
|
346
|
+
"""Test streaming structured responses with Anthropic's API.
|
|
347
|
+
|
|
348
|
+
This test simulates how Anthropic streams tool calls:
|
|
349
|
+
1. Message start
|
|
350
|
+
2. Tool block start with initial data
|
|
351
|
+
3. Tool block delta with additional data
|
|
352
|
+
4. Tool block stop
|
|
353
|
+
5. Update usage
|
|
354
|
+
6. Message stop
|
|
355
|
+
"""
|
|
356
|
+
stream: list[RawMessageStreamEvent] = [
|
|
357
|
+
RawMessageStartEvent(
|
|
358
|
+
type='message_start',
|
|
359
|
+
message=AnthropicMessage(
|
|
360
|
+
id='msg_123',
|
|
361
|
+
model='claude-3-5-haiku-latest',
|
|
362
|
+
role='assistant',
|
|
363
|
+
type='message',
|
|
364
|
+
content=[],
|
|
365
|
+
stop_reason=None,
|
|
366
|
+
usage=AnthropicUsage(input_tokens=20, output_tokens=0),
|
|
367
|
+
),
|
|
368
|
+
),
|
|
369
|
+
# Start tool block with initial data
|
|
370
|
+
RawContentBlockStartEvent(
|
|
371
|
+
type='content_block_start',
|
|
372
|
+
index=0,
|
|
373
|
+
content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}),
|
|
374
|
+
),
|
|
375
|
+
# Add more data through an incomplete JSON delta
|
|
376
|
+
RawContentBlockDeltaEvent(
|
|
377
|
+
type='content_block_delta',
|
|
378
|
+
index=0,
|
|
379
|
+
delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'),
|
|
380
|
+
),
|
|
381
|
+
RawContentBlockDeltaEvent(
|
|
382
|
+
type='content_block_delta',
|
|
383
|
+
index=0,
|
|
384
|
+
delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'),
|
|
385
|
+
),
|
|
386
|
+
# Mark tool block as complete
|
|
387
|
+
RawContentBlockStopEvent(type='content_block_stop', index=0),
|
|
388
|
+
# Update the top-level message with usage
|
|
389
|
+
RawMessageDeltaEvent(
|
|
390
|
+
type='message_delta',
|
|
391
|
+
delta=Delta(
|
|
392
|
+
stop_reason='end_turn',
|
|
393
|
+
),
|
|
394
|
+
usage=MessageDeltaUsage(
|
|
395
|
+
output_tokens=5,
|
|
396
|
+
),
|
|
397
|
+
),
|
|
398
|
+
# Mark message as complete
|
|
399
|
+
RawMessageStopEvent(type='message_stop'),
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
done_stream: list[RawMessageStreamEvent] = [
|
|
403
|
+
RawMessageStartEvent(
|
|
404
|
+
type='message_start',
|
|
405
|
+
message=AnthropicMessage(
|
|
406
|
+
id='msg_123',
|
|
407
|
+
model='claude-3-5-haiku-latest',
|
|
408
|
+
role='assistant',
|
|
409
|
+
type='message',
|
|
410
|
+
content=[],
|
|
411
|
+
stop_reason=None,
|
|
412
|
+
usage=AnthropicUsage(input_tokens=0, output_tokens=0),
|
|
413
|
+
),
|
|
414
|
+
),
|
|
415
|
+
# Text block with final data
|
|
416
|
+
RawContentBlockStartEvent(
|
|
417
|
+
type='content_block_start',
|
|
418
|
+
index=0,
|
|
419
|
+
content_block=TextBlock(type='text', text='FINAL_PAYLOAD'),
|
|
420
|
+
),
|
|
421
|
+
RawContentBlockStopEvent(type='content_block_stop', index=0),
|
|
422
|
+
RawMessageStopEvent(type='message_stop'),
|
|
423
|
+
]
|
|
424
|
+
|
|
425
|
+
mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
|
|
426
|
+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
|
|
427
|
+
agent = Agent(m)
|
|
428
|
+
|
|
429
|
+
tool_called = False
|
|
430
|
+
|
|
431
|
+
@agent.tool_plain
|
|
432
|
+
async def my_tool(first: str, second: str) -> int:
|
|
433
|
+
nonlocal tool_called
|
|
434
|
+
tool_called = True
|
|
435
|
+
return len(first) + len(second)
|
|
436
|
+
|
|
437
|
+
async with agent.run_stream('') as result:
|
|
438
|
+
assert not result.is_complete
|
|
439
|
+
chunks = [c async for c in result.stream(debounce_by=None)]
|
|
440
|
+
|
|
441
|
+
# The tool output doesn't echo any content to the stream, so we only get the final payload once when
|
|
442
|
+
# the block starts and once when it ends.
|
|
443
|
+
assert chunks == snapshot(
|
|
444
|
+
[
|
|
445
|
+
'FINAL_PAYLOAD',
|
|
446
|
+
'FINAL_PAYLOAD',
|
|
447
|
+
]
|
|
448
|
+
)
|
|
449
|
+
assert result.is_complete
|
|
450
|
+
assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25))
|
|
451
|
+
assert tool_called
|