planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +155 -283
  3. planar/ai/agent_base.py +170 -0
  4. planar/ai/agent_utils.py +7 -0
  5. planar/ai/pydantic_ai.py +638 -0
  6. planar/ai/test_agent_serialization.py +1 -1
  7. planar/app.py +64 -20
  8. planar/cli.py +39 -27
  9. planar/config.py +45 -36
  10. planar/db/db.py +2 -1
  11. planar/files/storage/azure_blob.py +343 -0
  12. planar/files/storage/base.py +7 -0
  13. planar/files/storage/config.py +70 -7
  14. planar/files/storage/s3.py +6 -6
  15. planar/files/storage/test_azure_blob.py +435 -0
  16. planar/logging/formatter.py +17 -4
  17. planar/logging/test_formatter.py +327 -0
  18. planar/registry_items.py +2 -1
  19. planar/routers/agents_router.py +3 -1
  20. planar/routers/files.py +11 -2
  21. planar/routers/models.py +14 -1
  22. planar/routers/test_agents_router.py +1 -1
  23. planar/routers/test_files_router.py +49 -0
  24. planar/routers/test_routes_security.py +5 -7
  25. planar/routers/test_workflow_router.py +270 -3
  26. planar/routers/workflow.py +95 -36
  27. planar/rules/models.py +36 -39
  28. planar/rules/test_data/account_dormancy_management.json +223 -0
  29. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  30. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  31. planar/rules/test_data/booking_fraud_detection.json +407 -0
  32. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  33. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  34. planar/rules/test_data/customer_lifetime_value.json +143 -0
  35. planar/rules/test_data/import_duties_calculator.json +289 -0
  36. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  37. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  38. planar/rules/test_data/order_consolidation_system.json +375 -0
  39. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  40. planar/rules/test_data/supply_chain_risk.json +253 -0
  41. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  42. planar/rules/test_rules.py +750 -6
  43. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  44. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  45. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  46. planar/security/auth_context.py +21 -0
  47. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  48. planar/security/authorization.py +9 -15
  49. planar/security/tests/test_auth_middleware.py +162 -0
  50. planar/sse/proxy.py +4 -9
  51. planar/test_app.py +92 -1
  52. planar/test_cli.py +81 -59
  53. planar/test_config.py +17 -14
  54. planar/testing/fixtures.py +325 -0
  55. planar/testing/planar_test_client.py +5 -2
  56. planar/utils.py +41 -1
  57. planar/workflows/execution.py +1 -1
  58. planar/workflows/orchestrator.py +5 -0
  59. planar/workflows/serialization.py +12 -6
  60. planar/workflows/step_core.py +3 -1
  61. planar/workflows/test_serialization.py +9 -1
  62. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
  63. planar-0.8.0.dist-info/RECORD +166 -0
  64. planar/.__init__.py.un~ +0 -0
  65. planar/._version.py.un~ +0 -0
  66. planar/.app.py.un~ +0 -0
  67. planar/.cli.py.un~ +0 -0
  68. planar/.config.py.un~ +0 -0
  69. planar/.context.py.un~ +0 -0
  70. planar/.db.py.un~ +0 -0
  71. planar/.di.py.un~ +0 -0
  72. planar/.engine.py.un~ +0 -0
  73. planar/.files.py.un~ +0 -0
  74. planar/.log_context.py.un~ +0 -0
  75. planar/.log_metadata.py.un~ +0 -0
  76. planar/.logging.py.un~ +0 -0
  77. planar/.object_registry.py.un~ +0 -0
  78. planar/.otel.py.un~ +0 -0
  79. planar/.server.py.un~ +0 -0
  80. planar/.session.py.un~ +0 -0
  81. planar/.sqlalchemy.py.un~ +0 -0
  82. planar/.task_local.py.un~ +0 -0
  83. planar/.test_app.py.un~ +0 -0
  84. planar/.test_config.py.un~ +0 -0
  85. planar/.test_object_config.py.un~ +0 -0
  86. planar/.test_sqlalchemy.py.un~ +0 -0
  87. planar/.test_utils.py.un~ +0 -0
  88. planar/.util.py.un~ +0 -0
  89. planar/.utils.py.un~ +0 -0
  90. planar/ai/.__init__.py.un~ +0 -0
  91. planar/ai/._models.py.un~ +0 -0
  92. planar/ai/.agent.py.un~ +0 -0
  93. planar/ai/.agent_utils.py.un~ +0 -0
  94. planar/ai/.events.py.un~ +0 -0
  95. planar/ai/.files.py.un~ +0 -0
  96. planar/ai/.models.py.un~ +0 -0
  97. planar/ai/.providers.py.un~ +0 -0
  98. planar/ai/.pydantic_ai.py.un~ +0 -0
  99. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  100. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  101. planar/ai/.step.py.un~ +0 -0
  102. planar/ai/.test_agent.py.un~ +0 -0
  103. planar/ai/.test_agent_serialization.py.un~ +0 -0
  104. planar/ai/.test_providers.py.un~ +0 -0
  105. planar/ai/.utils.py.un~ +0 -0
  106. planar/ai/providers.py +0 -1088
  107. planar/ai/test_agent.py +0 -1298
  108. planar/ai/test_providers.py +0 -463
  109. planar/db/.db.py.un~ +0 -0
  110. planar/files/.config.py.un~ +0 -0
  111. planar/files/.local.py.un~ +0 -0
  112. planar/files/.local_filesystem.py.un~ +0 -0
  113. planar/files/.model.py.un~ +0 -0
  114. planar/files/.models.py.un~ +0 -0
  115. planar/files/.s3.py.un~ +0 -0
  116. planar/files/.storage.py.un~ +0 -0
  117. planar/files/.test_files.py.un~ +0 -0
  118. planar/files/storage/.__init__.py.un~ +0 -0
  119. planar/files/storage/.base.py.un~ +0 -0
  120. planar/files/storage/.config.py.un~ +0 -0
  121. planar/files/storage/.context.py.un~ +0 -0
  122. planar/files/storage/.local_directory.py.un~ +0 -0
  123. planar/files/storage/.test_local_directory.py.un~ +0 -0
  124. planar/files/storage/.test_s3.py.un~ +0 -0
  125. planar/human/.human.py.un~ +0 -0
  126. planar/human/.test_human.py.un~ +0 -0
  127. planar/logging/.__init__.py.un~ +0 -0
  128. planar/logging/.attributes.py.un~ +0 -0
  129. planar/logging/.formatter.py.un~ +0 -0
  130. planar/logging/.logger.py.un~ +0 -0
  131. planar/logging/.otel.py.un~ +0 -0
  132. planar/logging/.tracer.py.un~ +0 -0
  133. planar/modeling/.mixin.py.un~ +0 -0
  134. planar/modeling/.storage.py.un~ +0 -0
  135. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  136. planar/object_config/.object_config.py.un~ +0 -0
  137. planar/routers/.__init__.py.un~ +0 -0
  138. planar/routers/.agents_router.py.un~ +0 -0
  139. planar/routers/.crud.py.un~ +0 -0
  140. planar/routers/.decision.py.un~ +0 -0
  141. planar/routers/.event.py.un~ +0 -0
  142. planar/routers/.file_attachment.py.un~ +0 -0
  143. planar/routers/.files.py.un~ +0 -0
  144. planar/routers/.files_router.py.un~ +0 -0
  145. planar/routers/.human.py.un~ +0 -0
  146. planar/routers/.info.py.un~ +0 -0
  147. planar/routers/.models.py.un~ +0 -0
  148. planar/routers/.object_config_router.py.un~ +0 -0
  149. planar/routers/.rule.py.un~ +0 -0
  150. planar/routers/.test_object_config_router.py.un~ +0 -0
  151. planar/routers/.test_workflow_router.py.un~ +0 -0
  152. planar/routers/.workflow.py.un~ +0 -0
  153. planar/rules/.decorator.py.un~ +0 -0
  154. planar/rules/.runner.py.un~ +0 -0
  155. planar/rules/.test_rules.py.un~ +0 -0
  156. planar/security/.jwt_middleware.py.un~ +0 -0
  157. planar/sse/.constants.py.un~ +0 -0
  158. planar/sse/.example.html.un~ +0 -0
  159. planar/sse/.hub.py.un~ +0 -0
  160. planar/sse/.model.py.un~ +0 -0
  161. planar/sse/.proxy.py.un~ +0 -0
  162. planar/testing/.client.py.un~ +0 -0
  163. planar/testing/.memory_storage.py.un~ +0 -0
  164. planar/testing/.planar_test_client.py.un~ +0 -0
  165. planar/testing/.predictable_tracer.py.un~ +0 -0
  166. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  167. planar/testing/.test_memory_storage.py.un~ +0 -0
  168. planar/testing/.workflow_observer.py.un~ +0 -0
  169. planar/workflows/.__init__.py.un~ +0 -0
  170. planar/workflows/.builtin_steps.py.un~ +0 -0
  171. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  172. planar/workflows/.context.py.un~ +0 -0
  173. planar/workflows/.contrib.py.un~ +0 -0
  174. planar/workflows/.decorators.py.un~ +0 -0
  175. planar/workflows/.durable_test.py.un~ +0 -0
  176. planar/workflows/.errors.py.un~ +0 -0
  177. planar/workflows/.events.py.un~ +0 -0
  178. planar/workflows/.exceptions.py.un~ +0 -0
  179. planar/workflows/.execution.py.un~ +0 -0
  180. planar/workflows/.human.py.un~ +0 -0
  181. planar/workflows/.lock.py.un~ +0 -0
  182. planar/workflows/.misc.py.un~ +0 -0
  183. planar/workflows/.model.py.un~ +0 -0
  184. planar/workflows/.models.py.un~ +0 -0
  185. planar/workflows/.notifications.py.un~ +0 -0
  186. planar/workflows/.orchestrator.py.un~ +0 -0
  187. planar/workflows/.runtime.py.un~ +0 -0
  188. planar/workflows/.serialization.py.un~ +0 -0
  189. planar/workflows/.step.py.un~ +0 -0
  190. planar/workflows/.step_core.py.un~ +0 -0
  191. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  192. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  193. planar/workflows/.test_concurrency.py.un~ +0 -0
  194. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  195. planar/workflows/.test_human.py.un~ +0 -0
  196. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  197. planar/workflows/.test_orchestrator.py.un~ +0 -0
  198. planar/workflows/.test_race_conditions.py.un~ +0 -0
  199. planar/workflows/.test_serialization.py.un~ +0 -0
  200. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  201. planar/workflows/.test_workflow.py.un~ +0 -0
  202. planar/workflows/.tracing.py.un~ +0 -0
  203. planar/workflows/.types.py.un~ +0 -0
  204. planar/workflows/.util.py.un~ +0 -0
  205. planar/workflows/.utils.py.un~ +0 -0
  206. planar/workflows/.workflow.py.un~ +0 -0
  207. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  208. planar/workflows/.wrappers.py.un~ +0 -0
  209. planar-0.5.0.dist-info/RECORD +0 -289
  210. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
  211. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,638 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ import textwrap
