planar 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +67 -30
  3. planar/ai/pydantic_ai.py +570 -0
  4. planar/ai/pydantic_ai_agent.py +329 -0
  5. planar/ai/test_agent.py +2 -2
  6. planar/app.py +64 -20
  7. planar/cli.py +39 -27
  8. planar/config.py +45 -36
  9. planar/db/db.py +2 -1
  10. planar/files/storage/azure_blob.py +343 -0
  11. planar/files/storage/base.py +7 -0
  12. planar/files/storage/config.py +70 -7
  13. planar/files/storage/s3.py +6 -6
  14. planar/files/storage/test_azure_blob.py +435 -0
  15. planar/logging/formatter.py +17 -4
  16. planar/logging/test_formatter.py +327 -0
  17. planar/registry_items.py +2 -1
  18. planar/routers/agents_router.py +3 -1
  19. planar/routers/files.py +11 -2
  20. planar/routers/models.py +14 -1
  21. planar/routers/test_files_router.py +49 -0
  22. planar/routers/test_routes_security.py +5 -7
  23. planar/routers/test_workflow_router.py +270 -3
  24. planar/routers/workflow.py +95 -36
  25. planar/rules/models.py +36 -39
  26. planar/rules/test_data/account_dormancy_management.json +223 -0
  27. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  28. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  29. planar/rules/test_data/booking_fraud_detection.json +407 -0
  30. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  31. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  32. planar/rules/test_data/customer_lifetime_value.json +143 -0
  33. planar/rules/test_data/import_duties_calculator.json +289 -0
  34. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  35. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  36. planar/rules/test_data/order_consolidation_system.json +375 -0
  37. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  38. planar/rules/test_data/supply_chain_risk.json +253 -0
  39. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  40. planar/rules/test_rules.py +750 -6
  41. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  42. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  43. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  44. planar/security/auth_context.py +21 -0
  45. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  46. planar/security/authorization.py +9 -15
  47. planar/security/tests/test_auth_middleware.py +162 -0
  48. planar/sse/proxy.py +4 -9
  49. planar/test_app.py +92 -1
  50. planar/test_cli.py +81 -59
  51. planar/test_config.py +17 -14
  52. planar/testing/fixtures.py +325 -0
  53. planar/testing/planar_test_client.py +5 -2
  54. planar/utils.py +41 -1
  55. planar/workflows/execution.py +1 -1
  56. planar/workflows/orchestrator.py +5 -0
  57. planar/workflows/serialization.py +12 -6
  58. planar/workflows/step_core.py +3 -1
  59. planar/workflows/test_serialization.py +9 -1
  60. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/METADATA +30 -5
  61. planar-0.7.0.dist-info/RECORD +169 -0
  62. planar/.__init__.py.un~ +0 -0
  63. planar/._version.py.un~ +0 -0
  64. planar/.app.py.un~ +0 -0
  65. planar/.cli.py.un~ +0 -0
  66. planar/.config.py.un~ +0 -0
  67. planar/.context.py.un~ +0 -0
  68. planar/.db.py.un~ +0 -0
  69. planar/.di.py.un~ +0 -0
  70. planar/.engine.py.un~ +0 -0
  71. planar/.files.py.un~ +0 -0
  72. planar/.log_context.py.un~ +0 -0
  73. planar/.log_metadata.py.un~ +0 -0
  74. planar/.logging.py.un~ +0 -0
  75. planar/.object_registry.py.un~ +0 -0
  76. planar/.otel.py.un~ +0 -0
  77. planar/.server.py.un~ +0 -0
  78. planar/.session.py.un~ +0 -0
  79. planar/.sqlalchemy.py.un~ +0 -0
  80. planar/.task_local.py.un~ +0 -0
  81. planar/.test_app.py.un~ +0 -0
  82. planar/.test_config.py.un~ +0 -0
  83. planar/.test_object_config.py.un~ +0 -0
  84. planar/.test_sqlalchemy.py.un~ +0 -0
  85. planar/.test_utils.py.un~ +0 -0
  86. planar/.util.py.un~ +0 -0
  87. planar/.utils.py.un~ +0 -0
  88. planar/ai/.__init__.py.un~ +0 -0
  89. planar/ai/._models.py.un~ +0 -0
  90. planar/ai/.agent.py.un~ +0 -0
  91. planar/ai/.agent_utils.py.un~ +0 -0
  92. planar/ai/.events.py.un~ +0 -0
  93. planar/ai/.files.py.un~ +0 -0
  94. planar/ai/.models.py.un~ +0 -0
  95. planar/ai/.providers.py.un~ +0 -0
  96. planar/ai/.pydantic_ai.py.un~ +0 -0
  97. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  98. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  99. planar/ai/.step.py.un~ +0 -0
  100. planar/ai/.test_agent.py.un~ +0 -0
  101. planar/ai/.test_agent_serialization.py.un~ +0 -0
  102. planar/ai/.test_providers.py.un~ +0 -0
  103. planar/ai/.utils.py.un~ +0 -0
  104. planar/db/.db.py.un~ +0 -0
  105. planar/files/.config.py.un~ +0 -0
  106. planar/files/.local.py.un~ +0 -0
  107. planar/files/.local_filesystem.py.un~ +0 -0
  108. planar/files/.model.py.un~ +0 -0
  109. planar/files/.models.py.un~ +0 -0
  110. planar/files/.s3.py.un~ +0 -0
  111. planar/files/.storage.py.un~ +0 -0
  112. planar/files/.test_files.py.un~ +0 -0
  113. planar/files/storage/.__init__.py.un~ +0 -0
  114. planar/files/storage/.base.py.un~ +0 -0
  115. planar/files/storage/.config.py.un~ +0 -0
  116. planar/files/storage/.context.py.un~ +0 -0
  117. planar/files/storage/.local_directory.py.un~ +0 -0
  118. planar/files/storage/.test_local_directory.py.un~ +0 -0
  119. planar/files/storage/.test_s3.py.un~ +0 -0
  120. planar/human/.human.py.un~ +0 -0
  121. planar/human/.test_human.py.un~ +0 -0
  122. planar/logging/.__init__.py.un~ +0 -0
  123. planar/logging/.attributes.py.un~ +0 -0
  124. planar/logging/.formatter.py.un~ +0 -0
  125. planar/logging/.logger.py.un~ +0 -0
  126. planar/logging/.otel.py.un~ +0 -0
  127. planar/logging/.tracer.py.un~ +0 -0
  128. planar/modeling/.mixin.py.un~ +0 -0
  129. planar/modeling/.storage.py.un~ +0 -0
  130. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  131. planar/object_config/.object_config.py.un~ +0 -0
  132. planar/routers/.__init__.py.un~ +0 -0
  133. planar/routers/.agents_router.py.un~ +0 -0
  134. planar/routers/.crud.py.un~ +0 -0
  135. planar/routers/.decision.py.un~ +0 -0
  136. planar/routers/.event.py.un~ +0 -0
  137. planar/routers/.file_attachment.py.un~ +0 -0
  138. planar/routers/.files.py.un~ +0 -0
  139. planar/routers/.files_router.py.un~ +0 -0
  140. planar/routers/.human.py.un~ +0 -0
  141. planar/routers/.info.py.un~ +0 -0
  142. planar/routers/.models.py.un~ +0 -0
  143. planar/routers/.object_config_router.py.un~ +0 -0
  144. planar/routers/.rule.py.un~ +0 -0
  145. planar/routers/.test_object_config_router.py.un~ +0 -0
  146. planar/routers/.test_workflow_router.py.un~ +0 -0
  147. planar/routers/.workflow.py.un~ +0 -0
  148. planar/rules/.decorator.py.un~ +0 -0
  149. planar/rules/.runner.py.un~ +0 -0
  150. planar/rules/.test_rules.py.un~ +0 -0
  151. planar/security/.jwt_middleware.py.un~ +0 -0
  152. planar/sse/.constants.py.un~ +0 -0
  153. planar/sse/.example.html.un~ +0 -0
  154. planar/sse/.hub.py.un~ +0 -0
  155. planar/sse/.model.py.un~ +0 -0
  156. planar/sse/.proxy.py.un~ +0 -0
  157. planar/testing/.client.py.un~ +0 -0
  158. planar/testing/.memory_storage.py.un~ +0 -0
  159. planar/testing/.planar_test_client.py.un~ +0 -0
  160. planar/testing/.predictable_tracer.py.un~ +0 -0
  161. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  162. planar/testing/.test_memory_storage.py.un~ +0 -0
  163. planar/testing/.workflow_observer.py.un~ +0 -0
  164. planar/workflows/.__init__.py.un~ +0 -0
  165. planar/workflows/.builtin_steps.py.un~ +0 -0
  166. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  167. planar/workflows/.context.py.un~ +0 -0
  168. planar/workflows/.contrib.py.un~ +0 -0
  169. planar/workflows/.decorators.py.un~ +0 -0
  170. planar/workflows/.durable_test.py.un~ +0 -0
  171. planar/workflows/.errors.py.un~ +0 -0
  172. planar/workflows/.events.py.un~ +0 -0
  173. planar/workflows/.exceptions.py.un~ +0 -0
  174. planar/workflows/.execution.py.un~ +0 -0
  175. planar/workflows/.human.py.un~ +0 -0
  176. planar/workflows/.lock.py.un~ +0 -0
  177. planar/workflows/.misc.py.un~ +0 -0
  178. planar/workflows/.model.py.un~ +0 -0
  179. planar/workflows/.models.py.un~ +0 -0
  180. planar/workflows/.notifications.py.un~ +0 -0
  181. planar/workflows/.orchestrator.py.un~ +0 -0
  182. planar/workflows/.runtime.py.un~ +0 -0
  183. planar/workflows/.serialization.py.un~ +0 -0
  184. planar/workflows/.step.py.un~ +0 -0
  185. planar/workflows/.step_core.py.un~ +0 -0
  186. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  187. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  188. planar/workflows/.test_concurrency.py.un~ +0 -0
  189. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  190. planar/workflows/.test_human.py.un~ +0 -0
  191. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  192. planar/workflows/.test_orchestrator.py.un~ +0 -0
  193. planar/workflows/.test_race_conditions.py.un~ +0 -0
  194. planar/workflows/.test_serialization.py.un~ +0 -0
  195. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  196. planar/workflows/.test_workflow.py.un~ +0 -0
  197. planar/workflows/.tracing.py.un~ +0 -0
  198. planar/workflows/.types.py.un~ +0 -0
  199. planar/workflows/.util.py.un~ +0 -0
  200. planar/workflows/.utils.py.un~ +0 -0
  201. planar/workflows/.workflow.py.un~ +0 -0
  202. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  203. planar/workflows/.wrappers.py.un~ +0 -0
  204. planar-0.5.0.dist-info/RECORD +0 -289
  205. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/WHEEL +0 -0
  206. {planar-0.5.0.dist-info → planar-0.7.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,570 @@
1
+ import base64
2
+ import json
3
+ import re
4
+ import textwrap
5
+ from typing import Any, Literal, Protocol, Type, cast
6
+
7
+ from pydantic import BaseModel, ValidationError
8
+ from pydantic_ai import BinaryContent
9
+ from pydantic_ai._output import OutputObjectDefinition, OutputToolset
10
+ from pydantic_ai.direct import model_request_stream
11
+ from pydantic_ai.messages import (
12
+ ModelMessage,
13
+ ModelRequest,
14
+ ModelRequestPart,
15
+ ModelResponse,
16
+ ModelResponsePart,
17
+ PartDeltaEvent,
18
+ PartStartEvent,
19
+ RetryPromptPart,
20
+ SystemPromptPart,
21
+ TextPart,
22
+ TextPartDelta,
23
+ ThinkingPart,
24
+ ThinkingPartDelta,
25
+ ToolCallPart,
26
+ ToolCallPartDelta,
27
+ ToolReturnPart,
28
+ UserContent,
29
+ UserPromptPart,
30
+ )
31
+ from pydantic_ai.models import Model, ModelRequestParameters
32
+ from pydantic_ai.settings import ModelSettings
33
+ from pydantic_ai.tools import ToolDefinition
34
+ from pydantic_core import ErrorDetails
35
+
36
+ from planar.ai import models as m
37
+ from planar.logging import get_logger
38
+ from planar.utils import partition
39
+
40
+ logger = get_logger(__name__)
41
+
42
+ OUTPUT_TOOL_NAME = "send_final_response"
43
+ OUTPUT_TOOL_DESCRIPTION = """Called to provide the final response which ends this conversation.
44
+ Call it with the final JSON response!"""
45
+
46
+ NATIVE_STRUCTURED_OUTPUT_MODELS = re.compile(
47
+ r"""
48
+ gpt-4o
49
+ """,
50
+ re.VERBOSE | re.IGNORECASE,
51
+ )
52
+
53
+
54
+ def format_validation_errors(errors: list[ErrorDetails], function: bool) -> str:
55
+ lines = [
56
+ f"You called {OUTPUT_TOOL_NAME} with JSON that doesn't pass validation:"
57
+ if function
58
+ else "You returned JSON that did not pass validation:",
59
+ "",
60
+ ]
61
+ for error in errors:
62
+ msg = error["msg"]
63
+ field_path = ".".join([str(loc) for loc in error["loc"]])
64
+ input = error["input"]
65
+ lines.append(f"- {field_path}: {msg} (input: {json.dumps(input)})")
66
+
67
+ return "\n".join(lines)
68
+
69
+
70
+ async def build_file_map(messages: list[m.ModelMessage]) -> m.FileMap:
71
+ logger.debug("building file map", num_messages=len(messages))
72
+ file_dict = {}
73
+
74
+ for message_idx, message in enumerate(messages):
75
+ if isinstance(message, m.UserMessage) and message.files:
76
+ logger.debug(
77
+ "processing files in message",
78
+ num_files=len(message.files),
79
+ message_index=message_idx,
80
+ )
81
+ for file_idx, file in enumerate(message.files):
82
+ logger.debug(
83
+ "processing file",
84
+ file_index=file_idx,
85
+ file_id=file.id,
86
+ content_type=file.content_type,
87
+ )
88
+
89
+ # For now we are not using uploaded files with Gemini, so convert all to base64
90
+ if file.content_type.startswith(
91
+ ("image/", "audio/", "video/", "application/pdf")
92
+ ):
93
+ logger.debug(
94
+ "encoding file to base64",
95
+ filename=file.filename,
96
+ content_type=file.content_type,
97
+ )
98
+ file_dict[str(file.id)] = m.Base64Content(
99
+ content=base64.b64encode(await file.get_content()).decode(
100
+ "utf-8"
101
+ ),
102
+ content_type=file.content_type,
103
+ )
104
+ else:
105
+ raise ValueError(f"Unsupported file type: {file.content_type}")
106
+
107
+ return m.FileMap(mapping=file_dict)
108
+
109
+
110
+ async def prepare_messages(messages: list[m.ModelMessage]) -> list[Any]:
111
+ """Prepare messages from Planar representations into the format expected by PydanticAI.
112
+
113
+ Args:
114
+ messages: List of structured messages.
115
+ file_map: Optional file map for file content.
116
+
117
+ Returns:
118
+ List of messages in PydanticAI format
119
+ """
120
+ pydantic_messages: list[ModelMessage] = []
121
+ file_map = await build_file_map(messages)
122
+
123
+ def append_request_part(part: ModelRequestPart):
124
+ last = (
125
+ pydantic_messages[-1]
126
+ if pydantic_messages and isinstance(pydantic_messages[-1], ModelRequest)
127
+ else None
128
+ )
129
+ if not last:
130
+ last = ModelRequest(parts=[])
131
+ pydantic_messages.append(last)
132
+ last.parts.append(part)
133
+
134
+ def append_response_part(part: ModelResponsePart):
135
+ last = (
136
+ pydantic_messages[-1]
137
+ if pydantic_messages and isinstance(pydantic_messages[-1], ModelResponse)
138
+ else None
139
+ )
140
+ if not last:
141
+ last = ModelResponse(parts=[])
142
+ pydantic_messages.append(last)
143
+ last.parts.append(part)
144
+
145
+ for message in messages:
146
+ if isinstance(message, m.SystemMessage):
147
+ append_request_part(SystemPromptPart(content=message.content or ""))
148
+ elif isinstance(message, m.UserMessage):
149
+ user_content: list[UserContent] = []
150
+ files: list[m.FileContent] = []
151
+ if message.files:
152
+ if not file_map:
153
+ raise ValueError("File map empty while user message has files.")
154
+ for file in message.files:
155
+ if str(file.id) not in file_map.mapping:
156
+ raise ValueError(
157
+ f"File {file} not found in file map {file_map}."
158
+ )
159
+ files.append(file_map.mapping[str(file.id)])
160
+ for file in files:
161
+ match file:
162
+ case m.Base64Content():
163
+ user_content.append(
164
+ BinaryContent(
165
+ data=base64.b64decode(file.content),
166
+ media_type=file.content_type,
167
+ )
168
+ )
169
+ case m.FileIdContent():
170
+ raise Exception(
171
+ "file id handling not implemented yet for PydanticAI"
172
+ )
173
+ if message.content is not None:
174
+ user_content.append(message.content)
175
+ append_request_part(UserPromptPart(content=user_content))
176
+ elif isinstance(message, m.ToolMessage):
177
+ append_request_part(
178
+ ToolReturnPart(
179
+ tool_name="unknown", # FIXME: Planar's ToolMessage doesn't include tool name
180
+ content=message.content,
181
+ tool_call_id=message.tool_call_id,
182
+ )
183
+ )
184
+ elif isinstance(message, m.AssistantMessage):
185
+ if message.content:
186
+ append_response_part(TextPart(content=message.content or ""))
187
+ if message.tool_calls:
188
+ for tc in message.tool_calls:
189
+ append_response_part(
190
+ ToolCallPart(
191
+ tool_call_id=str(tc.id),
192
+ tool_name=tc.name,
193
+ args=tc.arguments,
194
+ )
195
+ )
196
+
197
+ return pydantic_messages
198
+
199
+
200
+ class StreamEventHandler(Protocol):
201
+ def emit(self, event: Literal["text", "think"], data: str) -> None: ...
202
+
203
+
204
+ def setup_native_structured_output(
205
+ request_params: ModelRequestParameters,
206
+ output_type: Type[BaseModel],
207
+ ):
208
+ schema_name = output_type.__name__
209
+ if not re.match(r"^[a-zA-Z0-9_-]+$", output_type.__name__):
210
+ schema_name = re.sub(r"[^a-zA-Z0-9_-]", "_", output_type.__name__)
211
+ json_schema = output_type.model_json_schema()
212
+ request_params.output_object = OutputObjectDefinition(
213
+ name=schema_name,
214
+ description=output_type.__doc__ or "",
215
+ json_schema=json_schema,
216
+ )
217
+ request_params.output_mode = "native"
218
+
219
+
220
+ def setup_tool_structured_output(
221
+ request_params: ModelRequestParameters,
222
+ output_type: Type[BaseModel],
223
+ messages: list[ModelMessage],
224
+ ):
225
+ request_params.output_mode = "tool"
226
+ toolset = OutputToolset.build(
227
+ [output_type],
228
+ name=OUTPUT_TOOL_NAME,
229
+ description=OUTPUT_TOOL_DESCRIPTION,
230
+ )
231
+ assert toolset
232
+ output_tool_defs = toolset._tool_defs
233
+ assert len(output_tool_defs) == 1, "Only one output tool is expected"
234
+ output_tool_defs[0].strict = True
235
+ request_params.output_tools = output_tool_defs
236
+
237
+ if not len(messages):
238
+ return
239
+
240
+ # Some weaker models might not understand that they need to call a function
241
+ # to return the final response. Add a reminder to the end of the system
242
+ # prompt.
243
+ first_request = messages[0]
244
+ first_part = first_request.parts[0]
245
+ if not isinstance(first_part, SystemPromptPart):
246
+ return
247
+ extra_system = textwrap.dedent(
248
+ f"""\n
249
+ WHEN you have a final JSON response, you MUST call the "{OUTPUT_TOOL_NAME}" function/tool with the response to return it. DO NOT RETURN the JSON response directly!!!
250
+ """
251
+ )
252
+ first_part.content += extra_system
253
+
254
+
255
+ def return_native_structured_output[TOutput: BaseModel](
256
+ output_type: Type[TOutput],
257
+ final_tool_calls: list[m.ToolCall],
258
+ content: str,
259
+ thinking: str | None = None,
260
+ ) -> m.CompletionResponse[TOutput]:
261
+ try:
262
+ result = m.CompletionResponse(
263
+ content=output_type.model_validate_json(content),
264
+ tool_calls=final_tool_calls,
265
+ reasoning_content=thinking,
266
+ )
267
+ logger.info(
268
+ "model run completed with structured output",
269
+ content=result.content,
270
+ reasoning_content=result.reasoning_content,
271
+ tool_calls=result.tool_calls,
272
+ )
273
+ return result
274
+ except Exception:
275
+ logger.exception(
276
+ "model output parse failure",
277
+ content=content,
278
+ output_model=output_type,
279
+ )
280
+ raise
281
+
282
+
283
+ def return_tool_structured_output[TOutput: BaseModel](
284
+ output_type: Type[TOutput],
285
+ tool_calls: list[m.ToolCall],
286
+ final_result_tc: m.ToolCall,
287
+ content: str,
288
+ thinking: str | None = None,
289
+ ) -> m.CompletionResponse[TOutput]:
290
+ try:
291
+ result = m.CompletionResponse(
292
+ content=output_type.model_validate(final_result_tc.arguments),
293
+ tool_calls=tool_calls,
294
+ reasoning_content=thinking,
295
+ )
296
+ logger.info(
297
+ "model run completed with structured output",
298
+ content=result.content,
299
+ reasoning_content=result.reasoning_content,
300
+ tool_calls=result.tool_calls,
301
+ )
302
+ return result
303
+ except Exception:
304
+ logger.exception(
305
+ "model output parse failure",
306
+ content=content,
307
+ output_model=output_type,
308
+ )
309
+ raise
310
+
311
+
312
+ class ModelRunResponse[TOutput: BaseModel | str](BaseModel):
313
+ response: m.CompletionResponse[TOutput]
314
+ extra_turns_used: int
315
+
316
+
317
+ async def model_run[TOutput: BaseModel | str](
318
+ model: Model | str,
319
+ max_extra_turns: int,
320
+ model_settings: dict[str, Any] | None = None,
321
+ messages: list[m.ModelMessage] = [],
322
+ tools: list[m.ToolDefinition] = [],
323
+ event_handler: StreamEventHandler | None = None,
324
+ output_type: Type[TOutput] = str,
325
+ ) -> ModelRunResponse[TOutput]:
326
+ # assert that the caller doesn't provide a tool called "final_result"
327
+ if any(tool.name == OUTPUT_TOOL_NAME for tool in tools):
328
+ raise ValueError(
329
+ f'Tool named "{OUTPUT_TOOL_NAME}" is reserved and should not be provided.'
330
+ )
331
+
332
+ extra_turns_used = 0
333
+ model_name = model.model_name if isinstance(model, Model) else model
334
+ # Only enable native structured output for models that support it
335
+ supports_native_structured_output = bool(
336
+ NATIVE_STRUCTURED_OUTPUT_MODELS.search(model_name)
337
+ )
338
+
339
+ request_params = ModelRequestParameters(
340
+ function_tools=[
341
+ ToolDefinition(
342
+ name=tool.name,
343
+ description=tool.description,
344
+ parameters_json_schema=tool.parameters,
345
+ strict=True,
346
+ )
347
+ for tool in tools
348
+ ]
349
+ )
350
+
351
+ structured_output = issubclass(output_type, BaseModel)
352
+
353
+ def emit(event_type: Literal["text", "think"], content: str):
354
+ if event_handler:
355
+ event_handler.emit(event_type, content)
356
+
357
+ history = await prepare_messages(messages=messages)
358
+
359
+ if structured_output:
360
+ if supports_native_structured_output:
361
+ setup_native_structured_output(request_params, output_type)
362
+ else:
363
+ setup_tool_structured_output(request_params, output_type, history)
364
+
365
+ while True:
366
+ think_buffer = []
367
+ text_buffer = []
368
+ current_tool_call = None
369
+ current_tool_args_buffer = []
370
+ current_tool_call_id = None
371
+ tool_calls = []
372
+
373
+ response_parts: list[ModelResponsePart] = []
374
+
375
+ async with model_request_stream(
376
+ model=model,
377
+ messages=history,
378
+ model_request_parameters=request_params,
379
+ model_settings=cast(ModelSettings, model_settings),
380
+ ) as stream:
381
+ async for event in stream:
382
+ match event:
383
+ case PartStartEvent(part=part):
384
+ response_parts.append(part)
385
+ if isinstance(part, TextPart):
386
+ emit("text", part.content)
387
+ text_buffer.append(part.content)
388
+ elif isinstance(part, ThinkingPart):
389
+ emit("think", part.content)
390
+ think_buffer.append(part.content)
391
+ elif isinstance(part, ToolCallPart):
392
+ if current_tool_call is not None:
393
+ # If we already have a tool call, emit the previous one
394
+ tool_calls.append(
395
+ dict(
396
+ name=current_tool_call,
397
+ arg_str="".join(current_tool_args_buffer),
398
+ id=current_tool_call_id,
399
+ )
400
+ )
401
+ current_tool_call = part.tool_name
402
+ current_tool_call_id = part.tool_call_id
403
+ current_tool_args_buffer = []
404
+ if part.args:
405
+ if isinstance(part.args, dict):
406
+ current_tool_args_buffer.append(
407
+ json.dumps(part.args)
408
+ )
409
+ else:
410
+ current_tool_args_buffer.append(part.args)
411
+ case PartDeltaEvent(delta=delta):
412
+ current = response_parts[-1]
413
+ if isinstance(delta, TextPartDelta):
414
+ assert isinstance(current, TextPart)
415
+ emit("text", delta.content_delta)
416
+ text_buffer.append(delta.content_delta)
417
+ current.content += delta.content_delta
418
+ elif (
419
+ isinstance(delta, ThinkingPartDelta) and delta.content_delta
420
+ ):
421
+ assert isinstance(current, ThinkingPart)
422
+ emit("think", delta.content_delta)
423
+ think_buffer.append(delta.content_delta)
424
+ current.content += delta.content_delta
425
+ elif isinstance(delta, ToolCallPartDelta):
426
+ assert isinstance(current, ToolCallPart)
427
+ assert current_tool_call is not None
428
+ assert current_tool_call_id == delta.tool_call_id
429
+ current_tool_args_buffer.append(delta.args_delta)
430
+ if delta.tool_name_delta:
431
+ current.tool_name += delta.tool_name_delta
432
+ if isinstance(delta.args_delta, str):
433
+ if current.args is None:
434
+ current.args = ""
435
+ assert isinstance(current.args, str)
436
+ current.args += delta.args_delta
437
+
438
+ if current_tool_call is not None:
439
+ tool_calls.append(
440
+ dict(
441
+ name=current_tool_call,
442
+ arg_str="".join(current_tool_args_buffer),
443
+ id=current_tool_call_id,
444
+ )
445
+ )
446
+
447
+ content = "".join(text_buffer)
448
+ thinking = "".join(think_buffer)
449
+
450
+ logger.debug(
451
+ "model run completed",
452
+ content=content,
453
+ thinking=thinking,
454
+ tool_calls=tool_calls,
455
+ )
456
+
457
+ try:
458
+ calls = [
459
+ m.ToolCall(
460
+ id=tc["id"],
461
+ name=tc["name"],
462
+ arguments=json.loads(tc["arg_str"]),
463
+ )
464
+ for tc in tool_calls
465
+ ]
466
+
467
+ def is_output_tool(tc):
468
+ return tc.name == OUTPUT_TOOL_NAME
469
+
470
+ final_tool_calls, final_result_tool_calls = partition(is_output_tool, calls)
471
+ except json.JSONDecodeError:
472
+ logger.exception(
473
+ "tool call json parse failure",
474
+ tool_calls=tool_calls,
475
+ )
476
+ raise
477
+
478
+ if final_tool_calls:
479
+ return ModelRunResponse(
480
+ response=m.CompletionResponse(
481
+ tool_calls=final_tool_calls,
482
+ reasoning_content=thinking,
483
+ ),
484
+ extra_turns_used=extra_turns_used,
485
+ )
486
+
487
+ if final_result_tool_calls:
488
+ # only 1 final result tool call is expected
489
+ assert len(final_result_tool_calls) == 1
490
+
491
+ if structured_output:
492
+ try:
493
+ if supports_native_structured_output:
494
+ return ModelRunResponse(
495
+ response=return_native_structured_output(
496
+ output_type, final_tool_calls, content, thinking
497
+ ),
498
+ extra_turns_used=extra_turns_used,
499
+ )
500
+ elif final_result_tool_calls:
501
+ return ModelRunResponse(
502
+ response=return_tool_structured_output(
503
+ output_type,
504
+ final_tool_calls,
505
+ final_result_tool_calls[0],
506
+ content,
507
+ thinking,
508
+ ),
509
+ extra_turns_used=extra_turns_used,
510
+ )
511
+ except ValidationError as e:
512
+ if extra_turns_used >= max_extra_turns:
513
+ raise
514
+ # retry passing the validation error to the LLM
515
+ # first, append the collected response parts to the history
516
+ history.append(ModelResponse(parts=response_parts))
517
+ # now append the ToolResponse with the validation errors
518
+
519
+ retry_part = RetryPromptPart(
520
+ content=format_validation_errors(
521
+ e.errors(), function=len(final_result_tool_calls) > 0
522
+ )
523
+ )
524
+ if final_result_tool_calls:
525
+ retry_part.tool_name = OUTPUT_TOOL_NAME
526
+ retry_part.tool_call_id = cast(str, final_result_tool_calls[0].id)
527
+
528
+ request_parts: list[ModelRequestPart] = [retry_part]
529
+ history.append(ModelRequest(parts=request_parts))
530
+ extra_turns_used += 1
531
+ continue
532
+
533
+ if output_type is not str:
534
+ if extra_turns_used >= max_extra_turns:
535
+ raise ValueError(
536
+ "Model did not return structured output, and no turns left to retry."
537
+ )
538
+ # We can only reach this point if the model did not call send_final_response
539
+ # To return structured output. Report the error back to the LLM and retry
540
+ history.append(ModelResponse(parts=response_parts))
541
+ history.append(
542
+ ModelRequest(
543
+ parts=[
544
+ UserPromptPart(
545
+ content=f'Error processing response. You MUST pass the final JSON response to the "{OUTPUT_TOOL_NAME}" tool/function. DO NOT RETURN the JSON directly!!!'
546
+ )
547
+ ]
548
+ )
549
+ )
550
+ extra_turns_used += 1
551
+ continue
552
+
553
+ result = cast(
554
+ m.CompletionResponse[TOutput],
555
+ m.CompletionResponse(
556
+ content=content,
557
+ tool_calls=final_tool_calls,
558
+ reasoning_content=thinking,
559
+ ),
560
+ )
561
+ logger.info(
562
+ "model run completed with string output",
563
+ content=result.content,
564
+ reasoning_content=result.reasoning_content,
565
+ tool_calls=result.tool_calls,
566
+ )
567
+ return ModelRunResponse(
568
+ response=result,
569
+ extra_turns_used=extra_turns_used,
570
+ )