planar 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- planar/.__init__.py.un~ +0 -0
- planar/._version.py.un~ +0 -0
- planar/.app.py.un~ +0 -0
- planar/.cli.py.un~ +0 -0
- planar/.config.py.un~ +0 -0
- planar/.context.py.un~ +0 -0
- planar/.db.py.un~ +0 -0
- planar/.di.py.un~ +0 -0
- planar/.engine.py.un~ +0 -0
- planar/.files.py.un~ +0 -0
- planar/.log_context.py.un~ +0 -0
- planar/.log_metadata.py.un~ +0 -0
- planar/.logging.py.un~ +0 -0
- planar/.object_registry.py.un~ +0 -0
- planar/.otel.py.un~ +0 -0
- planar/.server.py.un~ +0 -0
- planar/.session.py.un~ +0 -0
- planar/.sqlalchemy.py.un~ +0 -0
- planar/.task_local.py.un~ +0 -0
- planar/.test_app.py.un~ +0 -0
- planar/.test_config.py.un~ +0 -0
- planar/.test_object_config.py.un~ +0 -0
- planar/.test_sqlalchemy.py.un~ +0 -0
- planar/.test_utils.py.un~ +0 -0
- planar/.util.py.un~ +0 -0
- planar/.utils.py.un~ +0 -0
- planar/__init__.py +26 -0
- planar/_version.py +1 -0
- planar/ai/.__init__.py.un~ +0 -0
- planar/ai/._models.py.un~ +0 -0
- planar/ai/.agent.py.un~ +0 -0
- planar/ai/.agent_utils.py.un~ +0 -0
- planar/ai/.events.py.un~ +0 -0
- planar/ai/.files.py.un~ +0 -0
- planar/ai/.models.py.un~ +0 -0
- planar/ai/.providers.py.un~ +0 -0
- planar/ai/.pydantic_ai.py.un~ +0 -0
- planar/ai/.pydantic_ai_agent.py.un~ +0 -0
- planar/ai/.pydantic_ai_provider.py.un~ +0 -0
- planar/ai/.step.py.un~ +0 -0
- planar/ai/.test_agent.py.un~ +0 -0
- planar/ai/.test_agent_serialization.py.un~ +0 -0
- planar/ai/.test_providers.py.un~ +0 -0
- planar/ai/.utils.py.un~ +0 -0
- planar/ai/__init__.py +15 -0
- planar/ai/agent.py +457 -0
- planar/ai/agent_utils.py +205 -0
- planar/ai/models.py +140 -0
- planar/ai/providers.py +1088 -0
- planar/ai/test_agent.py +1298 -0
- planar/ai/test_agent_serialization.py +229 -0
- planar/ai/test_providers.py +463 -0
- planar/ai/utils.py +102 -0
- planar/app.py +494 -0
- planar/cli.py +282 -0
- planar/config.py +544 -0
- planar/db/.db.py.un~ +0 -0
- planar/db/__init__.py +17 -0
- planar/db/alembic/env.py +136 -0
- planar/db/alembic/script.py.mako +28 -0
- planar/db/alembic/versions/3476068c153c_initial_system_tables_migration.py +339 -0
- planar/db/alembic.ini +128 -0
- planar/db/db.py +318 -0
- planar/files/.config.py.un~ +0 -0
- planar/files/.local.py.un~ +0 -0
- planar/files/.local_filesystem.py.un~ +0 -0
- planar/files/.model.py.un~ +0 -0
- planar/files/.models.py.un~ +0 -0
- planar/files/.s3.py.un~ +0 -0
- planar/files/.storage.py.un~ +0 -0
- planar/files/.test_files.py.un~ +0 -0
- planar/files/__init__.py +2 -0
- planar/files/models.py +162 -0
- planar/files/storage/.__init__.py.un~ +0 -0
- planar/files/storage/.base.py.un~ +0 -0
- planar/files/storage/.config.py.un~ +0 -0
- planar/files/storage/.context.py.un~ +0 -0
- planar/files/storage/.local_directory.py.un~ +0 -0
- planar/files/storage/.test_local_directory.py.un~ +0 -0
- planar/files/storage/.test_s3.py.un~ +0 -0
- planar/files/storage/base.py +61 -0
- planar/files/storage/config.py +44 -0
- planar/files/storage/context.py +15 -0
- planar/files/storage/local_directory.py +188 -0
- planar/files/storage/s3.py +220 -0
- planar/files/storage/test_local_directory.py +162 -0
- planar/files/storage/test_s3.py +299 -0
- planar/files/test_files.py +283 -0
- planar/human/.human.py.un~ +0 -0
- planar/human/.test_human.py.un~ +0 -0
- planar/human/__init__.py +2 -0
- planar/human/human.py +458 -0
- planar/human/models.py +80 -0
- planar/human/test_human.py +385 -0
- planar/logging/.__init__.py.un~ +0 -0
- planar/logging/.attributes.py.un~ +0 -0
- planar/logging/.formatter.py.un~ +0 -0
- planar/logging/.logger.py.un~ +0 -0
- planar/logging/.otel.py.un~ +0 -0
- planar/logging/.tracer.py.un~ +0 -0
- planar/logging/__init__.py +10 -0
- planar/logging/attributes.py +54 -0
- planar/logging/context.py +14 -0
- planar/logging/formatter.py +113 -0
- planar/logging/logger.py +114 -0
- planar/logging/otel.py +51 -0
- planar/modeling/.mixin.py.un~ +0 -0
- planar/modeling/.storage.py.un~ +0 -0
- planar/modeling/__init__.py +0 -0
- planar/modeling/field_helpers.py +59 -0
- planar/modeling/json_schema_generator.py +94 -0
- planar/modeling/mixins/__init__.py +10 -0
- planar/modeling/mixins/auditable.py +52 -0
- planar/modeling/mixins/test_auditable.py +97 -0
- planar/modeling/mixins/test_timestamp.py +134 -0
- planar/modeling/mixins/test_uuid_primary_key.py +52 -0
- planar/modeling/mixins/timestamp.py +53 -0
- planar/modeling/mixins/uuid_primary_key.py +19 -0
- planar/modeling/orm/.planar_base_model.py.un~ +0 -0
- planar/modeling/orm/__init__.py +18 -0
- planar/modeling/orm/planar_base_entity.py +29 -0
- planar/modeling/orm/query_filter_builder.py +122 -0
- planar/modeling/orm/reexports.py +15 -0
- planar/object_config/.object_config.py.un~ +0 -0
- planar/object_config/__init__.py +11 -0
- planar/object_config/models.py +114 -0
- planar/object_config/object_config.py +378 -0
- planar/object_registry.py +100 -0
- planar/registry_items.py +65 -0
- planar/routers/.__init__.py.un~ +0 -0
- planar/routers/.agents_router.py.un~ +0 -0
- planar/routers/.crud.py.un~ +0 -0
- planar/routers/.decision.py.un~ +0 -0
- planar/routers/.event.py.un~ +0 -0
- planar/routers/.file_attachment.py.un~ +0 -0
- planar/routers/.files.py.un~ +0 -0
- planar/routers/.files_router.py.un~ +0 -0
- planar/routers/.human.py.un~ +0 -0
- planar/routers/.info.py.un~ +0 -0
- planar/routers/.models.py.un~ +0 -0
- planar/routers/.object_config_router.py.un~ +0 -0
- planar/routers/.rule.py.un~ +0 -0
- planar/routers/.test_object_config_router.py.un~ +0 -0
- planar/routers/.test_workflow_router.py.un~ +0 -0
- planar/routers/.workflow.py.un~ +0 -0
- planar/routers/__init__.py +13 -0
- planar/routers/agents_router.py +197 -0
- planar/routers/entity_router.py +143 -0
- planar/routers/event.py +91 -0
- planar/routers/files.py +142 -0
- planar/routers/human.py +151 -0
- planar/routers/info.py +131 -0
- planar/routers/models.py +170 -0
- planar/routers/object_config_router.py +133 -0
- planar/routers/rule.py +108 -0
- planar/routers/test_agents_router.py +174 -0
- planar/routers/test_object_config_router.py +367 -0
- planar/routers/test_routes_security.py +169 -0
- planar/routers/test_rule_router.py +470 -0
- planar/routers/test_workflow_router.py +274 -0
- planar/routers/workflow.py +468 -0
- planar/rules/.decorator.py.un~ +0 -0
- planar/rules/.runner.py.un~ +0 -0
- planar/rules/.test_rules.py.un~ +0 -0
- planar/rules/__init__.py +23 -0
- planar/rules/decorator.py +184 -0
- planar/rules/models.py +355 -0
- planar/rules/rule_configuration.py +191 -0
- planar/rules/runner.py +64 -0
- planar/rules/test_rules.py +750 -0
- planar/scaffold_templates/app/__init__.py.j2 +0 -0
- planar/scaffold_templates/app/db/entities.py.j2 +11 -0
- planar/scaffold_templates/app/flows/process_invoice.py.j2 +67 -0
- planar/scaffold_templates/main.py.j2 +13 -0
- planar/scaffold_templates/planar.dev.yaml.j2 +34 -0
- planar/scaffold_templates/planar.prod.yaml.j2 +28 -0
- planar/scaffold_templates/pyproject.toml.j2 +10 -0
- planar/security/.jwt_middleware.py.un~ +0 -0
- planar/security/auth_context.py +148 -0
- planar/security/authorization.py +388 -0
- planar/security/default_policies.cedar +77 -0
- planar/security/jwt_middleware.py +116 -0
- planar/security/security_context.py +18 -0
- planar/security/tests/test_authorization_context.py +78 -0
- planar/security/tests/test_cedar_basics.py +41 -0
- planar/security/tests/test_cedar_policies.py +158 -0
- planar/security/tests/test_jwt_principal_context.py +179 -0
- planar/session.py +40 -0
- planar/sse/.constants.py.un~ +0 -0
- planar/sse/.example.html.un~ +0 -0
- planar/sse/.hub.py.un~ +0 -0
- planar/sse/.model.py.un~ +0 -0
- planar/sse/.proxy.py.un~ +0 -0
- planar/sse/constants.py +1 -0
- planar/sse/example.html +126 -0
- planar/sse/hub.py +216 -0
- planar/sse/model.py +8 -0
- planar/sse/proxy.py +257 -0
- planar/task_local.py +37 -0
- planar/test_app.py +51 -0
- planar/test_cli.py +372 -0
- planar/test_config.py +512 -0
- planar/test_object_config.py +527 -0
- planar/test_object_registry.py +14 -0
- planar/test_sqlalchemy.py +158 -0
- planar/test_utils.py +105 -0
- planar/testing/.client.py.un~ +0 -0
- planar/testing/.memory_storage.py.un~ +0 -0
- planar/testing/.planar_test_client.py.un~ +0 -0
- planar/testing/.predictable_tracer.py.un~ +0 -0
- planar/testing/.synchronizable_tracer.py.un~ +0 -0
- planar/testing/.test_memory_storage.py.un~ +0 -0
- planar/testing/.workflow_observer.py.un~ +0 -0
- planar/testing/__init__.py +0 -0
- planar/testing/memory_storage.py +78 -0
- planar/testing/planar_test_client.py +54 -0
- planar/testing/synchronizable_tracer.py +153 -0
- planar/testing/test_memory_storage.py +143 -0
- planar/testing/workflow_observer.py +73 -0
- planar/utils.py +70 -0
- planar/workflows/.__init__.py.un~ +0 -0
- planar/workflows/.builtin_steps.py.un~ +0 -0
- planar/workflows/.concurrency_tracing.py.un~ +0 -0
- planar/workflows/.context.py.un~ +0 -0
- planar/workflows/.contrib.py.un~ +0 -0
- planar/workflows/.decorators.py.un~ +0 -0
- planar/workflows/.durable_test.py.un~ +0 -0
- planar/workflows/.errors.py.un~ +0 -0
- planar/workflows/.events.py.un~ +0 -0
- planar/workflows/.exceptions.py.un~ +0 -0
- planar/workflows/.execution.py.un~ +0 -0
- planar/workflows/.human.py.un~ +0 -0
- planar/workflows/.lock.py.un~ +0 -0
- planar/workflows/.misc.py.un~ +0 -0
- planar/workflows/.model.py.un~ +0 -0
- planar/workflows/.models.py.un~ +0 -0
- planar/workflows/.notifications.py.un~ +0 -0
- planar/workflows/.orchestrator.py.un~ +0 -0
- planar/workflows/.runtime.py.un~ +0 -0
- planar/workflows/.serialization.py.un~ +0 -0
- planar/workflows/.step.py.un~ +0 -0
- planar/workflows/.step_core.py.un~ +0 -0
- planar/workflows/.sub_workflow_runner.py.un~ +0 -0
- planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
- planar/workflows/.test_concurrency.py.un~ +0 -0
- planar/workflows/.test_concurrency_detection.py.un~ +0 -0
- planar/workflows/.test_human.py.un~ +0 -0
- planar/workflows/.test_lock_timeout.py.un~ +0 -0
- planar/workflows/.test_orchestrator.py.un~ +0 -0
- planar/workflows/.test_race_conditions.py.un~ +0 -0
- planar/workflows/.test_serialization.py.un~ +0 -0
- planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
- planar/workflows/.test_workflow.py.un~ +0 -0
- planar/workflows/.tracing.py.un~ +0 -0
- planar/workflows/.types.py.un~ +0 -0
- planar/workflows/.util.py.un~ +0 -0
- planar/workflows/.utils.py.un~ +0 -0
- planar/workflows/.workflow.py.un~ +0 -0
- planar/workflows/.workflow_wrapper.py.un~ +0 -0
- planar/workflows/.wrappers.py.un~ +0 -0
- planar/workflows/__init__.py +42 -0
- planar/workflows/context.py +44 -0
- planar/workflows/contrib.py +190 -0
- planar/workflows/decorators.py +217 -0
- planar/workflows/events.py +185 -0
- planar/workflows/exceptions.py +34 -0
- planar/workflows/execution.py +198 -0
- planar/workflows/lock.py +229 -0
- planar/workflows/misc.py +5 -0
- planar/workflows/models.py +154 -0
- planar/workflows/notifications.py +96 -0
- planar/workflows/orchestrator.py +383 -0
- planar/workflows/query.py +256 -0
- planar/workflows/serialization.py +409 -0
- planar/workflows/step_core.py +373 -0
- planar/workflows/step_metadata.py +357 -0
- planar/workflows/step_testing_utils.py +86 -0
- planar/workflows/sub_workflow_runner.py +191 -0
- planar/workflows/test_concurrency_detection.py +120 -0
- planar/workflows/test_lock_timeout.py +140 -0
- planar/workflows/test_serialization.py +1195 -0
- planar/workflows/test_suspend_deserialization.py +231 -0
- planar/workflows/test_workflow.py +1967 -0
- planar/workflows/tracing.py +106 -0
- planar/workflows/wrappers.py +41 -0
- planar-0.5.0.dist-info/METADATA +285 -0
- planar-0.5.0.dist-info/RECORD +289 -0
- planar-0.5.0.dist-info/WHEEL +4 -0
- planar-0.5.0.dist-info/entry_points.txt +3 -0
planar/ai/test_agent.py
ADDED
@@ -0,0 +1,1298 @@
|
|
1
|
+
from datetime import timedelta
|
2
|
+
from typing import Any, cast
|
3
|
+
from unittest.mock import AsyncMock, Mock, patch
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
import pytest
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
from sqlmodel.ext.asyncio.session import AsyncSession
|
9
|
+
|
10
|
+
from planar.ai import Agent
|
11
|
+
from planar.ai.agent import (
|
12
|
+
AgentRunResult,
|
13
|
+
)
|
14
|
+
from planar.ai.agent_utils import create_tool_definition, extract_files_from_model
|
15
|
+
from planar.ai.models import (
|
16
|
+
AgentConfig,
|
17
|
+
AssistantMessage,
|
18
|
+
CompletionResponse,
|
19
|
+
SystemMessage,
|
20
|
+
ToolCall,
|
21
|
+
ToolMessage,
|
22
|
+
UserMessage,
|
23
|
+
)
|
24
|
+
from planar.ai.providers import OpenAI
|
25
|
+
from planar.app import PlanarApp
|
26
|
+
from planar.config import sqlite_config
|
27
|
+
from planar.files.models import PlanarFile
|
28
|
+
from planar.testing.planar_test_client import PlanarTestClient
|
29
|
+
from planar.workflows.decorators import workflow
|
30
|
+
from planar.workflows.execution import execute
|
31
|
+
from planar.workflows.models import Workflow
|
32
|
+
from planar.workflows.step_core import Suspend, suspend
|
33
|
+
|
34
|
+
app = PlanarApp(
|
35
|
+
config=sqlite_config(":memory:"),
|
36
|
+
title="Planar app for testing agents",
|
37
|
+
description="Testing",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
@pytest.fixture(name="app")
|
42
|
+
def app_fixture():
|
43
|
+
yield app
|
44
|
+
|
45
|
+
|
46
|
+
# Test data and models (not test classes themselves)
|
47
|
+
# Using different names to avoid pytest collection warnings
|
48
|
+
class InputModel(BaseModel):
|
49
|
+
text: str
|
50
|
+
value: int
|
51
|
+
|
52
|
+
|
53
|
+
class OutputModel(BaseModel):
|
54
|
+
message: str
|
55
|
+
score: int
|
56
|
+
|
57
|
+
|
58
|
+
# Mock data for receipt analysis tests
|
59
|
+
MOCK_RECEIPT_DATA = {
|
60
|
+
"merchant_name": "Coffee Shop",
|
61
|
+
"date": "2025-03-11",
|
62
|
+
"total_amount": 42.99,
|
63
|
+
"items": [
|
64
|
+
{"name": "Coffee", "price": 4.99, "quantity": 2},
|
65
|
+
{"name": "Pastry", "price": 3.99, "quantity": 1},
|
66
|
+
],
|
67
|
+
"payment_method": "Credit Card",
|
68
|
+
"receipt_number": "R-123456",
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
@pytest.fixture
|
73
|
+
def mock_providers():
|
74
|
+
"""Mock both OpenAI and Anthropic providers to return test responses."""
|
75
|
+
|
76
|
+
# Create a factory to produce provider mocks with consistent tracking
|
77
|
+
def create_provider_mock():
|
78
|
+
mock = Mock()
|
79
|
+
mock.call_count = 0
|
80
|
+
return mock
|
81
|
+
|
82
|
+
# Create mocks for each provider
|
83
|
+
provider_mocks = {
|
84
|
+
"openai": create_provider_mock(),
|
85
|
+
"anthropic": create_provider_mock(),
|
86
|
+
}
|
87
|
+
|
88
|
+
# Shared mock response generator
|
89
|
+
async def generate_response(
|
90
|
+
output_type=None, tools=None, planar_files=None, is_first_call=True
|
91
|
+
):
|
92
|
+
"""Generate appropriate mock responses based on request parameters"""
|
93
|
+
# Tool-based multi-turn conversation
|
94
|
+
if tools:
|
95
|
+
if is_first_call:
|
96
|
+
return CompletionResponse(
|
97
|
+
content=None,
|
98
|
+
tool_calls=[
|
99
|
+
cast(
|
100
|
+
ToolCall,
|
101
|
+
{
|
102
|
+
"id": "call_1",
|
103
|
+
"name": "tool1",
|
104
|
+
"arguments": {"param": "test_param"},
|
105
|
+
},
|
106
|
+
)
|
107
|
+
],
|
108
|
+
)
|
109
|
+
elif output_type == OutputModel:
|
110
|
+
return CompletionResponse(
|
111
|
+
content=OutputModel(message="Multi-turn response", score=90),
|
112
|
+
tool_calls=None,
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
return CompletionResponse(
|
116
|
+
content="Final tool response",
|
117
|
+
tool_calls=None,
|
118
|
+
)
|
119
|
+
|
120
|
+
# Planar file processing
|
121
|
+
elif planar_files:
|
122
|
+
if output_type and issubclass(output_type, BaseModel):
|
123
|
+
# If a specific output model is requested, return a predetermined mock instance
|
124
|
+
if output_type == OutputModel:
|
125
|
+
return CompletionResponse(
|
126
|
+
content=OutputModel(message="Analyzed file content", score=98),
|
127
|
+
tool_calls=None,
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
# Check file content type for different response types
|
131
|
+
file_type = None
|
132
|
+
if len(planar_files) > 0:
|
133
|
+
file_type = planar_files[0].content_type
|
134
|
+
|
135
|
+
# Generate mock response based on file type
|
136
|
+
if file_type == "application/pdf":
|
137
|
+
mock_data = {**MOCK_RECEIPT_DATA, "document_type": "pdf"}
|
138
|
+
else: # Image types
|
139
|
+
mock_data = {**MOCK_RECEIPT_DATA, "document_type": "image"}
|
140
|
+
|
141
|
+
# Only include fields that exist in the model
|
142
|
+
filtered_data = {
|
143
|
+
k: v
|
144
|
+
for k, v in mock_data.items()
|
145
|
+
if k in output_type.model_fields
|
146
|
+
}
|
147
|
+
|
148
|
+
return CompletionResponse(
|
149
|
+
content=output_type.model_validate(filtered_data),
|
150
|
+
tool_calls=None,
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
file_type = planar_files[0].content_type if planar_files else None
|
154
|
+
if file_type == "application/pdf":
|
155
|
+
return CompletionResponse(
|
156
|
+
content="Description of the PDF document",
|
157
|
+
tool_calls=None,
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
return CompletionResponse(
|
161
|
+
content="Description of the image content",
|
162
|
+
tool_calls=None,
|
163
|
+
)
|
164
|
+
|
165
|
+
# Structured output (single turn)
|
166
|
+
elif output_type == OutputModel:
|
167
|
+
return CompletionResponse(
|
168
|
+
content=OutputModel(message="Test", score=95),
|
169
|
+
tool_calls=None,
|
170
|
+
)
|
171
|
+
|
172
|
+
# Default simple response
|
173
|
+
else:
|
174
|
+
return CompletionResponse(
|
175
|
+
content="Mock LLM response",
|
176
|
+
tool_calls=None,
|
177
|
+
)
|
178
|
+
|
179
|
+
# Create a factory function for patched provider methods
|
180
|
+
def create_provider_patch(provider_key):
|
181
|
+
"""Create patched complete method for the specified provider"""
|
182
|
+
|
183
|
+
async def patched_complete(*args, **kwargs):
|
184
|
+
# Get the provider's mock
|
185
|
+
mock = provider_mocks[provider_key]
|
186
|
+
|
187
|
+
# Update call tracking
|
188
|
+
mock.call_count += 1
|
189
|
+
mock.call_args = (args, kwargs)
|
190
|
+
mock.call_args_list.append(cast(Any, (args, kwargs)))
|
191
|
+
|
192
|
+
messages = kwargs.get("messages", [])
|
193
|
+
planar_files = None
|
194
|
+
for msg in messages:
|
195
|
+
if isinstance(msg, UserMessage) and msg.files:
|
196
|
+
planar_files = msg.files
|
197
|
+
break
|
198
|
+
|
199
|
+
# Generate appropriate response
|
200
|
+
return await generate_response(
|
201
|
+
output_type=kwargs.get("output_type"),
|
202
|
+
tools=kwargs.get("tools"),
|
203
|
+
planar_files=planar_files,
|
204
|
+
is_first_call=(mock.call_count == 1),
|
205
|
+
)
|
206
|
+
|
207
|
+
return patched_complete
|
208
|
+
|
209
|
+
# Apply patches
|
210
|
+
with (
|
211
|
+
patch(
|
212
|
+
"planar.ai.providers.OpenAIProvider.complete",
|
213
|
+
create_provider_patch("openai"),
|
214
|
+
),
|
215
|
+
patch(
|
216
|
+
"planar.ai.providers.AnthropicProvider.complete",
|
217
|
+
create_provider_patch("anthropic"),
|
218
|
+
),
|
219
|
+
):
|
220
|
+
yield (provider_mocks["openai"], provider_mocks["anthropic"])
|
221
|
+
|
222
|
+
|
223
|
+
DEFAULT_CONFIG = AgentConfig(
|
224
|
+
system_prompt="Default system prompt",
|
225
|
+
user_prompt="Default user prompt: {{ input }}",
|
226
|
+
model="openai:gpt-4.1",
|
227
|
+
max_turns=3,
|
228
|
+
)
|
229
|
+
|
230
|
+
|
231
|
+
@pytest.fixture
|
232
|
+
def mock_get_agent_config():
|
233
|
+
"""Mock the get_agent_config function to return empty config by default."""
|
234
|
+
mock = AsyncMock(return_value=DEFAULT_CONFIG)
|
235
|
+
with patch("planar.ai.agent.get_agent_config", mock):
|
236
|
+
yield mock
|
237
|
+
|
238
|
+
|
239
|
+
def test_agent_initialization():
|
240
|
+
"""Test that the Agent class initializes with correct parameters."""
|
241
|
+
agent = Agent(
|
242
|
+
name="test_agent",
|
243
|
+
system_prompt="Test system prompt: {{ param1 }}",
|
244
|
+
user_prompt="Test user prompt: {{ param2 }}",
|
245
|
+
model="test:model",
|
246
|
+
max_turns=3,
|
247
|
+
)
|
248
|
+
|
249
|
+
# Verify initialization
|
250
|
+
assert agent.name == "test_agent"
|
251
|
+
assert agent.system_prompt == "Test system prompt: {{ param1 }}"
|
252
|
+
assert agent.user_prompt == "Test user prompt: {{ param2 }}"
|
253
|
+
assert agent.model == "test:model"
|
254
|
+
assert agent.max_turns == 3
|
255
|
+
assert agent.tools == []
|
256
|
+
assert agent.input_type is None
|
257
|
+
assert agent.output_type is None
|
258
|
+
assert agent.model_parameters == {}
|
259
|
+
|
260
|
+
|
261
|
+
async def test_agent_call_simple(session: AsyncSession, mock_providers):
|
262
|
+
"""Test that an agent can be called in a workflow for a simple string response."""
|
263
|
+
openai_mock, anthropic_mock = mock_providers
|
264
|
+
|
265
|
+
# Create an agent
|
266
|
+
test_agent = Agent(
|
267
|
+
name="test_agent",
|
268
|
+
system_prompt="Process this request",
|
269
|
+
user_prompt="Input: {{ input }}",
|
270
|
+
model="openai:gpt-4.1", # Using a real provider name
|
271
|
+
)
|
272
|
+
|
273
|
+
# Define a workflow that uses the agent
|
274
|
+
@workflow()
|
275
|
+
async def test_workflow(input_text: str):
|
276
|
+
result = await test_agent(input_value=input_text)
|
277
|
+
assert isinstance(result, AgentRunResult)
|
278
|
+
return result.output
|
279
|
+
|
280
|
+
with patch(
|
281
|
+
"planar.ai.agent.get_agent_config",
|
282
|
+
AsyncMock(return_value=test_agent.to_config()),
|
283
|
+
) as mock_config:
|
284
|
+
# Start and execute the workflow
|
285
|
+
wf = await test_workflow.start("test input")
|
286
|
+
result = await execute(wf)
|
287
|
+
|
288
|
+
# Verify the result
|
289
|
+
assert result == "Mock LLM response"
|
290
|
+
|
291
|
+
# Verify the workflow completed successfully
|
292
|
+
updated_wf = await session.get(Workflow, wf.id)
|
293
|
+
assert updated_wf is not None
|
294
|
+
assert updated_wf.result == "Mock LLM response"
|
295
|
+
|
296
|
+
# Verify get_agent_config was called with the agent name
|
297
|
+
assert mock_config.called
|
298
|
+
|
299
|
+
# Verify complete was called with the formatted messages
|
300
|
+
assert openai_mock.call_count == 1 # called once
|
301
|
+
args, kwargs = openai_mock.call_args
|
302
|
+
messages = kwargs.get("messages")
|
303
|
+
assert any(
|
304
|
+
isinstance(m, SystemMessage) and m.content == "Process this request"
|
305
|
+
for m in messages
|
306
|
+
)
|
307
|
+
assert any(
|
308
|
+
isinstance(m, UserMessage) and m.content == "Input: test input"
|
309
|
+
for m in messages
|
310
|
+
)
|
311
|
+
|
312
|
+
|
313
|
+
async def test_prompt_injection_protection(session: AsyncSession, mock_providers):
|
314
|
+
"""Ensure unsafe template expressions raise an error before model call."""
|
315
|
+
openai_mock, _ = mock_providers
|
316
|
+
|
317
|
+
inj_agent = Agent(
|
318
|
+
name="inj_agent",
|
319
|
+
system_prompt="Hi",
|
320
|
+
user_prompt="{{ input.__class__.__mro__[1] }}",
|
321
|
+
)
|
322
|
+
|
323
|
+
@workflow()
|
324
|
+
async def inj_workflow(text: str):
|
325
|
+
return await inj_agent(text)
|
326
|
+
|
327
|
+
with patch(
|
328
|
+
"planar.ai.agent.get_agent_config",
|
329
|
+
AsyncMock(return_value=inj_agent.to_config()),
|
330
|
+
):
|
331
|
+
wf = await inj_workflow.start("test")
|
332
|
+
with pytest.raises(ValueError):
|
333
|
+
await execute(wf)
|
334
|
+
|
335
|
+
assert openai_mock.call_count == 0
|
336
|
+
|
337
|
+
|
338
|
+
async def test_agent_with_structured_output(session: AsyncSession, mock_providers):
|
339
|
+
"""Test agent with structured output using a Pydantic model."""
|
340
|
+
openai_mock, anthropic_mock = mock_providers
|
341
|
+
|
342
|
+
# Create an agent with structured output
|
343
|
+
test_agent = Agent(
|
344
|
+
name="structured_agent",
|
345
|
+
system_prompt="Provide structured analysis",
|
346
|
+
user_prompt="Analyze: {{ input }}",
|
347
|
+
output_type=OutputModel,
|
348
|
+
model="openai:gpt-4.1",
|
349
|
+
)
|
350
|
+
|
351
|
+
@workflow()
|
352
|
+
async def structured_workflow(input_text: str):
|
353
|
+
result = await test_agent(input_value=input_text)
|
354
|
+
await suspend(interval=timedelta(seconds=0.1))
|
355
|
+
return {"message": result.output.message, "score": result.output.score}
|
356
|
+
|
357
|
+
with patch(
|
358
|
+
"planar.ai.agent.get_agent_config",
|
359
|
+
AsyncMock(return_value=test_agent.to_config()),
|
360
|
+
):
|
361
|
+
wf = await structured_workflow.start("test structured input")
|
362
|
+
result = await execute(wf)
|
363
|
+
assert isinstance(result, Suspend)
|
364
|
+
result = await execute(wf)
|
365
|
+
|
366
|
+
assert isinstance(result, dict)
|
367
|
+
assert result["message"] == "Test"
|
368
|
+
assert result["score"] == 95
|
369
|
+
|
370
|
+
updated_wf = await session.get(Workflow, wf.id)
|
371
|
+
assert updated_wf is not None
|
372
|
+
assert updated_wf.result == {"message": "Test", "score": 95}
|
373
|
+
|
374
|
+
# Verify the correct provider method was called with right params
|
375
|
+
assert openai_mock.call_count == 1 # called once
|
376
|
+
args, kwargs = openai_mock.call_args
|
377
|
+
assert kwargs["output_type"] == OutputModel
|
378
|
+
messages = kwargs["messages"]
|
379
|
+
assert any(
|
380
|
+
isinstance(m, SystemMessage) and m.content == "Provide structured analysis"
|
381
|
+
for m in messages
|
382
|
+
)
|
383
|
+
assert any(
|
384
|
+
isinstance(m, UserMessage) and m.content == "Analyze: test structured input"
|
385
|
+
for m in messages
|
386
|
+
)
|
387
|
+
|
388
|
+
|
389
|
+
async def test_agent_with_input_validation(
|
390
|
+
session: AsyncSession, mock_get_agent_config, mock_providers
|
391
|
+
):
|
392
|
+
"""Test agent with input validation using a Pydantic model."""
|
393
|
+
openai_mock, anthropic_mock = mock_providers
|
394
|
+
|
395
|
+
# Create an agent with input validation
|
396
|
+
test_agent = Agent(
|
397
|
+
name="validated_input_agent",
|
398
|
+
system_prompt="Process validated input",
|
399
|
+
user_prompt="Text: {{ input.text }}, Value: {{ input.value }}",
|
400
|
+
input_type=InputModel,
|
401
|
+
model="openai:gpt-4.1",
|
402
|
+
)
|
403
|
+
|
404
|
+
# Define a workflow that uses the agent
|
405
|
+
@workflow()
|
406
|
+
async def validation_workflow(input_text: str, input_value: int):
|
407
|
+
result = await test_agent(
|
408
|
+
input_value=InputModel(text=input_text, value=input_value)
|
409
|
+
)
|
410
|
+
return result.output
|
411
|
+
|
412
|
+
# Start and execute the workflow
|
413
|
+
wf = await validation_workflow.start("test input", 42)
|
414
|
+
result = await execute(wf)
|
415
|
+
|
416
|
+
# Verify the result
|
417
|
+
assert result == "Mock LLM response"
|
418
|
+
|
419
|
+
# Verify the agent validates input
|
420
|
+
# Define a workflow missing the required 'value' parameter
|
421
|
+
@workflow()
|
422
|
+
async def invalid_workflow(input_text: str):
|
423
|
+
# This call should raise a validation error at runtime
|
424
|
+
# Ignore the type error to test validation
|
425
|
+
return await test_agent(input_value=input_text) # type: ignore
|
426
|
+
|
427
|
+
# Start the workflow - this doesn't execute the agent validation yet
|
428
|
+
invalid_wf = await invalid_workflow.start("missing value")
|
429
|
+
|
430
|
+
# Now actually execute the workflow, which should raise ValueError
|
431
|
+
with pytest.raises(ValueError):
|
432
|
+
await execute(invalid_wf)
|
433
|
+
|
434
|
+
|
435
|
+
async def test_agent_with_tools(
|
436
|
+
mock_providers,
|
437
|
+
client: PlanarTestClient,
|
438
|
+
app,
|
439
|
+
):
|
440
|
+
"""Test agent with tools for multi-turn conversations."""
|
441
|
+
openai_mock, anthropic_mock = mock_providers
|
442
|
+
|
443
|
+
# Define some tools
|
444
|
+
async def tool1(param: str) -> str:
|
445
|
+
"""Test tool 1"""
|
446
|
+
return f"Tool 1 result: {param}"
|
447
|
+
|
448
|
+
async def tool2(num: int) -> int:
|
449
|
+
"""Test tool 2"""
|
450
|
+
return num * 2
|
451
|
+
|
452
|
+
# Create an agent with tools
|
453
|
+
test_agent = Agent(
|
454
|
+
name="tools_agent",
|
455
|
+
system_prompt="Use tools to solve the problem",
|
456
|
+
user_prompt="Problem: {{ input }}",
|
457
|
+
tools=[tool1, tool2],
|
458
|
+
output_type=OutputModel,
|
459
|
+
max_turns=3,
|
460
|
+
model="anthropic:claude-3-sonnet", # Test the Anthropic provider this time
|
461
|
+
)
|
462
|
+
|
463
|
+
# then register it with app
|
464
|
+
app.register_agent(test_agent)
|
465
|
+
|
466
|
+
# Define a workflow that uses the agent
|
467
|
+
@workflow()
|
468
|
+
async def tools_workflow(problem: str):
|
469
|
+
result = await test_agent(input_value=problem)
|
470
|
+
return {"message": result.output.message, "score": result.output.score}
|
471
|
+
|
472
|
+
with patch(
|
473
|
+
"planar.ai.agent.get_agent_config",
|
474
|
+
AsyncMock(return_value=test_agent.to_config()),
|
475
|
+
):
|
476
|
+
# Start and execute the workflow
|
477
|
+
wf = await tools_workflow.start("complex problem")
|
478
|
+
result = await execute(wf)
|
479
|
+
|
480
|
+
# Verify the result
|
481
|
+
assert isinstance(result, dict)
|
482
|
+
assert result["message"] == "Multi-turn response"
|
483
|
+
assert result["score"] == 90
|
484
|
+
|
485
|
+
# Verify complete was called twice (once for tool call, once for final response)
|
486
|
+
assert anthropic_mock.call_count == 2
|
487
|
+
|
488
|
+
# First call should include tools
|
489
|
+
args, first_call_kwargs = anthropic_mock.call_args_list[0]
|
490
|
+
assert len(first_call_kwargs["tools"]) == 2
|
491
|
+
assert first_call_kwargs["output_type"] == OutputModel
|
492
|
+
|
493
|
+
response = await client.get(
|
494
|
+
f"/planar/v1/workflows/{wf.function_name}/runs/{wf.id}/steps"
|
495
|
+
)
|
496
|
+
data = response.json()
|
497
|
+
|
498
|
+
step = data["items"][0]
|
499
|
+
assert step["step_id"] == 1
|
500
|
+
assert step["function_name"] == "planar.ai.agent.Agent.run_step"
|
501
|
+
assert step["display_name"] == test_agent.name
|
502
|
+
|
503
|
+
|
504
|
+
async def test_config_override(session: AsyncSession, mock_providers):
|
505
|
+
"""Test that agent correctly applies configuration overrides."""
|
506
|
+
openai_mock, anthropic_mock = mock_providers
|
507
|
+
|
508
|
+
# Create a custom mock for agent_config with overrides
|
509
|
+
override_config = AgentConfig(
|
510
|
+
system_prompt="Overridden system prompt",
|
511
|
+
user_prompt="Overridden user prompt: {{ input }}",
|
512
|
+
model="anthropic:claude-3-opus", # Change from OpenAI to Anthropic
|
513
|
+
max_turns=5,
|
514
|
+
)
|
515
|
+
|
516
|
+
# Create an agent with defaults that will be overridden
|
517
|
+
test_agent = Agent(
|
518
|
+
name="override_agent",
|
519
|
+
system_prompt="Original system prompt",
|
520
|
+
user_prompt="Original user prompt: {{ input }}",
|
521
|
+
model="openai:gpt-4.1", # Start with OpenAI
|
522
|
+
max_turns=1,
|
523
|
+
)
|
524
|
+
|
525
|
+
@workflow()
|
526
|
+
async def override_workflow(input_text: str):
|
527
|
+
result = await test_agent(input_text)
|
528
|
+
return result.output
|
529
|
+
|
530
|
+
with patch(
|
531
|
+
"planar.ai.agent.get_agent_config",
|
532
|
+
AsyncMock(return_value=override_config),
|
533
|
+
) as mock_config:
|
534
|
+
wf = await override_workflow.start("override test")
|
535
|
+
result = await execute(wf)
|
536
|
+
|
537
|
+
# Verify the result
|
538
|
+
assert result == "Mock LLM response"
|
539
|
+
|
540
|
+
# Verify get_agent_config was called
|
541
|
+
assert mock_config.called
|
542
|
+
|
543
|
+
# Since we overrode to anthropic, that provider should be used
|
544
|
+
assert anthropic_mock.call_count == 1 # called once
|
545
|
+
assert openai_mock.call_count == 0 # not called
|
546
|
+
|
547
|
+
# Verify the messages include the overridden prompts
|
548
|
+
args, kwargs = anthropic_mock.call_args
|
549
|
+
messages = kwargs["messages"]
|
550
|
+
assert any(
|
551
|
+
isinstance(m, SystemMessage) and m.content == "Overridden system prompt"
|
552
|
+
for m in messages
|
553
|
+
)
|
554
|
+
assert any(
|
555
|
+
isinstance(m, UserMessage)
|
556
|
+
and m.content == "Overridden user prompt: override test"
|
557
|
+
for m in messages
|
558
|
+
)
|
559
|
+
|
560
|
+
|
561
|
+
async def test_agent_with_model_parameters(session: AsyncSession, mock_providers):
|
562
|
+
"""Test that an agent can be configured with model parameters."""
|
563
|
+
openai_mock, anthropic_mock = mock_providers
|
564
|
+
|
565
|
+
# Create an agent with model parameters
|
566
|
+
test_agent = Agent(
|
567
|
+
name="params_agent",
|
568
|
+
system_prompt="Test with parameters",
|
569
|
+
user_prompt="Input: {{ input }}",
|
570
|
+
model=OpenAI.gpt_4_1,
|
571
|
+
model_parameters={"temperature": 0.2, "top_p": 0.95},
|
572
|
+
)
|
573
|
+
|
574
|
+
# Define a workflow that uses the agent
|
575
|
+
@workflow()
|
576
|
+
async def params_workflow(input_text: str):
|
577
|
+
result = await test_agent(input_value=input_text)
|
578
|
+
return result.output
|
579
|
+
|
580
|
+
with patch(
|
581
|
+
"planar.ai.agent.get_agent_config",
|
582
|
+
AsyncMock(return_value=test_agent.to_config()),
|
583
|
+
):
|
584
|
+
# Start and execute the workflow
|
585
|
+
wf = await params_workflow.start("test input")
|
586
|
+
result = await execute(wf)
|
587
|
+
|
588
|
+
# Verify the result
|
589
|
+
assert result == "Mock LLM response"
|
590
|
+
|
591
|
+
# Check that model parameters were handled correctly
|
592
|
+
# (in a real implementation, this would affect the call to the LLM provider)
|
593
|
+
assert test_agent.model_parameters == {"temperature": 0.2, "top_p": 0.95}
|
594
|
+
|
595
|
+
# Verify the model parameters are passed to the provider
|
596
|
+
args, kwargs = openai_mock.call_args
|
597
|
+
assert "temperature" in kwargs.get("model_spec").parameters
|
598
|
+
assert kwargs.get("model_spec").parameters["temperature"] == 0.2
|
599
|
+
assert kwargs.get("model_spec").parameters["top_p"] == 0.95
|
600
|
+
|
601
|
+
|
602
|
+
async def test_tool_response_formatting(
|
603
|
+
session: AsyncSession, mock_get_agent_config, mock_providers
|
604
|
+
):
|
605
|
+
"""Test that tool responses are correctly formatted in multi-turn conversations."""
|
606
|
+
openai_mock, _ = mock_providers
|
607
|
+
|
608
|
+
# Define a tool that returns a specific response - must match name in mock
|
609
|
+
async def tool1(param: str) -> str:
|
610
|
+
"""Test tool with simple string return"""
|
611
|
+
return f"Tool result for: {param}"
|
612
|
+
|
613
|
+
# Create an agent with the tool
|
614
|
+
test_agent = Agent(
|
615
|
+
name="tool_response_agent",
|
616
|
+
system_prompt="Use tools to process the query",
|
617
|
+
user_prompt="Query: {{ input }}",
|
618
|
+
tools=[tool1], # Name matches what the mock will call
|
619
|
+
model="openai:gpt-4.1",
|
620
|
+
max_turns=3,
|
621
|
+
)
|
622
|
+
|
623
|
+
# Define a workflow using the agent
|
624
|
+
@workflow()
|
625
|
+
async def tool_workflow(query: str):
|
626
|
+
result = await test_agent(input_value=query)
|
627
|
+
return result.output
|
628
|
+
|
629
|
+
# Start and execute the workflow
|
630
|
+
wf = await tool_workflow.start("test query")
|
631
|
+
result = await execute(wf)
|
632
|
+
|
633
|
+
# Verify result
|
634
|
+
assert result == "Final tool response"
|
635
|
+
|
636
|
+
# Verify complete was called twice
|
637
|
+
assert openai_mock.call_count == 2
|
638
|
+
|
639
|
+
# Extract the messages from the second call to check for proper tool response formatting
|
640
|
+
args, second_call_kwargs = openai_mock.call_args_list[1]
|
641
|
+
messages = second_call_kwargs.get("messages")
|
642
|
+
|
643
|
+
# Check that there's a ToolMessage in the conversation
|
644
|
+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
645
|
+
assert len(tool_messages) == 1
|
646
|
+
|
647
|
+
# Verify the content of the tool message matches our tool's output
|
648
|
+
assert (
|
649
|
+
tool_messages[0].content is not None
|
650
|
+
and "Tool result for: test_param" in tool_messages[0].content
|
651
|
+
)
|
652
|
+
|
653
|
+
# Verify that the message was formatted using the format_tool_response method
|
654
|
+
assert tool_messages[0].tool_call_id is not None
|
655
|
+
|
656
|
+
|
657
|
+
async def test_structured_output_with_tools(
|
658
|
+
session: AsyncSession, mock_get_agent_config, mock_providers
|
659
|
+
):
|
660
|
+
"""Test that structured output works correctly with tool calling."""
|
661
|
+
openai_mock, anthropic_mock = mock_providers
|
662
|
+
|
663
|
+
# Define a tool function - must match name in mock
|
664
|
+
async def tool1(param: str) -> dict:
|
665
|
+
"""Fetch data for the given ID"""
|
666
|
+
return {"id": param, "value": f"data-{param}"}
|
667
|
+
|
668
|
+
# Create a test agent with structured output and tools
|
669
|
+
test_agent = Agent(
|
670
|
+
name="structured_tool_agent",
|
671
|
+
system_prompt="Process the input and return structured data",
|
672
|
+
user_prompt="Process: {{ input }}",
|
673
|
+
tools=[tool1],
|
674
|
+
output_type=OutputModel,
|
675
|
+
model="openai:gpt-4.1",
|
676
|
+
max_turns=3,
|
677
|
+
)
|
678
|
+
|
679
|
+
# Define workflow
|
680
|
+
@workflow()
|
681
|
+
async def structured_tool_workflow(data: str):
|
682
|
+
result = await test_agent(input_value=data)
|
683
|
+
return {"message": result.output.message, "score": result.output.score}
|
684
|
+
|
685
|
+
# Start and execute the workflow
|
686
|
+
wf = await structured_tool_workflow.start("test-data")
|
687
|
+
result = await execute(wf)
|
688
|
+
|
689
|
+
# Verify result structure
|
690
|
+
assert isinstance(result, dict)
|
691
|
+
assert result["message"] == "Multi-turn response"
|
692
|
+
assert result["score"] == 90
|
693
|
+
|
694
|
+
# Verify calls to complete
|
695
|
+
assert openai_mock.call_count == 2
|
696
|
+
|
697
|
+
# Check first call (should include tools and output_type)
|
698
|
+
args, first_call_kwargs = openai_mock.call_args_list[0]
|
699
|
+
assert first_call_kwargs["output_type"] == OutputModel
|
700
|
+
assert len(first_call_kwargs["tools"]) == 1
|
701
|
+
|
702
|
+
# Check second call after tool response
|
703
|
+
args, second_call_kwargs = openai_mock.call_args_list[1]
|
704
|
+
assert (
|
705
|
+
second_call_kwargs["output_type"] == OutputModel
|
706
|
+
) # Should still request structured output
|
707
|
+
|
708
|
+
# Verify messages in second call include the tool response
|
709
|
+
messages = second_call_kwargs["messages"]
|
710
|
+
assert any(isinstance(m, ToolMessage) for m in messages)
|
711
|
+
|
712
|
+
# Verify assistant message with tool calls is included
|
713
|
+
assistant_messages = [
|
714
|
+
m for m in messages if isinstance(m, AssistantMessage) and m.tool_calls
|
715
|
+
]
|
716
|
+
assert len(assistant_messages) == 1
|
717
|
+
|
718
|
+
|
719
|
+
async def test_tool_error_catching(
|
720
|
+
session: AsyncSession, mock_get_agent_config, mock_providers
|
721
|
+
):
|
722
|
+
"""Test that workflow can catch and handle errors from tool execution."""
|
723
|
+
openai_mock, anthropic_mock = mock_providers
|
724
|
+
|
725
|
+
# Define a tool that raises an exception - must match name in mock
|
726
|
+
async def tool1(param: str) -> str:
|
727
|
+
"""This tool always fails"""
|
728
|
+
raise ValueError(f"Tool error for: {param}")
|
729
|
+
|
730
|
+
# Create an agent with the failing tool
|
731
|
+
test_agent = Agent(
|
732
|
+
name="error_handling_agent",
|
733
|
+
system_prompt="Use tools to process this",
|
734
|
+
user_prompt="Process: {{ input }}",
|
735
|
+
tools=[tool1],
|
736
|
+
model="openai:gpt-4.1",
|
737
|
+
max_turns=3,
|
738
|
+
)
|
739
|
+
|
740
|
+
# Define a workflow that catches the error
|
741
|
+
@workflow()
|
742
|
+
async def error_handling_workflow(value: str):
|
743
|
+
try:
|
744
|
+
result = await test_agent(input_value=value)
|
745
|
+
return {"status": "success", "output": result.output}
|
746
|
+
except ValueError as e:
|
747
|
+
# Workflow catches the error and returns a graceful response
|
748
|
+
return {"status": "error", "message": str(e)}
|
749
|
+
|
750
|
+
# Start and execute the workflow
|
751
|
+
wf = await error_handling_workflow.start("test value")
|
752
|
+
result = await execute(wf)
|
753
|
+
|
754
|
+
# Verify the workflow caught the error
|
755
|
+
assert isinstance(result, dict) # Make sure result is a dictionary before indexing
|
756
|
+
assert result.get("status") == "error"
|
757
|
+
assert "Tool error for:" in result.get("message", "")
|
758
|
+
|
759
|
+
# Verify the API was called once to get the tool call
|
760
|
+
assert openai_mock.call_count == 1
|
761
|
+
|
762
|
+
|
763
|
+
def test_tool_validation():
|
764
|
+
"""Test that different types of functions are supported as tools."""
|
765
|
+
|
766
|
+
# Create some simple Pydantic models for reference
|
767
|
+
class ValidToolParams(BaseModel):
|
768
|
+
param: str
|
769
|
+
|
770
|
+
class UntypedToolParams(BaseModel):
|
771
|
+
param: Any
|
772
|
+
|
773
|
+
# Define a regular function - should work
|
774
|
+
async def valid_tool(param: str) -> str:
|
775
|
+
"""A valid tool function"""
|
776
|
+
return f"Result for {param}"
|
777
|
+
|
778
|
+
# This should succeed (not a bound method)
|
779
|
+
tool_def = create_tool_definition(valid_tool)
|
780
|
+
assert tool_def.name == "valid_tool"
|
781
|
+
assert tool_def.description == "A valid tool function"
|
782
|
+
|
783
|
+
# Verify parameter structure
|
784
|
+
tool_schema = tool_def.parameters
|
785
|
+
reference_schema = ValidToolParams.model_json_schema()
|
786
|
+
|
787
|
+
# Check required fields
|
788
|
+
assert tool_schema["required"] == reference_schema["required"]
|
789
|
+
# Check param is string type
|
790
|
+
assert tool_schema["properties"]["param"]["type"] == "string"
|
791
|
+
|
792
|
+
# Define a function without type annotations - should work
|
793
|
+
async def untyped_tool(param):
|
794
|
+
"""An untyped tool function"""
|
795
|
+
return f"Result for {param}"
|
796
|
+
|
797
|
+
# This should succeed with Any type in the schema
|
798
|
+
untyped_tool_def = create_tool_definition(untyped_tool)
|
799
|
+
assert untyped_tool_def.name == "untyped_tool"
|
800
|
+
assert untyped_tool_def.description == "An untyped tool function"
|
801
|
+
|
802
|
+
# Define a class with methods for testing different method types
|
803
|
+
class ToolOwner:
|
804
|
+
async def bound_method(self, param: str) -> str:
|
805
|
+
"""A bound instance method"""
|
806
|
+
return f"Result for {param}"
|
807
|
+
|
808
|
+
@staticmethod
|
809
|
+
async def static_method(param: str) -> str:
|
810
|
+
"""A static method"""
|
811
|
+
return f"Static result for {param}"
|
812
|
+
|
813
|
+
@classmethod
|
814
|
+
async def class_method(cls, param: str) -> str:
|
815
|
+
"""A class method"""
|
816
|
+
return f"Class result for {param}"
|
817
|
+
|
818
|
+
# Create an instance and get the bound method
|
819
|
+
owner = ToolOwner()
|
820
|
+
bound_method = owner.bound_method
|
821
|
+
|
822
|
+
# Test bound instance methods
|
823
|
+
bound_tool_def = create_tool_definition(bound_method)
|
824
|
+
assert bound_tool_def.name == "bound_method"
|
825
|
+
bound_schema = bound_tool_def.parameters
|
826
|
+
assert bound_schema["properties"]["param"]["type"] == "string"
|
827
|
+
|
828
|
+
# Test static methods
|
829
|
+
static_tool_def = create_tool_definition(ToolOwner.static_method)
|
830
|
+
assert static_tool_def.name == "static_method"
|
831
|
+
static_schema = static_tool_def.parameters
|
832
|
+
assert static_schema["properties"]["param"]["type"] == "string"
|
833
|
+
|
834
|
+
# Test class methods
|
835
|
+
class_tool_def = create_tool_definition(ToolOwner.class_method)
|
836
|
+
assert class_tool_def.name == "class_method"
|
837
|
+
class_schema = class_tool_def.parameters
|
838
|
+
assert class_schema["properties"]["param"]["type"] == "string"
|
839
|
+
|
840
|
+
|
841
|
+
# Common models for file-based tests
|
842
|
+
class ReceiptItem(BaseModel):
|
843
|
+
name: str = Field(description="Name of the item")
|
844
|
+
price: float | None = Field(description="Price of the item", default=None)
|
845
|
+
quantity: int | None = Field(description="Quantity of the item", default=None)
|
846
|
+
|
847
|
+
|
848
|
+
class ReceiptData(BaseModel):
|
849
|
+
merchant_name: str = Field(description="Name of the merchant/store")
|
850
|
+
date: str = Field(description="Date of the transaction")
|
851
|
+
total_amount: float = Field(description="Total amount of the transaction")
|
852
|
+
items: list[ReceiptItem] = Field(
|
853
|
+
description="List of items purchased with prices if available"
|
854
|
+
)
|
855
|
+
payment_method: str | None = Field(
|
856
|
+
description="Payment method if specified", default=None
|
857
|
+
)
|
858
|
+
receipt_number: str | None = Field(
|
859
|
+
description="Receipt number if available", default=None
|
860
|
+
)
|
861
|
+
document_type: str | None = Field(
|
862
|
+
description="Type of document (pdf or image)", default=None
|
863
|
+
)
|
864
|
+
|
865
|
+
|
866
|
+
@pytest.fixture
|
867
|
+
def planar_files():
|
868
|
+
"""Create PlanarFile instances for testing."""
|
869
|
+
image_file = PlanarFile(
|
870
|
+
id=uuid4(),
|
871
|
+
filename="receipt.jpg",
|
872
|
+
content_type="image/jpeg",
|
873
|
+
size=1024,
|
874
|
+
)
|
875
|
+
|
876
|
+
pdf_file = PlanarFile(
|
877
|
+
id=uuid4(),
|
878
|
+
filename="invoice.pdf",
|
879
|
+
content_type="application/pdf",
|
880
|
+
size=2048,
|
881
|
+
)
|
882
|
+
|
883
|
+
return {"image": image_file, "pdf": pdf_file}
|
884
|
+
|
885
|
+
|
886
|
+
async def test_agent_with_direct_planar_file(
|
887
|
+
session: AsyncSession, mock_get_agent_config, mock_providers, planar_files
|
888
|
+
):
|
889
|
+
"""Test agent with a PlanarFile in a Pydantic input model."""
|
890
|
+
openai_mock, anthropic_mock = mock_providers
|
891
|
+
image_file = planar_files["image"]
|
892
|
+
|
893
|
+
# Create an agent for receipt analysis
|
894
|
+
receipt_agent = Agent(
|
895
|
+
name="receipt_analyzer",
|
896
|
+
system_prompt="You are an expert receipt analyzer.",
|
897
|
+
user_prompt="Please analyze this receipt.",
|
898
|
+
output_type=ReceiptData,
|
899
|
+
input_type=PlanarFile,
|
900
|
+
model=OpenAI.gpt_4_1,
|
901
|
+
)
|
902
|
+
|
903
|
+
# Define a workflow using the agent
|
904
|
+
@workflow()
|
905
|
+
async def receipt_analysis_workflow(file: PlanarFile):
|
906
|
+
# Pass it as input_value
|
907
|
+
result = await receipt_agent(input_value=file)
|
908
|
+
return result.output
|
909
|
+
|
910
|
+
# Start and execute the workflow
|
911
|
+
wf = await receipt_analysis_workflow.start(image_file)
|
912
|
+
result = await execute(wf)
|
913
|
+
|
914
|
+
# Verify the result is the correct type
|
915
|
+
assert isinstance(result, ReceiptData)
|
916
|
+
|
917
|
+
# Verify the result structure
|
918
|
+
assert result.merchant_name == "Coffee Shop"
|
919
|
+
assert result.date == "2025-03-11"
|
920
|
+
assert result.total_amount == 42.99
|
921
|
+
assert result.document_type == "image" # Should detect it's an image
|
922
|
+
assert isinstance(result.items, list)
|
923
|
+
assert len(result.items) == 2
|
924
|
+
assert result.items[0].name == "Coffee"
|
925
|
+
assert result.items[0].price == 4.99
|
926
|
+
|
927
|
+
# Verify that provider's complete method was called once
|
928
|
+
assert openai_mock.call_count == 1
|
929
|
+
args, kwargs = openai_mock.call_args
|
930
|
+
|
931
|
+
# Files are passed in the messages, not directly as planar_files parameter
|
932
|
+
messages = kwargs.get("messages", [])
|
933
|
+
user_messages = [m for m in messages if isinstance(m, UserMessage)]
|
934
|
+
assert len(user_messages) == 1
|
935
|
+
assert user_messages[0].files is not None
|
936
|
+
assert len(user_messages[0].files) == 1
|
937
|
+
assert user_messages[0].files[0] == image_file
|
938
|
+
|
939
|
+
|
940
|
+
class DocumentInput(BaseModel):
|
941
|
+
"""Model with a single PlanarFile field."""
|
942
|
+
|
943
|
+
file: PlanarFile
|
944
|
+
instructions: str | None = None
|
945
|
+
|
946
|
+
|
947
|
+
async def test_agent_with_planar_file_in_model(
|
948
|
+
session: AsyncSession, mock_providers, planar_files
|
949
|
+
):
|
950
|
+
"""Test agent with a PlanarFile field in a Pydantic model."""
|
951
|
+
openai_mock, anthropic_mock = mock_providers
|
952
|
+
pdf_file = planar_files["pdf"]
|
953
|
+
|
954
|
+
# Create an agent for document analysis
|
955
|
+
document_agent = Agent(
|
956
|
+
name="document_analyzer",
|
957
|
+
system_prompt="You are an expert document analyzer. Extract all information from the document.",
|
958
|
+
user_prompt="Please analyze this document. {{ input.instructions }}",
|
959
|
+
output_type=ReceiptData,
|
960
|
+
model=OpenAI.gpt_4_1,
|
961
|
+
input_type=DocumentInput,
|
962
|
+
)
|
963
|
+
|
964
|
+
# Define a workflow using the agent
|
965
|
+
@workflow()
|
966
|
+
async def document_analysis_workflow(
|
967
|
+
file: PlanarFile, instructions: str | None = None
|
968
|
+
):
|
969
|
+
input_model = DocumentInput(file=file, instructions=instructions)
|
970
|
+
result = await document_agent(input_value=input_model)
|
971
|
+
return result.output
|
972
|
+
|
973
|
+
with patch(
|
974
|
+
"planar.ai.agent.get_agent_config",
|
975
|
+
AsyncMock(return_value=document_agent.to_config()),
|
976
|
+
):
|
977
|
+
# Start and execute the workflow with instructions
|
978
|
+
wf = await document_analysis_workflow.start(
|
979
|
+
pdf_file, instructions="Focus on payment details"
|
980
|
+
)
|
981
|
+
result = await execute(wf)
|
982
|
+
|
983
|
+
# Verify the result is the correct type
|
984
|
+
assert isinstance(result, ReceiptData)
|
985
|
+
|
986
|
+
# Verify the result structure
|
987
|
+
assert result.merchant_name == "Coffee Shop"
|
988
|
+
assert result.date == "2025-03-11"
|
989
|
+
assert result.total_amount == 42.99
|
990
|
+
assert result.document_type == "pdf" # Should detect it's a PDF
|
991
|
+
assert isinstance(result.items, list)
|
992
|
+
assert len(result.items) == 2
|
993
|
+
|
994
|
+
# Verify that provider's complete method was called once
|
995
|
+
assert openai_mock.call_count == 1
|
996
|
+
args, kwargs = openai_mock.call_args
|
997
|
+
|
998
|
+
# Files are passed in the messages, not directly as planar_files parameter
|
999
|
+
messages = kwargs.get("messages", [])
|
1000
|
+
user_messages = [m for m in messages if isinstance(m, UserMessage)]
|
1001
|
+
assert len(user_messages) == 1
|
1002
|
+
assert user_messages[0].files is not None
|
1003
|
+
assert len(user_messages[0].files) == 1
|
1004
|
+
assert user_messages[0].files[0] == pdf_file
|
1005
|
+
|
1006
|
+
# Verify the user prompt includes the instructions
|
1007
|
+
messages = kwargs.get("messages", [])
|
1008
|
+
user_messages = [m for m in messages if isinstance(m, UserMessage)]
|
1009
|
+
assert len(user_messages) == 1
|
1010
|
+
assert user_messages[0].content is not None
|
1011
|
+
assert (
|
1012
|
+
user_messages[0].content
|
1013
|
+
== "Please analyze this document. Focus on payment details"
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
|
1017
|
+
class MultiFileInput(BaseModel):
|
1018
|
+
"""Model with a list of PlanarFile field."""
|
1019
|
+
|
1020
|
+
files: list[PlanarFile]
|
1021
|
+
batch_name: str
|
1022
|
+
|
1023
|
+
|
1024
|
+
async def test_agent_with_planar_file_list(
|
1025
|
+
session: AsyncSession, mock_get_agent_config, mock_providers, planar_files
|
1026
|
+
):
|
1027
|
+
"""Test agent with a list of PlanarFile objects in a Pydantic model."""
|
1028
|
+
openai_mock, anthropic_mock = mock_providers
|
1029
|
+
image_file = planar_files["image"]
|
1030
|
+
pdf_file = planar_files["pdf"]
|
1031
|
+
|
1032
|
+
# Create an agent for batch document analysis
|
1033
|
+
batch_agent = Agent(
|
1034
|
+
name="batch_analyzer",
|
1035
|
+
system_prompt="You are a batch document processor. Analyze all provided files.",
|
1036
|
+
user_prompt="Process batch: {{ input.batch_name }}",
|
1037
|
+
output_type=str, # Just return a string description
|
1038
|
+
model=OpenAI.gpt_4_1,
|
1039
|
+
input_type=MultiFileInput,
|
1040
|
+
)
|
1041
|
+
|
1042
|
+
# Define a workflow using the agent
|
1043
|
+
@workflow()
|
1044
|
+
async def batch_analysis_workflow(files: list[PlanarFile], batch_name: str):
|
1045
|
+
# Create a model instance with the file list
|
1046
|
+
input_model = MultiFileInput(files=files, batch_name=batch_name)
|
1047
|
+
# Call the agent with the model as input_value
|
1048
|
+
result = await batch_agent(input_value=input_model)
|
1049
|
+
return result.output
|
1050
|
+
|
1051
|
+
with patch(
|
1052
|
+
"planar.ai.agent.get_agent_config",
|
1053
|
+
AsyncMock(return_value=batch_agent.to_config()),
|
1054
|
+
):
|
1055
|
+
# Start and execute the workflow with multiple files
|
1056
|
+
wf = await batch_analysis_workflow.start(
|
1057
|
+
[image_file, pdf_file], batch_name="Receipt and Invoice"
|
1058
|
+
)
|
1059
|
+
result = await execute(wf)
|
1060
|
+
|
1061
|
+
# Verify the result is a string
|
1062
|
+
assert isinstance(result, str)
|
1063
|
+
# Our mock may return either of these responses
|
1064
|
+
assert result in [
|
1065
|
+
"Description of the image content",
|
1066
|
+
"Description of the PDF document",
|
1067
|
+
"Mock LLM response",
|
1068
|
+
]
|
1069
|
+
|
1070
|
+
# Verify that provider's complete method was called once
|
1071
|
+
assert openai_mock.call_count == 1
|
1072
|
+
args, kwargs = openai_mock.call_args
|
1073
|
+
|
1074
|
+
messages = kwargs.get("messages", [])
|
1075
|
+
user_messages = [m for m in messages if isinstance(m, UserMessage)]
|
1076
|
+
assert len(user_messages) == 1
|
1077
|
+
assert user_messages[0].files is not None
|
1078
|
+
assert len(user_messages[0].files) == 2
|
1079
|
+
assert image_file in user_messages[0].files
|
1080
|
+
assert pdf_file in user_messages[0].files
|
1081
|
+
|
1082
|
+
# Verify the user prompt includes the batch name
|
1083
|
+
messages = kwargs.get("messages", [])
|
1084
|
+
user_messages = [m for m in messages if isinstance(m, UserMessage)]
|
1085
|
+
assert len(user_messages) == 1
|
1086
|
+
assert user_messages[0].content == "Process batch: Receipt and Invoice"
|
1087
|
+
|
1088
|
+
|
1089
|
+
def test_extract_files_from_model():
|
1090
|
+
"""Test that files are correctly extracted from Pydantic models."""
|
1091
|
+
image_file = PlanarFile(
|
1092
|
+
id=uuid4(),
|
1093
|
+
filename="test_image.jpg",
|
1094
|
+
content_type="image/jpeg",
|
1095
|
+
size=1024,
|
1096
|
+
)
|
1097
|
+
|
1098
|
+
pdf_file = PlanarFile(
|
1099
|
+
id=uuid4(),
|
1100
|
+
filename="test_document.pdf",
|
1101
|
+
content_type="application/pdf",
|
1102
|
+
size=2048,
|
1103
|
+
)
|
1104
|
+
|
1105
|
+
# Test model with PlanarFile directly
|
1106
|
+
files = extract_files_from_model(image_file)
|
1107
|
+
assert len(files) == 1
|
1108
|
+
assert files[0] == image_file
|
1109
|
+
|
1110
|
+
# Test model with PlanarFile as field
|
1111
|
+
class ModelWithFile(BaseModel):
|
1112
|
+
name: str
|
1113
|
+
description: str
|
1114
|
+
file: PlanarFile
|
1115
|
+
other_data: int
|
1116
|
+
|
1117
|
+
model_with_file = ModelWithFile(
|
1118
|
+
name="Test Model",
|
1119
|
+
description="A test model with a file",
|
1120
|
+
file=pdf_file,
|
1121
|
+
other_data=42,
|
1122
|
+
)
|
1123
|
+
|
1124
|
+
files = extract_files_from_model(model_with_file)
|
1125
|
+
assert len(files) == 1
|
1126
|
+
assert files[0] == pdf_file
|
1127
|
+
|
1128
|
+
# Test model with list of PlanarFile objects
|
1129
|
+
class ModelWithFileList(BaseModel):
|
1130
|
+
name: str
|
1131
|
+
files: list[PlanarFile]
|
1132
|
+
|
1133
|
+
model_with_file_list = ModelWithFileList(
|
1134
|
+
name="Test Model with File List",
|
1135
|
+
files=[image_file, pdf_file],
|
1136
|
+
)
|
1137
|
+
|
1138
|
+
files = extract_files_from_model(model_with_file_list)
|
1139
|
+
assert len(files) == 2
|
1140
|
+
assert image_file in files
|
1141
|
+
assert pdf_file in files
|
1142
|
+
|
1143
|
+
# Test mixed list with non-PlanarFile items
|
1144
|
+
class ModelWithMixedList(BaseModel):
|
1145
|
+
name: str
|
1146
|
+
items: list
|
1147
|
+
|
1148
|
+
model_with_mixed_list = ModelWithMixedList(
|
1149
|
+
name="Test Model with Mixed List",
|
1150
|
+
items=[image_file, "not a file", 123, pdf_file],
|
1151
|
+
)
|
1152
|
+
|
1153
|
+
files = extract_files_from_model(model_with_mixed_list)
|
1154
|
+
assert len(files) == 2
|
1155
|
+
assert image_file in files
|
1156
|
+
assert pdf_file in files
|
1157
|
+
|
1158
|
+
# Test model with no files
|
1159
|
+
class ModelWithoutFiles(BaseModel):
|
1160
|
+
name: str
|
1161
|
+
value: int
|
1162
|
+
|
1163
|
+
model_without_files = ModelWithoutFiles(name="No Files", value=42)
|
1164
|
+
files = extract_files_from_model(model_without_files)
|
1165
|
+
assert len(files) == 0
|
1166
|
+
|
1167
|
+
files = extract_files_from_model("test string")
|
1168
|
+
assert len(files) == 0
|
1169
|
+
|
1170
|
+
# Test nested BaseModel structure with PlanarFile
|
1171
|
+
class NestedModel(BaseModel):
|
1172
|
+
description: str
|
1173
|
+
file: PlanarFile
|
1174
|
+
|
1175
|
+
nested_model = NestedModel(
|
1176
|
+
description="A nested model with a file",
|
1177
|
+
file=image_file,
|
1178
|
+
)
|
1179
|
+
|
1180
|
+
class AnotherNestedModel(BaseModel):
|
1181
|
+
data: str
|
1182
|
+
files: list[PlanarFile]
|
1183
|
+
|
1184
|
+
class ComplexModel(BaseModel):
|
1185
|
+
name: str
|
1186
|
+
first_nested: NestedModel
|
1187
|
+
second_nested: AnotherNestedModel
|
1188
|
+
|
1189
|
+
another_nested = AnotherNestedModel(
|
1190
|
+
data="Some data",
|
1191
|
+
files=[pdf_file],
|
1192
|
+
)
|
1193
|
+
|
1194
|
+
complex_model = ComplexModel(
|
1195
|
+
name="Complex Model",
|
1196
|
+
first_nested=nested_model,
|
1197
|
+
second_nested=another_nested,
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
files = extract_files_from_model(complex_model)
|
1201
|
+
assert len(files) == 2
|
1202
|
+
assert image_file in files
|
1203
|
+
assert pdf_file in files
|
1204
|
+
|
1205
|
+
|
1206
|
+
def test_tool_parameter_serialization():
|
1207
|
+
"""Test that tool parameters are correctly serialized to JSON schema."""
|
1208
|
+
|
1209
|
+
# Create a reference Pydantic model with various parameter types
|
1210
|
+
class ComplexToolParams(BaseModel):
|
1211
|
+
str_param: str
|
1212
|
+
int_param: int
|
1213
|
+
float_param: float
|
1214
|
+
bool_param: bool
|
1215
|
+
list_param: list[str]
|
1216
|
+
dict_param: dict[str, int]
|
1217
|
+
union_param: str | int
|
1218
|
+
optional_param: str | None = None
|
1219
|
+
untyped_param: Any = None
|
1220
|
+
|
1221
|
+
# Define a function with various parameter types
|
1222
|
+
async def complex_tool(
|
1223
|
+
str_param: str,
|
1224
|
+
int_param: int,
|
1225
|
+
float_param: float,
|
1226
|
+
bool_param: bool,
|
1227
|
+
list_param: list[str],
|
1228
|
+
dict_param: dict[str, int],
|
1229
|
+
union_param: str | int,
|
1230
|
+
annotated_param: str = Field(description="This is an annotated parameter"),
|
1231
|
+
optional_param: str | None = None,
|
1232
|
+
complex_param: ComplexToolParams = Field(
|
1233
|
+
description="A complex parameter with various types"
|
1234
|
+
),
|
1235
|
+
untyped_param=None,
|
1236
|
+
) -> dict[str, Any]:
|
1237
|
+
"""A tool with various parameter types"""
|
1238
|
+
return {"result": "success"}
|
1239
|
+
|
1240
|
+
# Create tool definition
|
1241
|
+
tool_def = create_tool_definition(complex_tool)
|
1242
|
+
|
1243
|
+
# Verify basic tool properties
|
1244
|
+
assert tool_def.name == "complex_tool"
|
1245
|
+
assert tool_def.description == "A tool with various parameter types"
|
1246
|
+
|
1247
|
+
# Get schema from the tool parameters
|
1248
|
+
tool_schema = tool_def.parameters
|
1249
|
+
|
1250
|
+
# Verify schema structure
|
1251
|
+
assert "properties" in tool_schema
|
1252
|
+
assert "required" in tool_schema
|
1253
|
+
|
1254
|
+
# Verify parameter types are correctly mapped
|
1255
|
+
props = tool_schema["properties"]
|
1256
|
+
assert props["str_param"]["type"] == "string"
|
1257
|
+
assert props["int_param"]["type"] == "integer"
|
1258
|
+
assert props["float_param"]["type"] == "number"
|
1259
|
+
assert props["bool_param"]["type"] == "boolean"
|
1260
|
+
assert props["list_param"]["type"] == "array"
|
1261
|
+
assert props["list_param"]["items"]["type"] == "string"
|
1262
|
+
assert props["dict_param"]["type"] == "object"
|
1263
|
+
assert props["dict_param"]["additionalProperties"]["type"] == "integer"
|
1264
|
+
assert props["union_param"]["anyOf"][0]["type"] == "string"
|
1265
|
+
assert props["union_param"]["anyOf"][1]["type"] == "integer"
|
1266
|
+
assert props["annotated_param"]["type"] == "string"
|
1267
|
+
assert props["annotated_param"]["description"] == "This is an annotated parameter"
|
1268
|
+
assert props["complex_param"]["$ref"] == "#/$defs/ComplexToolParams"
|
1269
|
+
assert (
|
1270
|
+
tool_schema["$defs"]["ComplexToolParams"]
|
1271
|
+
== ComplexToolParams.model_json_schema()
|
1272
|
+
)
|
1273
|
+
|
1274
|
+
# Verify required parameters
|
1275
|
+
required = tool_schema["required"]
|
1276
|
+
assert "str_param" in required
|
1277
|
+
assert "int_param" in required
|
1278
|
+
assert "float_param" in required
|
1279
|
+
assert "bool_param" in required
|
1280
|
+
assert "list_param" in required
|
1281
|
+
assert "dict_param" in required
|
1282
|
+
assert "union_param" in required
|
1283
|
+
assert "optional_param" not in required # Has default value
|
1284
|
+
assert "untyped_param" not in required # Has default value
|
1285
|
+
|
1286
|
+
# Now we should be able to fully serialize the ToolDefinition
|
1287
|
+
parsed = tool_def.model_dump(mode="json")
|
1288
|
+
|
1289
|
+
# Verify the JSON structure is valid
|
1290
|
+
assert "name" in parsed
|
1291
|
+
assert "parameters" in parsed
|
1292
|
+
assert isinstance(parsed["parameters"], dict)
|
1293
|
+
|
1294
|
+
# Verify parameters were converted to JSON schema
|
1295
|
+
assert "properties" in parsed["parameters"]
|
1296
|
+
assert "required" in parsed["parameters"]
|
1297
|
+
assert parsed["name"] == "complex_tool"
|
1298
|
+
assert parsed["parameters"]["title"] == "Complex_toolParameters"
|