6
+ from typing import Any, Literal, Protocol, Type, cast
7
+
8
+ from pydantic import BaseModel, ValidationError
9
+ from pydantic_ai import BinaryContent
10
+ from pydantic_ai._output import OutputObjectDefinition, OutputToolset
11
+ from pydantic_ai.direct import model_request_stream
12
+ from pydantic_ai.messages import (
13
+ ModelMessage,
14
+ ModelRequest,
15
+ ModelRequestPart,
16
+ ModelResponse,
17
+ ModelResponsePart,
18
+ PartDeltaEvent,
19
+ PartStartEvent,
20
+ RetryPromptPart,
21
+ SystemPromptPart,
22
+ TextPart,
23
+ TextPartDelta,
24
+ ThinkingPart,
25
+ ThinkingPartDelta,
26
+ ToolCallPart,
27
+ ToolCallPartDelta,
28
+ ToolReturnPart,
29
+ UserContent,
30
+ UserPromptPart,
31
+ )
32
+ from pydantic_ai.models import KnownModelName, Model, ModelRequestParameters
33
+ from pydantic_ai.settings import ModelSettings
34
+ from pydantic_ai.tools import ToolDefinition
35
+ from pydantic_core import ErrorDetails
36
+
37
+ from planar.ai import models as m
38
+ from planar.files.models import PlanarFile
39
+ from planar.logging import get_logger
40
+ from planar.utils import partition
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ OUTPUT_TOOL_NAME = "send_final_response"
45
+ OUTPUT_TOOL_DESCRIPTION = """Called to provide the final response which ends this conversation.
46
+ Call it with the final JSON response!"""
47
+
48
+ NATIVE_STRUCTURED_OUTPUT_MODELS = re.compile(
49
+ r"""
50
+ gpt-4o
51
+ """,
52
+ re.VERBOSE | re.IGNORECASE,
53
+ )
54
+
55
+
56
+ def format_validation_errors(errors: list[ErrorDetails], function: bool) -> str:
57
+ lines = [
58
+ f"You called {OUTPUT_TOOL_NAME} with JSON that doesn't pass validation:"
59
+ if function
60
+ else "You returned JSON that did not pass validation:",
61
+ "",
62
+ ]
63
+ for error in errors:
64
+ msg = error["msg"]
65
+ field_path = ".".join([str(loc) for loc in error["loc"]])
66
+ input = error["input"]
67
+ lines.append(f"- {field_path}: {msg} (input: {json.dumps(input)})")
68
+
69
+ return "\n".join(lines)
70
+
71
+
72
+ async def openai_try_upload_file(
73
+ model: KnownModelName | Model, file: PlanarFile
74
+ ) -> m.FileIdContent | None:
75
+ # Currently pydanticAI doesn't support passing file_ids, but leaving the
76
+ # implementation here for when they add support.
77
+ return None
78
+
79
+ if file.content_type != "application/pdf":
80
+ # old implementation only does this for pdf files, so keep the behavior for now
81
+ return None
82
+
83
+ if isinstance(model, str) and not model.startswith("openai:"):
84
+ # not using openai provider
85
+ return None
86
+
87
+ try:
88
+ # make this code work with openai as optional dependency
89
+ from pydantic_ai.models.openai import OpenAIModel
90
+ except ImportError:
91
+ return None
92
+
93
+ if os.getenv("OPENAI_BASE_URL", None) is not None:
94
+ # cannot use OpenAI file upload if using a custom base url
95
+ return None
96
+
97
+ if (
98
+ isinstance(model, OpenAIModel)
99
+ and model.client.base_url.host != "api.openai.com"
100
+ ):
101
+ # same as above
102
+ return None
103
+
104
+ logger.debug("uploading pdf file to openai", filename=file.filename)
105
+
106
+ # use a separate AsyncClient instance since the model might be provided as a string
107
+ from openai import AsyncClient
108
+
109
+ client = AsyncClient()
110
+
111
+ # upload the file to the provider
112
+ openai_file = await client.files.create(
113
+ file=(
114
+ file.filename,
115
+ await file.get_content(),
116
+ file.content_type,
117
+ ),
118
+ purpose="user_data",
119
+ )
120
+ logger.info(
121
+ "uploaded pdf file to openai",
122
+ filename=file.filename,
123
+ openai_file_id=openai_file.id,
124
+ )
125
+ return m.FileIdContent(content=openai_file.id)
126
+
127
+
128
+ async def build_file_map(
129
+ model: KnownModelName | Model, messages: list[m.ModelMessage]
130
+ ) -> m.FileMap:
131
+ logger.debug("building file map", num_messages=len(messages))
132
+ file_dict = {}
133
+
134
+ for message_idx, message in enumerate(messages):
135
+ if isinstance(message, m.UserMessage) and message.files:
136
+ logger.debug(
137
+ "processing files in message",
138
+ num_files=len(message.files),
139
+ message_index=message_idx,
140
+ )
141
+ for file_idx, file in enumerate(message.files):
142
+ logger.debug(
143
+ "processing file",
144
+ file_index=file_idx,
145
+ file_id=file.id,
146
+ content_type=file.content_type,
147
+ )
148
+
149
+ file_content_id = await openai_try_upload_file(model, file)
150
+ # TODO: add more `try_upload_file` implementations for other providers that support
151
+ if file_content_id is not None:
152
+ file_dict[str(file.id)] = file_content_id
153
+ continue
154
+
155
+ # For now we are not using uploaded files with Gemini, so convert all to base64
156
+ if file.content_type.startswith(
157
+ ("image/", "audio/", "video/", "application/pdf")
158
+ ):
159
+ logger.debug(
160
+ "encoding file to base64",
161
+ filename=file.filename,
162
+ content_type=file.content_type,
163
+ )
164
+ file_dict[str(file.id)] = m.Base64Content(
165
+ content=base64.b64encode(await file.get_content()).decode(
166
+ "utf-8"
167
+ ),
168
+ content_type=file.content_type,
169
+ )
170
+ else:
171
+ raise ValueError(f"Unsupported file type: {file.content_type}")
172
+
173
+ return m.FileMap(mapping=file_dict)
174
+
175
+
176
+ async def prepare_messages(
177
+ model: KnownModelName | Model, messages: list[m.ModelMessage]
178
+ ) -> list[Any]:
179
+ """Prepare messages from Planar representations into the format expected by PydanticAI.
180
+
181
+ Args:
182
+ messages: List of structured messages.
183
+ file_map: Optional file map for file content.
184
+
185
+ Returns:
186
+ List of messages in PydanticAI format
187
+ """
188
+ pydantic_messages: list[ModelMessage] = []
189
+ file_map = await build_file_map(model, messages)
190
+
191
+ def append_request_part(part: ModelRequestPart):
192
+ last = (
193
+ pydantic_messages[-1]
194
+ if pydantic_messages and isinstance(pydantic_messages[-1], ModelRequest)
195
+ else None
196
+ )
197
+ if not last:
198
+ last = ModelRequest(parts=[])
199
+ pydantic_messages.append(last)
200
+ last.parts.append(part)
201
+
202
+ def append_response_part(part: ModelResponsePart):
203
+ last = (
204
+ pydantic_messages[-1]
205
+ if pydantic_messages and isinstance(pydantic_messages[-1], ModelResponse)
206
+ else None
207
+ )
208
+ if not last:
209
+ last = ModelResponse(parts=[])
210
+ pydantic_messages.append(last)
211
+ last.parts.append(part)
212
+
213
+ for message in messages:
214
+ if isinstance(message, m.SystemMessage):
215
+ append_request_part(SystemPromptPart(content=message.content or ""))
216
+ elif isinstance(message, m.UserMessage):
217
+ user_content: list[UserContent] = []
218
+ files: list[m.FileContent] = []
219
+ if message.files:
220
+ if not file_map:
221
+ raise ValueError("File map empty while user message has files.")
222
+ for file in message.files:
223
+ if str(file.id) not in file_map.mapping:
224
+ raise ValueError(
225
+ f"File {file} not found in file map {file_map}."
226
+ )
227
+ files.append(file_map.mapping[str(file.id)])
228
+ for file in files:
229
+ match file:
230
+ case m.Base64Content():
231
+ user_content.append(
232
+ BinaryContent(
233
+ data=base64.b64decode(file.content),
234
+ media_type=file.content_type,
235
+ )
236
+ )
237
+ case m.FileIdContent():
238
+ raise Exception(
239
+ "file id handling not implemented yet for PydanticAI"
240
+ )
241
+ if message.content is not None:
242
+ user_content.append(message.content)
243
+ append_request_part(UserPromptPart(content=user_content))
244
+ elif isinstance(message, m.ToolMessage):
245
+ append_request_part(
246
+ ToolReturnPart(
247
+ tool_name="unknown", # FIXME: Planar's ToolMessage doesn't include tool name
248
+ content=message.content,
249
+ tool_call_id=message.tool_call_id,
250
+ )
251
+ )
252
+ elif isinstance(message, m.AssistantMessage):
253
+ if message.content:
254
+ append_response_part(TextPart(content=message.content or ""))
255
+ if message.tool_calls:
256
+ for tc in message.tool_calls:
257
+ append_response_part(
258
+ ToolCallPart(
259
+ tool_call_id=str(tc.id),
260
+ tool_name=tc.name,
261
+ args=tc.arguments,
262
+ )
263
+ )
264
+
265
+ return pydantic_messages
266
+
267
+
268
+ class StreamEventHandler(Protocol):
269
+ def emit(self, event: Literal["text", "think"], data: str) -> None: ...
270
+
271
+
272
+ def setup_native_structured_output(
273
+ request_params: ModelRequestParameters,
274
+ output_type: Type[BaseModel],
275
+ ):
276
+ schema_name = output_type.__name__
277
+ if not re.match(r"^[a-zA-Z0-9_-]+$", output_type.__name__):
278
+ schema_name = re.sub(r"[^a-zA-Z0-9_-]", "_", output_type.__name__)
279
+ json_schema = output_type.model_json_schema()
280
+ request_params.output_object = OutputObjectDefinition(
281
+ name=schema_name,
282
+ description=output_type.__doc__ or "",
283
+ json_schema=json_schema,
284
+ )
285
+ request_params.output_mode = "native"
286
+
287
+
288
+ def setup_tool_structured_output(
289
+ request_params: ModelRequestParameters,
290
+ output_type: Type[BaseModel],
291
+ messages: list[ModelMessage],
292
+ ):
293
+ request_params.output_mode = "tool"
294
+ toolset = OutputToolset.build(
295
+ [output_type],
296
+ name=OUTPUT_TOOL_NAME,
297
+ description=OUTPUT_TOOL_DESCRIPTION,
298
+ )
299
+ assert toolset
300
+ output_tool_defs = toolset._tool_defs
301
+ assert len(output_tool_defs) == 1, "Only one output tool is expected"
302
+ output_tool_defs[0].strict = True
303
+ request_params.output_tools = output_tool_defs
304
+
305
+ if not len(messages):
306
+ return
307
+
308
+ # Some weaker models might not understand that they need to call a function
309
+ # to return the final response. Add a reminder to the end of the system
310
+ # prompt.
311
+ first_request = messages[0]
312
+ first_part = first_request.parts[0]
313
+ if not isinstance(first_part, SystemPromptPart):
314
+ return
315
+ extra_system = textwrap.dedent(
316
+ f"""\n
317
+ WHEN you have a final JSON response, you MUST call the "{OUTPUT_TOOL_NAME}" function/tool with the response to return it. DO NOT RETURN the JSON response directly!!!
318
+ """
319
+ )
320
+ first_part.content += extra_system
321
+
322
+
323
+ def return_native_structured_output[TOutput: BaseModel](
324
+ output_type: Type[TOutput],
325
+ final_tool_calls: list[m.ToolCall],
326
+ content: str,
327
+ thinking: str | None = None,
328
+ ) -> m.CompletionResponse[TOutput]:
329
+ try:
330
+ result = m.CompletionResponse(
331
+ content=output_type.model_validate_json(content),
332
+ tool_calls=final_tool_calls,
333
+ reasoning_content=thinking,
334
+ )
335
+ logger.info(
336
+ "model run completed with structured output",
337
+ content=result.content,
338
+ reasoning_content=result.reasoning_content,
339
+ tool_calls=result.tool_calls,
340
+ )
341
+ return result
342
+ except Exception:
343
+ logger.exception(
344
+ "model output parse failure",
345
+ content=content,
346
+ output_model=output_type,
347
+ )
348
+ raise
349
+
350
+
351
+ def return_tool_structured_output[TOutput: BaseModel](
352
+ output_type: Type[TOutput],
353
+ tool_calls: list[m.ToolCall],
354
+ final_result_tc: m.ToolCall,
355
+ content: str,
356
+ thinking: str | None = None,
357
+ ) -> m.CompletionResponse[TOutput]:
358
+ try:
359
+ result = m.CompletionResponse(
360
+ content=output_type.model_validate(final_result_tc.arguments),
361
+ tool_calls=tool_calls,
362
+ reasoning_content=thinking,
363
+ )
364
+ logger.info(
365
+ "model run completed with structured output",
366
+ content=result.content,
367
+ reasoning_content=result.reasoning_content,
368
+ tool_calls=result.tool_calls,
369
+ )
370
+ return result
371
+ except Exception:
372
+ logger.exception(
373
+ "model output parse failure",
374
+ content=content,
375
+ output_model=output_type,
376
+ )
377
+ raise
378
+
379
+
380
+ class ModelRunResponse[TOutput: BaseModel | str](BaseModel):
381
+ response: m.CompletionResponse[TOutput]
382
+ extra_turns_used: int
383
+
384
+
385
+ async def model_run[TOutput: BaseModel | str](
386
+ model: Model | KnownModelName,
387
+ max_extra_turns: int,
388
+ model_settings: dict[str, Any] | None = None,
389
+ messages: list[m.ModelMessage] = [],
390
+ tools: list[m.ToolDefinition] = [],
391
+ event_handler: StreamEventHandler | None = None,
392
+ output_type: Type[TOutput] = str,
393
+ ) -> ModelRunResponse[TOutput]:
394
+ # assert that the caller doesn't provide a tool called "final_result"
395
+ if any(tool.name == OUTPUT_TOOL_NAME for tool in tools):
396
+ raise ValueError(
397
+ f'Tool named "{OUTPUT_TOOL_NAME}" is reserved and should not be provided.'
398
+ )
399
+
400
+ extra_turns_used = 0
401
+ model_name = model.model_name if isinstance(model, Model) else model
402
+ # Only enable native structured output for models that support it
403
+ supports_native_structured_output = bool(
404
+ NATIVE_STRUCTURED_OUTPUT_MODELS.search(model_name)
405
+ )
406
+
407
+ request_params = ModelRequestParameters(
408
+ function_tools=[
409
+ ToolDefinition(
410
+ name=tool.name,
411
+ description=tool.description,
412
+ parameters_json_schema=tool.parameters,
413
+ strict=True,
414
+ )
415
+ for tool in tools
416
+ ]
417
+ )
418
+
419
+ structured_output = issubclass(output_type, BaseModel)
420
+
421
+ def emit(event_type: Literal["text", "think"], content: str):
422
+ if event_handler:
423
+ event_handler.emit(event_type, content)
424
+
425
+ history = await prepare_messages(model, messages=messages)
426
+
427
+ if structured_output:
428
+ if supports_native_structured_output:
429
+ setup_native_structured_output(request_params, output_type)
430
+ else:
431
+ setup_tool_structured_output(request_params, output_type, history)
432
+
433
+ while True:
434
+ think_buffer = []
435
+ text_buffer = []
436
+ current_tool_call = None
437
+ current_tool_args_buffer = []
438
+ current_tool_call_id = None
439
+ tool_calls = []
440
+
441
+ response_parts: list[ModelResponsePart] = []
442
+
443
+ async with model_request_stream(
444
+ model=model,
445
+ messages=history,
446
+ model_request_parameters=request_params,
447
+ model_settings=cast(ModelSettings, model_settings),
448
+ ) as stream:
449
+ async for event in stream:
450
+ match event:
451
+ case PartStartEvent(part=part):
452
+ response_parts.append(part)
453
+ if isinstance(part, TextPart):
454
+ emit("text", part.content)
455
+ text_buffer.append(part.content)
456
+ elif isinstance(part, ThinkingPart):
457
+ emit("think", part.content)
458
+ think_buffer.append(part.content)
459
+ elif isinstance(part, ToolCallPart):
460
+ if current_tool_call is not None:
461
+ # If we already have a tool call, emit the previous one
462
+ tool_calls.append(
463
+ dict(
464
+ name=current_tool_call,
465
+ arg_str="".join(current_tool_args_buffer),
466
+ id=current_tool_call_id,
467
+ )
468
+ )
469
+ current_tool_call = part.tool_name
470
+ current_tool_call_id = part.tool_call_id
471
+ current_tool_args_buffer = []
472
+ if part.args:
473
+ if isinstance(part.args, dict):
474
+ current_tool_args_buffer.append(
475
+ json.dumps(part.args)
476
+ )
477
+ else:
478
+ current_tool_args_buffer.append(part.args)
479
+ case PartDeltaEvent(delta=delta):
480
+ current = response_parts[-1]
481
+ if isinstance(delta, TextPartDelta):
482
+ assert isinstance(current, TextPart)
483
+ emit("text", delta.content_delta)
484
+ text_buffer.append(delta.content_delta)
485
+ current.content += delta.content_delta
486
+ elif (
487
+ isinstance(delta, ThinkingPartDelta) and delta.content_delta
488
+ ):
489
+ assert isinstance(current, ThinkingPart)
490
+ emit("think", delta.content_delta)
491
+ think_buffer.append(delta.content_delta)
492
+ current.content += delta.content_delta
493
+ elif isinstance(delta, ToolCallPartDelta):
494
+ assert isinstance(current, ToolCallPart)
495
+ assert current_tool_call is not None
496
+ assert current_tool_call_id == delta.tool_call_id
497
+ current_tool_args_buffer.append(delta.args_delta)
498
+ if delta.tool_name_delta:
499
+ current.tool_name += delta.tool_name_delta
500
+ if isinstance(delta.args_delta, str):
501
+ if current.args is None:
502
+ current.args = ""
503
+ assert isinstance(current.args, str)
504
+ current.args += delta.args_delta
505
+
506
+ if current_tool_call is not None:
507
+ tool_calls.append(
508
+ dict(
509
+ name=current_tool_call,
510
+ arg_str="".join(current_tool_args_buffer),
511
+ id=current_tool_call_id,
512
+ )
513
+ )
514
+
515
+ content = "".join(text_buffer)
516
+ thinking = "".join(think_buffer)
517
+
518
+ logger.debug(
519
+ "model run completed",
520
+ content=content,
521
+ thinking=thinking,
522
+ tool_calls=tool_calls,
523
+ )
524
+
525
+ try:
526
+ calls = [
527
+ m.ToolCall(
528
+ id=tc["id"],
529
+ name=tc["name"],
530
+ arguments=json.loads(tc["arg_str"]),
531
+ )
532
+ for tc in tool_calls
533
+ ]
534
+
535
+ def is_output_tool(tc):
536
+ return tc.name == OUTPUT_TOOL_NAME
537
+
538
+ final_tool_calls, final_result_tool_calls = partition(is_output_tool, calls)
539
+ except json.JSONDecodeError:
540
+ logger.exception(
541
+ "tool call json parse failure",
542
+ tool_calls=tool_calls,
543
+ )
544
+ raise
545
+
546
+ if final_tool_calls:
547
+ return ModelRunResponse(
548
+ response=m.CompletionResponse(
549
+ tool_calls=final_tool_calls,
550
+ reasoning_content=thinking,
551
+ ),
552
+ extra_turns_used=extra_turns_used,
553
+ )
554
+
555
+ if final_result_tool_calls:
556
+ # only 1 final result tool call is expected
557
+ assert len(final_result_tool_calls) == 1
558
+
559
+ if structured_output:
560
+ try:
561
+ if supports_native_structured_output:
562
+ return ModelRunResponse(
563
+ response=return_native_structured_output(
564
+ output_type, final_tool_calls, content, thinking
565
+ ),
566
+ extra_turns_used=extra_turns_used,
567
+ )
568
+ elif final_result_tool_calls:
569
+ return ModelRunResponse(
570
+ response=return_tool_structured_output(
571
+ output_type,
572
+ final_tool_calls,
573
+ final_result_tool_calls[0],
574
+ content,
575
+ thinking,
576
+ ),
577
+ extra_turns_used=extra_turns_used,
578
+ )
579
+ except ValidationError as e:
580
+ if extra_turns_used >= max_extra_turns:
581
+ raise
582
+ # retry passing the validation error to the LLM
583
+ # first, append the collected response parts to the history
584
+ history.append(ModelResponse(parts=response_parts))
585
+ # now append the ToolResponse with the validation errors
586
+
587
+ retry_part = RetryPromptPart(
588
+ content=format_validation_errors(
589
+ e.errors(), function=len(final_result_tool_calls) > 0
590
+ )
591
+ )
592
+ if final_result_tool_calls:
593
+ retry_part.tool_name = OUTPUT_TOOL_NAME
594
+ retry_part.tool_call_id = cast(str, final_result_tool_calls[0].id)
595
+
596
+ request_parts: list[ModelRequestPart] = [retry_part]
597
+ history.append(ModelRequest(parts=request_parts))
598
+ extra_turns_used += 1
599
+ continue
600
+
601
+ if output_type is not str:
602
+ if extra_turns_used >= max_extra_turns:
603
+ raise ValueError(
604
+ "Model did not return structured output, and no turns left to retry."
605
+ )
606
+ # We can only reach this point if the model did not call send_final_response
607
+ # To return structured output. Report the error back to the LLM and retry
608
+ history.append(ModelResponse(parts=response_parts))
609
+ history.append(
610
+ ModelRequest(
611
+ parts=[
612
+ UserPromptPart(
613
+ content=f'Error processing response. You MUST pass the final JSON response to the "{OUTPUT_TOOL_NAME}" tool/function. DO NOT RETURN the JSON directly!!!'
614
+ )
615
+ ]
616
+ )
617
+ )
618
+ extra_turns_used += 1
619
+ continue
620
+
621
+ result = cast(
622
+ m.CompletionResponse[TOutput],
623
+ m.CompletionResponse(
624
+ content=content,
625
+ tool_calls=final_tool_calls,
626
+ reasoning_content=thinking,
627
+ ),
628
+ )
629
+ logger.info(
630
+ "model run completed with string output",
631
+ content=result.content,
632
+ reasoning_content=result.reasoning_content,
633
+ tool_calls=result.tool_calls,
634
+ )
635
+ return ModelRunResponse(
636
+ response=result,
637
+ extra_turns_used=extra_turns_used,
638
+ )
@@ -56,7 +56,7 @@ def test_agent_with_tools():
56
56
  name="test_agent_with_tools",
57
57
  system_prompt="System with tools",
58
58
  user_prompt="User: {input}",
59
- model="anthropic:claude-3-sonnet",
59
+ model="anthropic:claude-3-5-sonnet-latest",
60
60
  max_turns=5,
61
61
  tools=[test_tool],
62
62
  )