planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +155 -283
  3. planar/ai/agent_base.py +170 -0
  4. planar/ai/agent_utils.py +7 -0
  5. planar/ai/pydantic_ai.py +638 -0
  6. planar/ai/test_agent_serialization.py +1 -1
  7. planar/app.py +64 -20
  8. planar/cli.py +39 -27
  9. planar/config.py +45 -36
  10. planar/db/db.py +2 -1
  11. planar/files/storage/azure_blob.py +343 -0
  12. planar/files/storage/base.py +7 -0
  13. planar/files/storage/config.py +70 -7
  14. planar/files/storage/s3.py +6 -6
  15. planar/files/storage/test_azure_blob.py +435 -0
  16. planar/logging/formatter.py +17 -4
  17. planar/logging/test_formatter.py +327 -0
  18. planar/registry_items.py +2 -1
  19. planar/routers/agents_router.py +3 -1
  20. planar/routers/files.py +11 -2
  21. planar/routers/models.py +14 -1
  22. planar/routers/test_agents_router.py +1 -1
  23. planar/routers/test_files_router.py +49 -0
  24. planar/routers/test_routes_security.py +5 -7
  25. planar/routers/test_workflow_router.py +270 -3
  26. planar/routers/workflow.py +95 -36
  27. planar/rules/models.py +36 -39
  28. planar/rules/test_data/account_dormancy_management.json +223 -0
  29. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  30. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  31. planar/rules/test_data/booking_fraud_detection.json +407 -0
  32. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  33. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  34. planar/rules/test_data/customer_lifetime_value.json +143 -0
  35. planar/rules/test_data/import_duties_calculator.json +289 -0
  36. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  37. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  38. planar/rules/test_data/order_consolidation_system.json +375 -0
  39. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  40. planar/rules/test_data/supply_chain_risk.json +253 -0
  41. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  42. planar/rules/test_rules.py +750 -6
  43. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  44. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  45. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  46. planar/security/auth_context.py +21 -0
  47. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  48. planar/security/authorization.py +9 -15
  49. planar/security/tests/test_auth_middleware.py +162 -0
  50. planar/sse/proxy.py +4 -9
  51. planar/test_app.py +92 -1
  52. planar/test_cli.py +81 -59
  53. planar/test_config.py +17 -14
  54. planar/testing/fixtures.py +325 -0
  55. planar/testing/planar_test_client.py +5 -2
  56. planar/utils.py +41 -1
  57. planar/workflows/execution.py +1 -1
  58. planar/workflows/orchestrator.py +5 -0
  59. planar/workflows/serialization.py +12 -6
  60. planar/workflows/step_core.py +3 -1
  61. planar/workflows/test_serialization.py +9 -1
  62. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
  63. planar-0.8.0.dist-info/RECORD +166 -0
  64. planar/.__init__.py.un~ +0 -0
  65. planar/._version.py.un~ +0 -0
  66. planar/.app.py.un~ +0 -0
  67. planar/.cli.py.un~ +0 -0
  68. planar/.config.py.un~ +0 -0
  69. planar/.context.py.un~ +0 -0
  70. planar/.db.py.un~ +0 -0
  71. planar/.di.py.un~ +0 -0
  72. planar/.engine.py.un~ +0 -0
  73. planar/.files.py.un~ +0 -0
  74. planar/.log_context.py.un~ +0 -0
  75. planar/.log_metadata.py.un~ +0 -0
  76. planar/.logging.py.un~ +0 -0
  77. planar/.object_registry.py.un~ +0 -0
  78. planar/.otel.py.un~ +0 -0
  79. planar/.server.py.un~ +0 -0
  80. planar/.session.py.un~ +0 -0
  81. planar/.sqlalchemy.py.un~ +0 -0
  82. planar/.task_local.py.un~ +0 -0
  83. planar/.test_app.py.un~ +0 -0
  84. planar/.test_config.py.un~ +0 -0
  85. planar/.test_object_config.py.un~ +0 -0
  86. planar/.test_sqlalchemy.py.un~ +0 -0
  87. planar/.test_utils.py.un~ +0 -0
  88. planar/.util.py.un~ +0 -0
  89. planar/.utils.py.un~ +0 -0
  90. planar/ai/.__init__.py.un~ +0 -0
  91. planar/ai/._models.py.un~ +0 -0
  92. planar/ai/.agent.py.un~ +0 -0
  93. planar/ai/.agent_utils.py.un~ +0 -0
  94. planar/ai/.events.py.un~ +0 -0
  95. planar/ai/.files.py.un~ +0 -0
  96. planar/ai/.models.py.un~ +0 -0
  97. planar/ai/.providers.py.un~ +0 -0
  98. planar/ai/.pydantic_ai.py.un~ +0 -0
  99. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  100. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  101. planar/ai/.step.py.un~ +0 -0
  102. planar/ai/.test_agent.py.un~ +0 -0
  103. planar/ai/.test_agent_serialization.py.un~ +0 -0
  104. planar/ai/.test_providers.py.un~ +0 -0
  105. planar/ai/.utils.py.un~ +0 -0
  106. planar/ai/providers.py +0 -1088
  107. planar/ai/test_agent.py +0 -1298
  108. planar/ai/test_providers.py +0 -463
  109. planar/db/.db.py.un~ +0 -0
  110. planar/files/.config.py.un~ +0 -0
  111. planar/files/.local.py.un~ +0 -0
  112. planar/files/.local_filesystem.py.un~ +0 -0
  113. planar/files/.model.py.un~ +0 -0
  114. planar/files/.models.py.un~ +0 -0
  115. planar/files/.s3.py.un~ +0 -0
  116. planar/files/.storage.py.un~ +0 -0
  117. planar/files/.test_files.py.un~ +0 -0
  118. planar/files/storage/.__init__.py.un~ +0 -0
  119. planar/files/storage/.base.py.un~ +0 -0
  120. planar/files/storage/.config.py.un~ +0 -0
  121. planar/files/storage/.context.py.un~ +0 -0
  122. planar/files/storage/.local_directory.py.un~ +0 -0
  123. planar/files/storage/.test_local_directory.py.un~ +0 -0
  124. planar/files/storage/.test_s3.py.un~ +0 -0
  125. planar/human/.human.py.un~ +0 -0
  126. planar/human/.test_human.py.un~ +0 -0
  127. planar/logging/.__init__.py.un~ +0 -0
  128. planar/logging/.attributes.py.un~ +0 -0
  129. planar/logging/.formatter.py.un~ +0 -0
  130. planar/logging/.logger.py.un~ +0 -0
  131. planar/logging/.otel.py.un~ +0 -0
  132. planar/logging/.tracer.py.un~ +0 -0
  133. planar/modeling/.mixin.py.un~ +0 -0
  134. planar/modeling/.storage.py.un~ +0 -0
  135. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  136. planar/object_config/.object_config.py.un~ +0 -0
  137. planar/routers/.__init__.py.un~ +0 -0
  138. planar/routers/.agents_router.py.un~ +0 -0
  139. planar/routers/.crud.py.un~ +0 -0
  140. planar/routers/.decision.py.un~ +0 -0
  141. planar/routers/.event.py.un~ +0 -0
  142. planar/routers/.file_attachment.py.un~ +0 -0
  143. planar/routers/.files.py.un~ +0 -0
  144. planar/routers/.files_router.py.un~ +0 -0
  145. planar/routers/.human.py.un~ +0 -0
  146. planar/routers/.info.py.un~ +0 -0
  147. planar/routers/.models.py.un~ +0 -0
  148. planar/routers/.object_config_router.py.un~ +0 -0
  149. planar/routers/.rule.py.un~ +0 -0
  150. planar/routers/.test_object_config_router.py.un~ +0 -0
  151. planar/routers/.test_workflow_router.py.un~ +0 -0
  152. planar/routers/.workflow.py.un~ +0 -0
  153. planar/rules/.decorator.py.un~ +0 -0
  154. planar/rules/.runner.py.un~ +0 -0
  155. planar/rules/.test_rules.py.un~ +0 -0
  156. planar/security/.jwt_middleware.py.un~ +0 -0
  157. planar/sse/.constants.py.un~ +0 -0
  158. planar/sse/.example.html.un~ +0 -0
  159. planar/sse/.hub.py.un~ +0 -0
  160. planar/sse/.model.py.un~ +0 -0
  161. planar/sse/.proxy.py.un~ +0 -0
  162. planar/testing/.client.py.un~ +0 -0
  163. planar/testing/.memory_storage.py.un~ +0 -0
  164. planar/testing/.planar_test_client.py.un~ +0 -0
  165. planar/testing/.predictable_tracer.py.un~ +0 -0
  166. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  167. planar/testing/.test_memory_storage.py.un~ +0 -0
  168. planar/testing/.workflow_observer.py.un~ +0 -0
  169. planar/workflows/.__init__.py.un~ +0 -0
  170. planar/workflows/.builtin_steps.py.un~ +0 -0
  171. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  172. planar/workflows/.context.py.un~ +0 -0
  173. planar/workflows/.contrib.py.un~ +0 -0
  174. planar/workflows/.decorators.py.un~ +0 -0
  175. planar/workflows/.durable_test.py.un~ +0 -0
  176. planar/workflows/.errors.py.un~ +0 -0
  177. planar/workflows/.events.py.un~ +0 -0
  178. planar/workflows/.exceptions.py.un~ +0 -0
  179. planar/workflows/.execution.py.un~ +0 -0
  180. planar/workflows/.human.py.un~ +0 -0
  181. planar/workflows/.lock.py.un~ +0 -0
  182. planar/workflows/.misc.py.un~ +0 -0
  183. planar/workflows/.model.py.un~ +0 -0
  184. planar/workflows/.models.py.un~ +0 -0
  185. planar/workflows/.notifications.py.un~ +0 -0
  186. planar/workflows/.orchestrator.py.un~ +0 -0
  187. planar/workflows/.runtime.py.un~ +0 -0
  188. planar/workflows/.serialization.py.un~ +0 -0
  189. planar/workflows/.step.py.un~ +0 -0
  190. planar/workflows/.step_core.py.un~ +0 -0
  191. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  192. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  193. planar/workflows/.test_concurrency.py.un~ +0 -0
  194. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  195. planar/workflows/.test_human.py.un~ +0 -0
  196. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  197. planar/workflows/.test_orchestrator.py.un~ +0 -0
  198. planar/workflows/.test_race_conditions.py.un~ +0 -0
  199. planar/workflows/.test_serialization.py.un~ +0 -0
  200. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  201. planar/workflows/.test_workflow.py.un~ +0 -0
  202. planar/workflows/.tracing.py.un~ +0 -0
  203. planar/workflows/.types.py.un~ +0 -0
  204. planar/workflows/.util.py.un~ +0 -0
  205. planar/workflows/.utils.py.un~ +0 -0
  206. planar/workflows/.workflow.py.un~ +0 -0
  207. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  208. planar/workflows/.wrappers.py.un~ +0 -0
  209. planar-0.5.0.dist-info/RECORD +0 -289
  210. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
  211. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
planar/ai/test_agent.py DELETED
@@ -1,1298 +0,0 @@
1
- from datetime import timedelta
2
- from typing import Any, cast
3
- from unittest.mock import AsyncMock, Mock, patch
4
- from uuid import uuid4
5
-
6
- import pytest
7
- from pydantic import BaseModel, Field
8
- from sqlmodel.ext.asyncio.session import AsyncSession
9
-
10
- from planar.ai import Agent
11
- from planar.ai.agent import (
12
- AgentRunResult,
13
- )
14
- from planar.ai.agent_utils import create_tool_definition, extract_files_from_model
15
- from planar.ai.models import (
16
- AgentConfig,
17
- AssistantMessage,
18
- CompletionResponse,
19
- SystemMessage,
20
- ToolCall,
21
- ToolMessage,
22
- UserMessage,
23
- )
24
- from planar.ai.providers import OpenAI
25
- from planar.app import PlanarApp
26
- from planar.config import sqlite_config
27
- from planar.files.models import PlanarFile
28
- from planar.testing.planar_test_client import PlanarTestClient
29
- from planar.workflows.decorators import workflow
30
- from planar.workflows.execution import execute
31
- from planar.workflows.models import Workflow
32
- from planar.workflows.step_core import Suspend, suspend
33
-
34
- app = PlanarApp(
35
- config=sqlite_config(":memory:"),
36
- title="Planar app for testing agents",
37
- description="Testing",
38
- )
39
-
40
-
41
- @pytest.fixture(name="app")
42
- def app_fixture():
43
- yield app
44
-
45
-
46
- # Test data and models (not test classes themselves)
47
- # Using different names to avoid pytest collection warnings
48
- class InputModel(BaseModel):
49
- text: str
50
- value: int
51
-
52
-
53
- class OutputModel(BaseModel):
54
- message: str
55
- score: int
56
-
57
-
58
- # Mock data for receipt analysis tests
59
- MOCK_RECEIPT_DATA = {
60
- "merchant_name": "Coffee Shop",
61
- "date": "2025-03-11",
62
- "total_amount": 42.99,
63
- "items": [
64
- {"name": "Coffee", "price": 4.99, "quantity": 2},
65
- {"name": "Pastry", "price": 3.99, "quantity": 1},
66
- ],
67
- "payment_method": "Credit Card",
68
- "receipt_number": "R-123456",
69
- }
70
-
71
-
72
- @pytest.fixture
73
- def mock_providers():
74
- """Mock both OpenAI and Anthropic providers to return test responses."""
75
-
76
- # Create a factory to produce provider mocks with consistent tracking
77
- def create_provider_mock():
78
- mock = Mock()
79
- mock.call_count = 0
80
- return mock
81
-
82
- # Create mocks for each provider
83
- provider_mocks = {
84
- "openai": create_provider_mock(),
85
- "anthropic": create_provider_mock(),
86
- }
87
-
88
- # Shared mock response generator
89
- async def generate_response(
90
- output_type=None, tools=None, planar_files=None, is_first_call=True
91
- ):
92
- """Generate appropriate mock responses based on request parameters"""
93
- # Tool-based multi-turn conversation
94
- if tools:
95
- if is_first_call:
96
- return CompletionResponse(
97
- content=None,
98
- tool_calls=[
99
- cast(
100
- ToolCall,
101
- {
102
- "id": "call_1",
103
- "name": "tool1",
104
- "arguments": {"param": "test_param"},
105
- },
106
- )
107
- ],
108
- )
109
- elif output_type == OutputModel:
110
- return CompletionResponse(
111
- content=OutputModel(message="Multi-turn response", score=90),
112
- tool_calls=None,
113
- )
114
- else:
115
- return CompletionResponse(
116
- content="Final tool response",
117
- tool_calls=None,
118
- )
119
-
120
- # Planar file processing
121
- elif planar_files:
122
- if output_type and issubclass(output_type, BaseModel):
123
- # If a specific output model is requested, return a predetermined mock instance
124
- if output_type == OutputModel:
125
- return CompletionResponse(
126
- content=OutputModel(message="Analyzed file content", score=98),
127
- tool_calls=None,
128
- )
129
- else:
130
- # Check file content type for different response types
131
- file_type = None
132
- if len(planar_files) > 0:
133
- file_type = planar_files[0].content_type
134
-
135
- # Generate mock response based on file type
136
- if file_type == "application/pdf":
137
- mock_data = {**MOCK_RECEIPT_DATA, "document_type": "pdf"}
138
- else: # Image types
139
- mock_data = {**MOCK_RECEIPT_DATA, "document_type": "image"}
140
-
141
- # Only include fields that exist in the model
142
- filtered_data = {
143
- k: v
144
- for k, v in mock_data.items()
145
- if k in output_type.model_fields
146
- }
147
-
148
- return CompletionResponse(
149
- content=output_type.model_validate(filtered_data),
150
- tool_calls=None,
151
- )
152
- else:
153
- file_type = planar_files[0].content_type if planar_files else None
154
- if file_type == "application/pdf":
155
- return CompletionResponse(
156
- content="Description of the PDF document",
157
- tool_calls=None,
158
- )
159
- else:
160
- return CompletionResponse(
161
- content="Description of the image content",
162
- tool_calls=None,
163
- )
164
-
165
- # Structured output (single turn)
166
- elif output_type == OutputModel:
167
- return CompletionResponse(
168
- content=OutputModel(message="Test", score=95),
169
- tool_calls=None,
170
- )
171
-
172
- # Default simple response
173
- else:
174
- return CompletionResponse(
175
- content="Mock LLM response",
176
- tool_calls=None,
177
- )
178
-
179
- # Create a factory function for patched provider methods
180
- def create_provider_patch(provider_key):
181
- """Create patched complete method for the specified provider"""
182
-
183
- async def patched_complete(*args, **kwargs):
184
- # Get the provider's mock
185
- mock = provider_mocks[provider_key]
186
-
187
- # Update call tracking
188
- mock.call_count += 1
189
- mock.call_args = (args, kwargs)
190
- mock.call_args_list.append(cast(Any, (args, kwargs)))
191
-
192
- messages = kwargs.get("messages", [])
193
- planar_files = None
194
- for msg in messages:
195
- if isinstance(msg, UserMessage) and msg.files:
196
- planar_files = msg.files
197
- break
198
-
199
- # Generate appropriate response
200
- return await generate_response(
201
- output_type=kwargs.get("output_type"),
202
- tools=kwargs.get("tools"),
203
- planar_files=planar_files,
204
- is_first_call=(mock.call_count == 1),
205
- )
206
-
207
- return patched_complete
208
-
209
- # Apply patches
210
- with (
211
- patch(
212
- "planar.ai.providers.OpenAIProvider.complete",
213
- create_provider_patch("openai"),
214
- ),
215
- patch(
216
- "planar.ai.providers.AnthropicProvider.complete",
217
- create_provider_patch("anthropic"),
218
- ),
219
- ):
220
- yield (provider_mocks["openai"], provider_mocks["anthropic"])
221
-
222
-
223
- DEFAULT_CONFIG = AgentConfig(
224
- system_prompt="Default system prompt",
225
- user_prompt="Default user prompt: {{ input }}",
226
- model="openai:gpt-4.1",
227
- max_turns=3,
228
- )
229
-
230
-
231
- @pytest.fixture
232
- def mock_get_agent_config():
233
- """Mock the get_agent_config function to return empty config by default."""
234
- mock = AsyncMock(return_value=DEFAULT_CONFIG)
235
- with patch("planar.ai.agent.get_agent_config", mock):
236
- yield mock
237
-
238
-
239
- def test_agent_initialization():
240
- """Test that the Agent class initializes with correct parameters."""
241
- agent = Agent(
242
- name="test_agent",
243
- system_prompt="Test system prompt: {{ param1 }}",
244
- user_prompt="Test user prompt: {{ param2 }}",
245
- model="test:model",
246
- max_turns=3,
247
- )
248
-
249
- # Verify initialization
250
- assert agent.name == "test_agent"
251
- assert agent.system_prompt == "Test system prompt: {{ param1 }}"
252
- assert agent.user_prompt == "Test user prompt: {{ param2 }}"
253
- assert agent.model == "test:model"
254
- assert agent.max_turns == 3
255
- assert agent.tools == []
256
- assert agent.input_type is None
257
- assert agent.output_type is None
258
- assert agent.model_parameters == {}
259
-
260
-
261
- async def test_agent_call_simple(session: AsyncSession, mock_providers):
262
- """Test that an agent can be called in a workflow for a simple string response."""
263
- openai_mock, anthropic_mock = mock_providers
264
-
265
- # Create an agent
266
- test_agent = Agent(
267
- name="test_agent",
268
- system_prompt="Process this request",
269
- user_prompt="Input: {{ input }}",
270
- model="openai:gpt-4.1", # Using a real provider name
271
- )
272
-
273
- # Define a workflow that uses the agent
274
- @workflow()
275
- async def test_workflow(input_text: str):
276
- result = await test_agent(input_value=input_text)
277
- assert isinstance(result, AgentRunResult)
278
- return result.output
279
-
280
- with patch(
281
- "planar.ai.agent.get_agent_config",
282
- AsyncMock(return_value=test_agent.to_config()),
283
- ) as mock_config:
284
- # Start and execute the workflow
285
- wf = await test_workflow.start("test input")
286
- result = await execute(wf)
287
-
288
- # Verify the result
289
- assert result == "Mock LLM response"
290
-
291
- # Verify the workflow completed successfully
292
- updated_wf = await session.get(Workflow, wf.id)
293
- assert updated_wf is not None
294
- assert updated_wf.result == "Mock LLM response"
295
-
296
- # Verify get_agent_config was called with the agent name
297
- assert mock_config.called
298
-
299
- # Verify complete was called with the formatted messages
300
- assert openai_mock.call_count == 1 # called once
301
- args, kwargs = openai_mock.call_args
302
- messages = kwargs.get("messages")
303
- assert any(
304
- isinstance(m, SystemMessage) and m.content == "Process this request"
305
- for m in messages
306
- )
307
- assert any(
308
- isinstance(m, UserMessage) and m.content == "Input: test input"
309
- for m in messages
310
- )
311
-
312
-
313
- async def test_prompt_injection_protection(session: AsyncSession, mock_providers):
314
- """Ensure unsafe template expressions raise an error before model call."""
315
- openai_mock, _ = mock_providers
316
-
317
- inj_agent = Agent(
318
- name="inj_agent",
319
- system_prompt="Hi",
320
- user_prompt="{{ input.__class__.__mro__[1] }}",
321
- )
322
-
323
- @workflow()
324
- async def inj_workflow(text: str):
325
- return await inj_agent(text)
326
-
327
- with patch(
328
- "planar.ai.agent.get_agent_config",
329
- AsyncMock(return_value=inj_agent.to_config()),
330
- ):
331
- wf = await inj_workflow.start("test")
332
- with pytest.raises(ValueError):
333
- await execute(wf)
334
-
335
- assert openai_mock.call_count == 0
336
-
337
-
338
- async def test_agent_with_structured_output(session: AsyncSession, mock_providers):
339
- """Test agent with structured output using a Pydantic model."""
340
- openai_mock, anthropic_mock = mock_providers
341
-
342
- # Create an agent with structured output
343
- test_agent = Agent(
344
- name="structured_agent",
345
- system_prompt="Provide structured analysis",
346
- user_prompt="Analyze: {{ input }}",
347
- output_type=OutputModel,
348
- model="openai:gpt-4.1",
349
- )
350
-
351
- @workflow()
352
- async def structured_workflow(input_text: str):
353
- result = await test_agent(input_value=input_text)
354
- await suspend(interval=timedelta(seconds=0.1))
355
- return {"message": result.output.message, "score": result.output.score}
356
-
357
- with patch(
358
- "planar.ai.agent.get_agent_config",
359
- AsyncMock(return_value=test_agent.to_config()),
360
- ):
361
- wf = await structured_workflow.start("test structured input")
362
- result = await execute(wf)
363
- assert isinstance(result, Suspend)
364
- result = await execute(wf)
365
-
366
- assert isinstance(result, dict)
367
- assert result["message"] == "Test"
368
- assert result["score"] == 95
369
-
370
- updated_wf = await session.get(Workflow, wf.id)
371
- assert updated_wf is not None
372
- assert updated_wf.result == {"message": "Test", "score": 95}
373
-
374
- # Verify the correct provider method was called with right params
375
- assert openai_mock.call_count == 1 # called once
376
- args, kwargs = openai_mock.call_args
377
- assert kwargs["output_type"] == OutputModel
378
- messages = kwargs["messages"]
379
- assert any(
380
- isinstance(m, SystemMessage) and m.content == "Provide structured analysis"
381
- for m in messages
382
- )
383
- assert any(
384
- isinstance(m, UserMessage) and m.content == "Analyze: test structured input"
385
- for m in messages
386
- )
387
-
388
-
389
- async def test_agent_with_input_validation(
390
- session: AsyncSession, mock_get_agent_config, mock_providers
391
- ):
392
- """Test agent with input validation using a Pydantic model."""
393
- openai_mock, anthropic_mock = mock_providers
394
-
395
- # Create an agent with input validation
396
- test_agent = Agent(
397
- name="validated_input_agent",
398
- system_prompt="Process validated input",
399
- user_prompt="Text: {{ input.text }}, Value: {{ input.value }}",
400
- input_type=InputModel,
401
- model="openai:gpt-4.1",
402
- )
403
-
404
- # Define a workflow that uses the agent
405
- @workflow()
406
- async def validation_workflow(input_text: str, input_value: int):
407
- result = await test_agent(
408
- input_value=InputModel(text=input_text, value=input_value)
409
- )
410
- return result.output
411
-
412
- # Start and execute the workflow
413
- wf = await validation_workflow.start("test input", 42)
414
- result = await execute(wf)
415
-
416
- # Verify the result
417
- assert result == "Mock LLM response"
418
-
419
- # Verify the agent validates input
420
- # Define a workflow missing the required 'value' parameter
421
- @workflow()
422
- async def invalid_workflow(input_text: str):
423
- # This call should raise a validation error at runtime
424
- # Ignore the type error to test validation
425
- return await test_agent(input_value=input_text) # type: ignore
426
-
427
- # Start the workflow - this doesn't execute the agent validation yet
428
- invalid_wf = await invalid_workflow.start("missing value")
429
-
430
- # Now actually execute the workflow, which should raise ValueError
431
- with pytest.raises(ValueError):
432
- await execute(invalid_wf)
433
-
434
-
435
- async def test_agent_with_tools(
436
- mock_providers,
437
- client: PlanarTestClient,
438
- app,
439
- ):
440
- """Test agent with tools for multi-turn conversations."""
441
- openai_mock, anthropic_mock = mock_providers
442
-
443
- # Define some tools
444
- async def tool1(param: str) -> str:
445
- """Test tool 1"""
446
- return f"Tool 1 result: {param}"
447
-
448
- async def tool2(num: int) -> int:
449
- """Test tool 2"""
450
- return num * 2
451
-
452
- # Create an agent with tools
453
- test_agent = Agent(
454
- name="tools_agent",
455
- system_prompt="Use tools to solve the problem",
456
- user_prompt="Problem: {{ input }}",
457
- tools=[tool1, tool2],
458
- output_type=OutputModel,
459
- max_turns=3,
460
- model="anthropic:claude-3-sonnet", # Test the Anthropic provider this time
461
- )
462
-
463
- # then register it with app
464
- app.register_agent(test_agent)
465
-
466
- # Define a workflow that uses the agent
467
- @workflow()
468
- async def tools_workflow(problem: str):
469
- result = await test_agent(input_value=problem)
470
- return {"message": result.output.message, "score": result.output.score}
471
-
472
- with patch(
473
- "planar.ai.agent.get_agent_config",
474
- AsyncMock(return_value=test_agent.to_config()),
475
- ):
476
- # Start and execute the workflow
477
- wf = await tools_workflow.start("complex problem")
478
- result = await execute(wf)
479
-
480
- # Verify the result
481
- assert isinstance(result, dict)
482
- assert result["message"] == "Multi-turn response"
483
- assert result["score"] == 90
484
-
485
- # Verify complete was called twice (once for tool call, once for final response)
486
- assert anthropic_mock.call_count == 2
487
-
488
- # First call should include tools
489
- args, first_call_kwargs = anthropic_mock.call_args_list[0]
490
- assert len(first_call_kwargs["tools"]) == 2
491
- assert first_call_kwargs["output_type"] == OutputModel
492
-
493
- response = await client.get(
494
- f"/planar/v1/workflows/{wf.function_name}/runs/{wf.id}/steps"
495
- )
496
- data = response.json()
497
-
498
- step = data["items"][0]
499
- assert step["step_id"] == 1
500
- assert step["function_name"] == "planar.ai.agent.Agent.run_step"
501
- assert step["display_name"] == test_agent.name
502
-
503
-
504
- async def test_config_override(session: AsyncSession, mock_providers):
505
- """Test that agent correctly applies configuration overrides."""
506
- openai_mock, anthropic_mock = mock_providers
507
-
508
- # Create a custom mock for agent_config with overrides
509
- override_config = AgentConfig(
510
- system_prompt="Overridden system prompt",
511
- user_prompt="Overridden user prompt: {{ input }}",
512
- model="anthropic:claude-3-opus", # Change from OpenAI to Anthropic
513
- max_turns=5,
514
- )
515
-
516
- # Create an agent with defaults that will be overridden
517
- test_agent = Agent(
518
- name="override_agent",
519
- system_prompt="Original system prompt",
520
- user_prompt="Original user prompt: {{ input }}",
521
- model="openai:gpt-4.1", # Start with OpenAI
522
- max_turns=1,
523
- )
524
-
525
- @workflow()
526
- async def override_workflow(input_text: str):
527
- result = await test_agent(input_text)
528
- return result.output
529
-
530
- with patch(
531
- "planar.ai.agent.get_agent_config",
532
- AsyncMock(return_value=override_config),
533
- ) as mock_config:
534
- wf = await override_workflow.start("override test")
535
- result = await execute(wf)
536
-
537
- # Verify the result
538
- assert result == "Mock LLM response"
539
-
540
- # Verify get_agent_config was called
541
- assert mock_config.called
542
-
543
- # Since we overrode to anthropic, that provider should be used
544
- assert anthropic_mock.call_count == 1 # called once
545
- assert openai_mock.call_count == 0 # not called
546
-
547
- # Verify the messages include the overridden prompts
548
- args, kwargs = anthropic_mock.call_args
549
- messages = kwargs["messages"]
550
- assert any(
551
- isinstance(m, SystemMessage) and m.content == "Overridden system prompt"
552
- for m in messages
553
- )
554
- assert any(
555
- isinstance(m, UserMessage)
556
- and m.content == "Overridden user prompt: override test"
557
- for m in messages
558
- )
559
-
560
-
561
- async def test_agent_with_model_parameters(session: AsyncSession, mock_providers):
562
- """Test that an agent can be configured with model parameters."""
563
- openai_mock, anthropic_mock = mock_providers
564
-
565
- # Create an agent with model parameters
566
- test_agent = Agent(
567
- name="params_agent",
568
- system_prompt="Test with parameters",
569
- user_prompt="Input: {{ input }}",
570
- model=OpenAI.gpt_4_1,
571
- model_parameters={"temperature": 0.2, "top_p": 0.95},
572
- )
573
-
574
- # Define a workflow that uses the agent
575
- @workflow()
576
- async def params_workflow(input_text: str):
577
- result = await test_agent(input_value=input_text)
578
- return result.output
579
-
580
- with patch(
581
- "planar.ai.agent.get_agent_config",
582
- AsyncMock(return_value=test_agent.to_config()),
583
- ):
584
- # Start and execute the workflow
585
- wf = await params_workflow.start("test input")
586
- result = await execute(wf)
587
-
588
- # Verify the result
589
- assert result == "Mock LLM response"
590
-
591
- # Check that model parameters were handled correctly
592
- # (in a real implementation, this would affect the call to the LLM provider)
593
- assert test_agent.model_parameters == {"temperature": 0.2, "top_p": 0.95}
594
-
595
- # Verify the model parameters are passed to the provider
596
- args, kwargs = openai_mock.call_args
597
- assert "temperature" in kwargs.get("model_spec").parameters
598
- assert kwargs.get("model_spec").parameters["temperature"] == 0.2
599
- assert kwargs.get("model_spec").parameters["top_p"] == 0.95
600
-
601
-
602
- async def test_tool_response_formatting(
603
- session: AsyncSession, mock_get_agent_config, mock_providers
604
- ):
605
- """Test that tool responses are correctly formatted in multi-turn conversations."""
606
- openai_mock, _ = mock_providers
607
-
608
- # Define a tool that returns a specific response - must match name in mock
609
- async def tool1(param: str) -> str:
610
- """Test tool with simple string return"""
611
- return f"Tool result for: {param}"
612
-
613
- # Create an agent with the tool
614
- test_agent = Agent(
615
- name="tool_response_agent",
616
- system_prompt="Use tools to process the query",
617
- user_prompt="Query: {{ input }}",
618
- tools=[tool1], # Name matches what the mock will call
619
- model="openai:gpt-4.1",
620
- max_turns=3,
621
- )
622
-
623
- # Define a workflow using the agent
624
- @workflow()
625
- async def tool_workflow(query: str):
626
- result = await test_agent(input_value=query)
627
- return result.output
628
-
629
- # Start and execute the workflow
630
- wf = await tool_workflow.start("test query")
631
- result = await execute(wf)
632
-
633
- # Verify result
634
- assert result == "Final tool response"
635
-
636
- # Verify complete was called twice
637
- assert openai_mock.call_count == 2
638
-
639
- # Extract the messages from the second call to check for proper tool response formatting
640
- args, second_call_kwargs = openai_mock.call_args_list[1]
641
- messages = second_call_kwargs.get("messages")
642
-
643
- # Check that there's a ToolMessage in the conversation
644
- tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
645
- assert len(tool_messages) == 1
646
-
647
- # Verify the content of the tool message matches our tool's output
648
- assert (
649
- tool_messages[0].content is not None
650
- and "Tool result for: test_param" in tool_messages[0].content
651
- )
652
-
653
- # Verify that the message was formatted using the format_tool_response method
654
- assert tool_messages[0].tool_call_id is not None
655
-
656
-
657
- async def test_structured_output_with_tools(
658
- session: AsyncSession, mock_get_agent_config, mock_providers
659
- ):
660
- """Test that structured output works correctly with tool calling."""
661
- openai_mock, anthropic_mock = mock_providers
662
-
663
- # Define a tool function - must match name in mock
664
- async def tool1(param: str) -> dict:
665
- """Fetch data for the given ID"""
666
- return {"id": param, "value": f"data-{param}"}
667
-
668
- # Create a test agent with structured output and tools
669
- test_agent = Agent(
670
- name="structured_tool_agent",
671
- system_prompt="Process the input and return structured data",
672
- user_prompt="Process: {{ input }}",
673
- tools=[tool1],
674
- output_type=OutputModel,
675
- model="openai:gpt-4.1",
676
- max_turns=3,
677
- )
678
-
679
- # Define workflow
680
- @workflow()
681
- async def structured_tool_workflow(data: str):
682
- result = await test_agent(input_value=data)
683
- return {"message": result.output.message, "score": result.output.score}
684
-
685
- # Start and execute the workflow
686
- wf = await structured_tool_workflow.start("test-data")
687
- result = await execute(wf)
688
-
689
- # Verify result structure
690
- assert isinstance(result, dict)
691
- assert result["message"] == "Multi-turn response"
692
- assert result["score"] == 90
693
-
694
- # Verify calls to complete
695
- assert openai_mock.call_count == 2
696
-
697
- # Check first call (should include tools and output_type)
698
- args, first_call_kwargs = openai_mock.call_args_list[0]
699
- assert first_call_kwargs["output_type"] == OutputModel
700
- assert len(first_call_kwargs["tools"]) == 1
701
-
702
- # Check second call after tool response
703
- args, second_call_kwargs = openai_mock.call_args_list[1]
704
- assert (
705
- second_call_kwargs["output_type"] == OutputModel
706
- ) # Should still request structured output
707
-
708
- # Verify messages in second call include the tool response
709
- messages = second_call_kwargs["messages"]
710
- assert any(isinstance(m, ToolMessage) for m in messages)
711
-
712
- # Verify assistant message with tool calls is included
713
- assistant_messages = [
714
- m for m in messages if isinstance(m, AssistantMessage) and m.tool_calls
715
- ]
716
- assert len(assistant_messages) == 1
717
-
718
-
719
- async def test_tool_error_catching(
720
- session: AsyncSession, mock_get_agent_config, mock_providers
721
- ):
722
- """Test that workflow can catch and handle errors from tool execution."""
723
- openai_mock, anthropic_mock = mock_providers
724
-
725
- # Define a tool that raises an exception - must match name in mock
726
- async def tool1(param: str) -> str:
727
- """This tool always fails"""
728
- raise ValueError(f"Tool error for: {param}")
729
-
730
- # Create an agent with the failing tool
731
- test_agent = Agent(
732
- name="error_handling_agent",
733
- system_prompt="Use tools to process this",
734
- user_prompt="Process: {{ input }}",
735
- tools=[tool1],
736
- model="openai:gpt-4.1",
737
- max_turns=3,
738
- )
739
-
740
- # Define a workflow that catches the error
741
- @workflow()
742
- async def error_handling_workflow(value: str):
743
- try:
744
- result = await test_agent(input_value=value)
745
- return {"status": "success", "output": result.output}
746
- except ValueError as e:
747
- # Workflow catches the error and returns a graceful response
748
- return {"status": "error", "message": str(e)}
749
-
750
- # Start and execute the workflow
751
- wf = await error_handling_workflow.start("test value")
752
- result = await execute(wf)
753
-
754
- # Verify the workflow caught the error
755
- assert isinstance(result, dict) # Make sure result is a dictionary before indexing
756
- assert result.get("status") == "error"
757
- assert "Tool error for:" in result.get("message", "")
758
-
759
- # Verify the API was called once to get the tool call
760
- assert openai_mock.call_count == 1
761
-
762
-
763
- def test_tool_validation():
764
- """Test that different types of functions are supported as tools."""
765
-
766
- # Create some simple Pydantic models for reference
767
- class ValidToolParams(BaseModel):
768
- param: str
769
-
770
- class UntypedToolParams(BaseModel):
771
- param: Any
772
-
773
- # Define a regular function - should work
774
- async def valid_tool(param: str) -> str:
775
- """A valid tool function"""
776
- return f"Result for {param}"
777
-
778
- # This should succeed (not a bound method)
779
- tool_def = create_tool_definition(valid_tool)
780
- assert tool_def.name == "valid_tool"
781
- assert tool_def.description == "A valid tool function"
782
-
783
- # Verify parameter structure
784
- tool_schema = tool_def.parameters
785
- reference_schema = ValidToolParams.model_json_schema()
786
-
787
- # Check required fields
788
- assert tool_schema["required"] == reference_schema["required"]
789
- # Check param is string type
790
- assert tool_schema["properties"]["param"]["type"] == "string"
791
-
792
- # Define a function without type annotations - should work
793
- async def untyped_tool(param):
794
- """An untyped tool function"""
795
- return f"Result for {param}"
796
-
797
- # This should succeed with Any type in the schema
798
- untyped_tool_def = create_tool_definition(untyped_tool)
799
- assert untyped_tool_def.name == "untyped_tool"
800
- assert untyped_tool_def.description == "An untyped tool function"
801
-
802
- # Define a class with methods for testing different method types
803
- class ToolOwner:
804
- async def bound_method(self, param: str) -> str:
805
- """A bound instance method"""
806
- return f"Result for {param}"
807
-
808
- @staticmethod
809
- async def static_method(param: str) -> str:
810
- """A static method"""
811
- return f"Static result for {param}"
812
-
813
- @classmethod
814
- async def class_method(cls, param: str) -> str:
815
- """A class method"""
816
- return f"Class result for {param}"
817
-
818
- # Create an instance and get the bound method
819
- owner = ToolOwner()
820
- bound_method = owner.bound_method
821
-
822
- # Test bound instance methods
823
- bound_tool_def = create_tool_definition(bound_method)
824
- assert bound_tool_def.name == "bound_method"
825
- bound_schema = bound_tool_def.parameters
826
- assert bound_schema["properties"]["param"]["type"] == "string"
827
-
828
- # Test static methods
829
- static_tool_def = create_tool_definition(ToolOwner.static_method)
830
- assert static_tool_def.name == "static_method"
831
- static_schema = static_tool_def.parameters
832
- assert static_schema["properties"]["param"]["type"] == "string"
833
-
834
- # Test class methods
835
- class_tool_def = create_tool_definition(ToolOwner.class_method)
836
- assert class_tool_def.name == "class_method"
837
- class_schema = class_tool_def.parameters
838
- assert class_schema["properties"]["param"]["type"] == "string"
839
-
840
-
841
- # Common models for file-based tests
842
- class ReceiptItem(BaseModel):
843
- name: str = Field(description="Name of the item")
844
- price: float | None = Field(description="Price of the item", default=None)
845
- quantity: int | None = Field(description="Quantity of the item", default=None)
846
-
847
-
848
- class ReceiptData(BaseModel):
849
- merchant_name: str = Field(description="Name of the merchant/store")
850
- date: str = Field(description="Date of the transaction")
851
- total_amount: float = Field(description="Total amount of the transaction")
852
- items: list[ReceiptItem] = Field(
853
- description="List of items purchased with prices if available"
854
- )
855
- payment_method: str | None = Field(
856
- description="Payment method if specified", default=None
857
- )
858
- receipt_number: str | None = Field(
859
- description="Receipt number if available", default=None
860
- )
861
- document_type: str | None = Field(
862
- description="Type of document (pdf or image)", default=None
863
- )
864
-
865
-
866
- @pytest.fixture
867
- def planar_files():
868
- """Create PlanarFile instances for testing."""
869
- image_file = PlanarFile(
870
- id=uuid4(),
871
- filename="receipt.jpg",
872
- content_type="image/jpeg",
873
- size=1024,
874
- )
875
-
876
- pdf_file = PlanarFile(
877
- id=uuid4(),
878
- filename="invoice.pdf",
879
- content_type="application/pdf",
880
- size=2048,
881
- )
882
-
883
- return {"image": image_file, "pdf": pdf_file}
884
-
885
-
886
- async def test_agent_with_direct_planar_file(
887
- session: AsyncSession, mock_get_agent_config, mock_providers, planar_files
888
- ):
889
- """Test agent with a PlanarFile in a Pydantic input model."""
890
- openai_mock, anthropic_mock = mock_providers
891
- image_file = planar_files["image"]
892
-
893
- # Create an agent for receipt analysis
894
- receipt_agent = Agent(
895
- name="receipt_analyzer",
896
- system_prompt="You are an expert receipt analyzer.",
897
- user_prompt="Please analyze this receipt.",
898
- output_type=ReceiptData,
899
- input_type=PlanarFile,
900
- model=OpenAI.gpt_4_1,
901
- )
902
-
903
- # Define a workflow using the agent
904
- @workflow()
905
- async def receipt_analysis_workflow(file: PlanarFile):
906
- # Pass it as input_value
907
- result = await receipt_agent(input_value=file)
908
- return result.output
909
-
910
- # Start and execute the workflow
911
- wf = await receipt_analysis_workflow.start(image_file)
912
- result = await execute(wf)
913
-
914
- # Verify the result is the correct type
915
- assert isinstance(result, ReceiptData)
916
-
917
- # Verify the result structure
918
- assert result.merchant_name == "Coffee Shop"
919
- assert result.date == "2025-03-11"
920
- assert result.total_amount == 42.99
921
- assert result.document_type == "image" # Should detect it's an image
922
- assert isinstance(result.items, list)
923
- assert len(result.items) == 2
924
- assert result.items[0].name == "Coffee"
925
- assert result.items[0].price == 4.99
926
-
927
- # Verify that provider's complete method was called once
928
- assert openai_mock.call_count == 1
929
- args, kwargs = openai_mock.call_args
930
-
931
- # Files are passed in the messages, not directly as planar_files parameter
932
- messages = kwargs.get("messages", [])
933
- user_messages = [m for m in messages if isinstance(m, UserMessage)]
934
- assert len(user_messages) == 1
935
- assert user_messages[0].files is not None
936
- assert len(user_messages[0].files) == 1
937
- assert user_messages[0].files[0] == image_file
938
-
939
-
940
- class DocumentInput(BaseModel):
941
- """Model with a single PlanarFile field."""
942
-
943
- file: PlanarFile
944
- instructions: str | None = None
945
-
946
-
947
- async def test_agent_with_planar_file_in_model(
948
- session: AsyncSession, mock_providers, planar_files
949
- ):
950
- """Test agent with a PlanarFile field in a Pydantic model."""
951
- openai_mock, anthropic_mock = mock_providers
952
- pdf_file = planar_files["pdf"]
953
-
954
- # Create an agent for document analysis
955
- document_agent = Agent(
956
- name="document_analyzer",
957
- system_prompt="You are an expert document analyzer. Extract all information from the document.",
958
- user_prompt="Please analyze this document. {{ input.instructions }}",
959
- output_type=ReceiptData,
960
- model=OpenAI.gpt_4_1,
961
- input_type=DocumentInput,
962
- )
963
-
964
- # Define a workflow using the agent
965
- @workflow()
966
- async def document_analysis_workflow(
967
- file: PlanarFile, instructions: str | None = None
968
- ):
969
- input_model = DocumentInput(file=file, instructions=instructions)
970
- result = await document_agent(input_value=input_model)
971
- return result.output
972
-
973
- with patch(
974
- "planar.ai.agent.get_agent_config",
975
- AsyncMock(return_value=document_agent.to_config()),
976
- ):
977
- # Start and execute the workflow with instructions
978
- wf = await document_analysis_workflow.start(
979
- pdf_file, instructions="Focus on payment details"
980
- )
981
- result = await execute(wf)
982
-
983
- # Verify the result is the correct type
984
- assert isinstance(result, ReceiptData)
985
-
986
- # Verify the result structure
987
- assert result.merchant_name == "Coffee Shop"
988
- assert result.date == "2025-03-11"
989
- assert result.total_amount == 42.99
990
- assert result.document_type == "pdf" # Should detect it's a PDF
991
- assert isinstance(result.items, list)
992
- assert len(result.items) == 2
993
-
994
- # Verify that provider's complete method was called once
995
- assert openai_mock.call_count == 1
996
- args, kwargs = openai_mock.call_args
997
-
998
- # Files are passed in the messages, not directly as planar_files parameter
999
- messages = kwargs.get("messages", [])
1000
- user_messages = [m for m in messages if isinstance(m, UserMessage)]
1001
- assert len(user_messages) == 1
1002
- assert user_messages[0].files is not None
1003
- assert len(user_messages[0].files) == 1
1004
- assert user_messages[0].files[0] == pdf_file
1005
-
1006
- # Verify the user prompt includes the instructions
1007
- messages = kwargs.get("messages", [])
1008
- user_messages = [m for m in messages if isinstance(m, UserMessage)]
1009
- assert len(user_messages) == 1
1010
- assert user_messages[0].content is not None
1011
- assert (
1012
- user_messages[0].content
1013
- == "Please analyze this document. Focus on payment details"
1014
- )
1015
-
1016
-
1017
- class MultiFileInput(BaseModel):
1018
- """Model with a list of PlanarFile field."""
1019
-
1020
- files: list[PlanarFile]
1021
- batch_name: str
1022
-
1023
-
1024
- async def test_agent_with_planar_file_list(
1025
- session: AsyncSession, mock_get_agent_config, mock_providers, planar_files
1026
- ):
1027
- """Test agent with a list of PlanarFile objects in a Pydantic model."""
1028
- openai_mock, anthropic_mock = mock_providers
1029
- image_file = planar_files["image"]
1030
- pdf_file = planar_files["pdf"]
1031
-
1032
- # Create an agent for batch document analysis
1033
- batch_agent = Agent(
1034
- name="batch_analyzer",
1035
- system_prompt="You are a batch document processor. Analyze all provided files.",
1036
- user_prompt="Process batch: {{ input.batch_name }}",
1037
- output_type=str, # Just return a string description
1038
- model=OpenAI.gpt_4_1,
1039
- input_type=MultiFileInput,
1040
- )
1041
-
1042
- # Define a workflow using the agent
1043
- @workflow()
1044
- async def batch_analysis_workflow(files: list[PlanarFile], batch_name: str):
1045
- # Create a model instance with the file list
1046
- input_model = MultiFileInput(files=files, batch_name=batch_name)
1047
- # Call the agent with the model as input_value
1048
- result = await batch_agent(input_value=input_model)
1049
- return result.output
1050
-
1051
- with patch(
1052
- "planar.ai.agent.get_agent_config",
1053
- AsyncMock(return_value=batch_agent.to_config()),
1054
- ):
1055
- # Start and execute the workflow with multiple files
1056
- wf = await batch_analysis_workflow.start(
1057
- [image_file, pdf_file], batch_name="Receipt and Invoice"
1058
- )
1059
- result = await execute(wf)
1060
-
1061
- # Verify the result is a string
1062
- assert isinstance(result, str)
1063
- # Our mock may return either of these responses
1064
- assert result in [
1065
- "Description of the image content",
1066
- "Description of the PDF document",
1067
- "Mock LLM response",
1068
- ]
1069
-
1070
- # Verify that provider's complete method was called once
1071
- assert openai_mock.call_count == 1
1072
- args, kwargs = openai_mock.call_args
1073
-
1074
- messages = kwargs.get("messages", [])
1075
- user_messages = [m for m in messages if isinstance(m, UserMessage)]
1076
- assert len(user_messages) == 1
1077
- assert user_messages[0].files is not None
1078
- assert len(user_messages[0].files) == 2
1079
- assert image_file in user_messages[0].files
1080
- assert pdf_file in user_messages[0].files
1081
-
1082
- # Verify the user prompt includes the batch name
1083
- messages = kwargs.get("messages", [])
1084
- user_messages = [m for m in messages if isinstance(m, UserMessage)]
1085
- assert len(user_messages) == 1
1086
- assert user_messages[0].content == "Process batch: Receipt and Invoice"
1087
-
1088
-
1089
- def test_extract_files_from_model():
1090
- """Test that files are correctly extracted from Pydantic models."""
1091
- image_file = PlanarFile(
1092
- id=uuid4(),
1093
- filename="test_image.jpg",
1094
- content_type="image/jpeg",
1095
- size=1024,
1096
- )
1097
-
1098
- pdf_file = PlanarFile(
1099
- id=uuid4(),
1100
- filename="test_document.pdf",
1101
- content_type="application/pdf",
1102
- size=2048,
1103
- )
1104
-
1105
- # Test model with PlanarFile directly
1106
- files = extract_files_from_model(image_file)
1107
- assert len(files) == 1
1108
- assert files[0] == image_file
1109
-
1110
- # Test model with PlanarFile as field
1111
- class ModelWithFile(BaseModel):
1112
- name: str
1113
- description: str
1114
- file: PlanarFile
1115
- other_data: int
1116
-
1117
- model_with_file = ModelWithFile(
1118
- name="Test Model",
1119
- description="A test model with a file",
1120
- file=pdf_file,
1121
- other_data=42,
1122
- )
1123
-
1124
- files = extract_files_from_model(model_with_file)
1125
- assert len(files) == 1
1126
- assert files[0] == pdf_file
1127
-
1128
- # Test model with list of PlanarFile objects
1129
- class ModelWithFileList(BaseModel):
1130
- name: str
1131
- files: list[PlanarFile]
1132
-
1133
- model_with_file_list = ModelWithFileList(
1134
- name="Test Model with File List",
1135
- files=[image_file, pdf_file],
1136
- )
1137
-
1138
- files = extract_files_from_model(model_with_file_list)
1139
- assert len(files) == 2
1140
- assert image_file in files
1141
- assert pdf_file in files
1142
-
1143
- # Test mixed list with non-PlanarFile items
1144
- class ModelWithMixedList(BaseModel):
1145
- name: str
1146
- items: list
1147
-
1148
- model_with_mixed_list = ModelWithMixedList(
1149
- name="Test Model with Mixed List",
1150
- items=[image_file, "not a file", 123, pdf_file],
1151
- )
1152
-
1153
- files = extract_files_from_model(model_with_mixed_list)
1154
- assert len(files) == 2
1155
- assert image_file in files
1156
- assert pdf_file in files
1157
-
1158
- # Test model with no files
1159
- class ModelWithoutFiles(BaseModel):
1160
- name: str
1161
- value: int
1162
-
1163
- model_without_files = ModelWithoutFiles(name="No Files", value=42)
1164
- files = extract_files_from_model(model_without_files)
1165
- assert len(files) == 0
1166
-
1167
- files = extract_files_from_model("test string")
1168
- assert len(files) == 0
1169
-
1170
- # Test nested BaseModel structure with PlanarFile
1171
- class NestedModel(BaseModel):
1172
- description: str
1173
- file: PlanarFile
1174
-
1175
- nested_model = NestedModel(
1176
- description="A nested model with a file",
1177
- file=image_file,
1178
- )
1179
-
1180
- class AnotherNestedModel(BaseModel):
1181
- data: str
1182
- files: list[PlanarFile]
1183
-
1184
- class ComplexModel(BaseModel):
1185
- name: str
1186
- first_nested: NestedModel
1187
- second_nested: AnotherNestedModel
1188
-
1189
- another_nested = AnotherNestedModel(
1190
- data="Some data",
1191
- files=[pdf_file],
1192
- )
1193
-
1194
- complex_model = ComplexModel(
1195
- name="Complex Model",
1196
- first_nested=nested_model,
1197
- second_nested=another_nested,
1198
- )
1199
-
1200
- files = extract_files_from_model(complex_model)
1201
- assert len(files) == 2
1202
- assert image_file in files
1203
- assert pdf_file in files
1204
-
1205
-
1206
- def test_tool_parameter_serialization():
1207
- """Test that tool parameters are correctly serialized to JSON schema."""
1208
-
1209
- # Create a reference Pydantic model with various parameter types
1210
- class ComplexToolParams(BaseModel):
1211
- str_param: str
1212
- int_param: int
1213
- float_param: float
1214
- bool_param: bool
1215
- list_param: list[str]
1216
- dict_param: dict[str, int]
1217
- union_param: str | int
1218
- optional_param: str | None = None
1219
- untyped_param: Any = None
1220
-
1221
- # Define a function with various parameter types
1222
- async def complex_tool(
1223
- str_param: str,
1224
- int_param: int,
1225
- float_param: float,
1226
- bool_param: bool,
1227
- list_param: list[str],
1228
- dict_param: dict[str, int],
1229
- union_param: str | int,
1230
- annotated_param: str = Field(description="This is an annotated parameter"),
1231
- optional_param: str | None = None,
1232
- complex_param: ComplexToolParams = Field(
1233
- description="A complex parameter with various types"
1234
- ),
1235
- untyped_param=None,
1236
- ) -> dict[str, Any]:
1237
- """A tool with various parameter types"""
1238
- return {"result": "success"}
1239
-
1240
- # Create tool definition
1241
- tool_def = create_tool_definition(complex_tool)
1242
-
1243
- # Verify basic tool properties
1244
- assert tool_def.name == "complex_tool"
1245
- assert tool_def.description == "A tool with various parameter types"
1246
-
1247
- # Get schema from the tool parameters
1248
- tool_schema = tool_def.parameters
1249
-
1250
- # Verify schema structure
1251
- assert "properties" in tool_schema
1252
- assert "required" in tool_schema
1253
-
1254
- # Verify parameter types are correctly mapped
1255
- props = tool_schema["properties"]
1256
- assert props["str_param"]["type"] == "string"
1257
- assert props["int_param"]["type"] == "integer"
1258
- assert props["float_param"]["type"] == "number"
1259
- assert props["bool_param"]["type"] == "boolean"
1260
- assert props["list_param"]["type"] == "array"
1261
- assert props["list_param"]["items"]["type"] == "string"
1262
- assert props["dict_param"]["type"] == "object"
1263
- assert props["dict_param"]["additionalProperties"]["type"] == "integer"
1264
- assert props["union_param"]["anyOf"][0]["type"] == "string"
1265
- assert props["union_param"]["anyOf"][1]["type"] == "integer"
1266
- assert props["annotated_param"]["type"] == "string"
1267
- assert props["annotated_param"]["description"] == "This is an annotated parameter"
1268
- assert props["complex_param"]["$ref"] == "#/$defs/ComplexToolParams"
1269
- assert (
1270
- tool_schema["$defs"]["ComplexToolParams"]
1271
- == ComplexToolParams.model_json_schema()
1272
- )
1273
-
1274
- # Verify required parameters
1275
- required = tool_schema["required"]
1276
- assert "str_param" in required
1277
- assert "int_param" in required
1278
- assert "float_param" in required
1279
- assert "bool_param" in required
1280
- assert "list_param" in required
1281
- assert "dict_param" in required
1282
- assert "union_param" in required
1283
- assert "optional_param" not in required # Has default value
1284
- assert "untyped_param" not in required # Has default value
1285
-
1286
- # Now we should be able to fully serialize the ToolDefinition
1287
- parsed = tool_def.model_dump(mode="json")
1288
-
1289
- # Verify the JSON structure is valid
1290
- assert "name" in parsed
1291
- assert "parameters" in parsed
1292
- assert isinstance(parsed["parameters"], dict)
1293
-
1294
- # Verify parameters were converted to JSON schema
1295
- assert "properties" in parsed["parameters"]
1296
- assert "required" in parsed["parameters"]
1297
- assert parsed["name"] == "complex_tool"
1298
- assert parsed["parameters"]["title"] == "Complex_toolParameters"