planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- planar/_version.py +1 -1
- planar/ai/agent.py +155 -283
- planar/ai/agent_base.py +170 -0
- planar/ai/agent_utils.py +7 -0
- planar/ai/pydantic_ai.py +638 -0
- planar/ai/test_agent_serialization.py +1 -1
- planar/app.py +64 -20
- planar/cli.py +39 -27
- planar/config.py +45 -36
- planar/db/db.py +2 -1
- planar/files/storage/azure_blob.py +343 -0
- planar/files/storage/base.py +7 -0
- planar/files/storage/config.py +70 -7
- planar/files/storage/s3.py +6 -6
- planar/files/storage/test_azure_blob.py +435 -0
- planar/logging/formatter.py +17 -4
- planar/logging/test_formatter.py +327 -0
- planar/registry_items.py +2 -1
- planar/routers/agents_router.py +3 -1
- planar/routers/files.py +11 -2
- planar/routers/models.py +14 -1
- planar/routers/test_agents_router.py +1 -1
- planar/routers/test_files_router.py +49 -0
- planar/routers/test_routes_security.py +5 -7
- planar/routers/test_workflow_router.py +270 -3
- planar/routers/workflow.py +95 -36
- planar/rules/models.py +36 -39
- planar/rules/test_data/account_dormancy_management.json +223 -0
- planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
- planar/rules/test_data/applicant_risk_assessment.json +435 -0
- planar/rules/test_data/booking_fraud_detection.json +407 -0
- planar/rules/test_data/cellular_data_rollover_system.json +258 -0
- planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
- planar/rules/test_data/customer_lifetime_value.json +143 -0
- planar/rules/test_data/import_duties_calculator.json +289 -0
- planar/rules/test_data/insurance_prior_authorization.json +443 -0
- planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
- planar/rules/test_data/order_consolidation_system.json +375 -0
- planar/rules/test_data/portfolio_risk_monitor.json +471 -0
- planar/rules/test_data/supply_chain_risk.json +253 -0
- planar/rules/test_data/warehouse_cross_docking.json +237 -0
- planar/rules/test_rules.py +750 -6
- planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
- planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
- planar/scaffold_templates/pyproject.toml.j2 +1 -1
- planar/security/auth_context.py +21 -0
- planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
- planar/security/authorization.py +9 -15
- planar/security/tests/test_auth_middleware.py +162 -0
- planar/sse/proxy.py +4 -9
- planar/test_app.py +92 -1
- planar/test_cli.py +81 -59
- planar/test_config.py +17 -14
- planar/testing/fixtures.py +325 -0
- planar/testing/planar_test_client.py +5 -2
- planar/utils.py +41 -1
- planar/workflows/execution.py +1 -1
- planar/workflows/orchestrator.py +5 -0
- planar/workflows/serialization.py +12 -6
- planar/workflows/step_core.py +3 -1
- planar/workflows/test_serialization.py +9 -1
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
- planar-0.8.0.dist-info/RECORD +166 -0
- planar/.__init__.py.un~ +0 -0
- planar/._version.py.un~ +0 -0
- planar/.app.py.un~ +0 -0
- planar/.cli.py.un~ +0 -0
- planar/.config.py.un~ +0 -0
- planar/.context.py.un~ +0 -0
- planar/.db.py.un~ +0 -0
- planar/.di.py.un~ +0 -0
- planar/.engine.py.un~ +0 -0
- planar/.files.py.un~ +0 -0
- planar/.log_context.py.un~ +0 -0
- planar/.log_metadata.py.un~ +0 -0
- planar/.logging.py.un~ +0 -0
- planar/.object_registry.py.un~ +0 -0
- planar/.otel.py.un~ +0 -0
- planar/.server.py.un~ +0 -0
- planar/.session.py.un~ +0 -0
- planar/.sqlalchemy.py.un~ +0 -0
- planar/.task_local.py.un~ +0 -0
- planar/.test_app.py.un~ +0 -0
- planar/.test_config.py.un~ +0 -0
- planar/.test_object_config.py.un~ +0 -0
- planar/.test_sqlalchemy.py.un~ +0 -0
- planar/.test_utils.py.un~ +0 -0
- planar/.util.py.un~ +0 -0
- planar/.utils.py.un~ +0 -0
- planar/ai/.__init__.py.un~ +0 -0
- planar/ai/._models.py.un~ +0 -0
- planar/ai/.agent.py.un~ +0 -0
- planar/ai/.agent_utils.py.un~ +0 -0
- planar/ai/.events.py.un~ +0 -0
- planar/ai/.files.py.un~ +0 -0
- planar/ai/.models.py.un~ +0 -0
- planar/ai/.providers.py.un~ +0 -0
- planar/ai/.pydantic_ai.py.un~ +0 -0
- planar/ai/.pydantic_ai_agent.py.un~ +0 -0
- planar/ai/.pydantic_ai_provider.py.un~ +0 -0
- planar/ai/.step.py.un~ +0 -0
- planar/ai/.test_agent.py.un~ +0 -0
- planar/ai/.test_agent_serialization.py.un~ +0 -0
- planar/ai/.test_providers.py.un~ +0 -0
- planar/ai/.utils.py.un~ +0 -0
- planar/ai/providers.py +0 -1088
- planar/ai/test_agent.py +0 -1298
- planar/ai/test_providers.py +0 -463
- planar/db/.db.py.un~ +0 -0
- planar/files/.config.py.un~ +0 -0
- planar/files/.local.py.un~ +0 -0
- planar/files/.local_filesystem.py.un~ +0 -0
- planar/files/.model.py.un~ +0 -0
- planar/files/.models.py.un~ +0 -0
- planar/files/.s3.py.un~ +0 -0
- planar/files/.storage.py.un~ +0 -0
- planar/files/.test_files.py.un~ +0 -0
- planar/files/storage/.__init__.py.un~ +0 -0
- planar/files/storage/.base.py.un~ +0 -0
- planar/files/storage/.config.py.un~ +0 -0
- planar/files/storage/.context.py.un~ +0 -0
- planar/files/storage/.local_directory.py.un~ +0 -0
- planar/files/storage/.test_local_directory.py.un~ +0 -0
- planar/files/storage/.test_s3.py.un~ +0 -0
- planar/human/.human.py.un~ +0 -0
- planar/human/.test_human.py.un~ +0 -0
- planar/logging/.__init__.py.un~ +0 -0
- planar/logging/.attributes.py.un~ +0 -0
- planar/logging/.formatter.py.un~ +0 -0
- planar/logging/.logger.py.un~ +0 -0
- planar/logging/.otel.py.un~ +0 -0
- planar/logging/.tracer.py.un~ +0 -0
- planar/modeling/.mixin.py.un~ +0 -0
- planar/modeling/.storage.py.un~ +0 -0
- planar/modeling/orm/.planar_base_model.py.un~ +0 -0
- planar/object_config/.object_config.py.un~ +0 -0
- planar/routers/.__init__.py.un~ +0 -0
- planar/routers/.agents_router.py.un~ +0 -0
- planar/routers/.crud.py.un~ +0 -0
- planar/routers/.decision.py.un~ +0 -0
- planar/routers/.event.py.un~ +0 -0
- planar/routers/.file_attachment.py.un~ +0 -0
- planar/routers/.files.py.un~ +0 -0
- planar/routers/.files_router.py.un~ +0 -0
- planar/routers/.human.py.un~ +0 -0
- planar/routers/.info.py.un~ +0 -0
- planar/routers/.models.py.un~ +0 -0
- planar/routers/.object_config_router.py.un~ +0 -0
- planar/routers/.rule.py.un~ +0 -0
- planar/routers/.test_object_config_router.py.un~ +0 -0
- planar/routers/.test_workflow_router.py.un~ +0 -0
- planar/routers/.workflow.py.un~ +0 -0
- planar/rules/.decorator.py.un~ +0 -0
- planar/rules/.runner.py.un~ +0 -0
- planar/rules/.test_rules.py.un~ +0 -0
- planar/security/.jwt_middleware.py.un~ +0 -0
- planar/sse/.constants.py.un~ +0 -0
- planar/sse/.example.html.un~ +0 -0
- planar/sse/.hub.py.un~ +0 -0
- planar/sse/.model.py.un~ +0 -0
- planar/sse/.proxy.py.un~ +0 -0
- planar/testing/.client.py.un~ +0 -0
- planar/testing/.memory_storage.py.un~ +0 -0
- planar/testing/.planar_test_client.py.un~ +0 -0
- planar/testing/.predictable_tracer.py.un~ +0 -0
- planar/testing/.synchronizable_tracer.py.un~ +0 -0
- planar/testing/.test_memory_storage.py.un~ +0 -0
- planar/testing/.workflow_observer.py.un~ +0 -0
- planar/workflows/.__init__.py.un~ +0 -0
- planar/workflows/.builtin_steps.py.un~ +0 -0
- planar/workflows/.concurrency_tracing.py.un~ +0 -0
- planar/workflows/.context.py.un~ +0 -0
- planar/workflows/.contrib.py.un~ +0 -0
- planar/workflows/.decorators.py.un~ +0 -0
- planar/workflows/.durable_test.py.un~ +0 -0
- planar/workflows/.errors.py.un~ +0 -0
- planar/workflows/.events.py.un~ +0 -0
- planar/workflows/.exceptions.py.un~ +0 -0
- planar/workflows/.execution.py.un~ +0 -0
- planar/workflows/.human.py.un~ +0 -0
- planar/workflows/.lock.py.un~ +0 -0
- planar/workflows/.misc.py.un~ +0 -0
- planar/workflows/.model.py.un~ +0 -0
- planar/workflows/.models.py.un~ +0 -0
- planar/workflows/.notifications.py.un~ +0 -0
- planar/workflows/.orchestrator.py.un~ +0 -0
- planar/workflows/.runtime.py.un~ +0 -0
- planar/workflows/.serialization.py.un~ +0 -0
- planar/workflows/.step.py.un~ +0 -0
- planar/workflows/.step_core.py.un~ +0 -0
- planar/workflows/.sub_workflow_runner.py.un~ +0 -0
- planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
- planar/workflows/.test_concurrency.py.un~ +0 -0
- planar/workflows/.test_concurrency_detection.py.un~ +0 -0
- planar/workflows/.test_human.py.un~ +0 -0
- planar/workflows/.test_lock_timeout.py.un~ +0 -0
- planar/workflows/.test_orchestrator.py.un~ +0 -0
- planar/workflows/.test_race_conditions.py.un~ +0 -0
- planar/workflows/.test_serialization.py.un~ +0 -0
- planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
- planar/workflows/.test_workflow.py.un~ +0 -0
- planar/workflows/.tracing.py.un~ +0 -0
- planar/workflows/.types.py.un~ +0 -0
- planar/workflows/.util.py.un~ +0 -0
- planar/workflows/.utils.py.un~ +0 -0
- planar/workflows/.workflow.py.un~ +0 -0
- planar/workflows/.workflow_wrapper.py.un~ +0 -0
- planar/workflows/.wrappers.py.un~ +0 -0
- planar-0.5.0.dist-info/RECORD +0 -289
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
- {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
planar/ai/pydantic_ai.py
ADDED
@@ -0,0 +1,638 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
import textwrap
|
6
|
+
from typing import Any, Literal, Protocol, Type, cast
|
7
|
+
|
8
|
+
from pydantic import BaseModel, ValidationError
|
9
|
+
from pydantic_ai import BinaryContent
|
10
|
+
from pydantic_ai._output import OutputObjectDefinition, OutputToolset
|
11
|
+
from pydantic_ai.direct import model_request_stream
|
12
|
+
from pydantic_ai.messages import (
|
13
|
+
ModelMessage,
|
14
|
+
ModelRequest,
|
15
|
+
ModelRequestPart,
|
16
|
+
ModelResponse,
|
17
|
+
ModelResponsePart,
|
18
|
+
PartDeltaEvent,
|
19
|
+
PartStartEvent,
|
20
|
+
RetryPromptPart,
|
21
|
+
SystemPromptPart,
|
22
|
+
TextPart,
|
23
|
+
TextPartDelta,
|
24
|
+
ThinkingPart,
|
25
|
+
ThinkingPartDelta,
|
26
|
+
ToolCallPart,
|
27
|
+
ToolCallPartDelta,
|
28
|
+
ToolReturnPart,
|
29
|
+
UserContent,
|
30
|
+
UserPromptPart,
|
31
|
+
)
|
32
|
+
from pydantic_ai.models import KnownModelName, Model, ModelRequestParameters
|
33
|
+
from pydantic_ai.settings import ModelSettings
|
34
|
+
from pydantic_ai.tools import ToolDefinition
|
35
|
+
from pydantic_core import ErrorDetails
|
36
|
+
|
37
|
+
from planar.ai import models as m
|
38
|
+
from planar.files.models import PlanarFile
|
39
|
+
from planar.logging import get_logger
|
40
|
+
from planar.utils import partition
|
41
|
+
|
42
|
+
logger = get_logger(__name__)
|
43
|
+
|
44
|
+
OUTPUT_TOOL_NAME = "send_final_response"
|
45
|
+
OUTPUT_TOOL_DESCRIPTION = """Called to provide the final response which ends this conversation.
|
46
|
+
Call it with the final JSON response!"""
|
47
|
+
|
48
|
+
NATIVE_STRUCTURED_OUTPUT_MODELS = re.compile(
|
49
|
+
r"""
|
50
|
+
gpt-4o
|
51
|
+
""",
|
52
|
+
re.VERBOSE | re.IGNORECASE,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def format_validation_errors(errors: list[ErrorDetails], function: bool) -> str:
|
57
|
+
lines = [
|
58
|
+
f"You called {OUTPUT_TOOL_NAME} with JSON that doesn't pass validation:"
|
59
|
+
if function
|
60
|
+
else "You returned JSON that did not pass validation:",
|
61
|
+
"",
|
62
|
+
]
|
63
|
+
for error in errors:
|
64
|
+
msg = error["msg"]
|
65
|
+
field_path = ".".join([str(loc) for loc in error["loc"]])
|
66
|
+
input = error["input"]
|
67
|
+
lines.append(f"- {field_path}: {msg} (input: {json.dumps(input)})")
|
68
|
+
|
69
|
+
return "\n".join(lines)
|
70
|
+
|
71
|
+
|
72
|
+
async def openai_try_upload_file(
|
73
|
+
model: KnownModelName | Model, file: PlanarFile
|
74
|
+
) -> m.FileIdContent | None:
|
75
|
+
# Currently pydanticAI doesn't support passing file_ids, but leaving the
|
76
|
+
# implementation here for when they add support.
|
77
|
+
return None
|
78
|
+
|
79
|
+
if file.content_type != "application/pdf":
|
80
|
+
# old implementation only does this for pdf files, so keep the behavior for now
|
81
|
+
return None
|
82
|
+
|
83
|
+
if isinstance(model, str) and not model.startswith("openai:"):
|
84
|
+
# not using openai provider
|
85
|
+
return None
|
86
|
+
|
87
|
+
try:
|
88
|
+
# make this code work with openai as optional dependency
|
89
|
+
from pydantic_ai.models.openai import OpenAIModel
|
90
|
+
except ImportError:
|
91
|
+
return None
|
92
|
+
|
93
|
+
if os.getenv("OPENAI_BASE_URL", None) is not None:
|
94
|
+
# cannot use OpenAI file upload if using a custom base url
|
95
|
+
return None
|
96
|
+
|
97
|
+
if (
|
98
|
+
isinstance(model, OpenAIModel)
|
99
|
+
and model.client.base_url.host != "api.openai.com"
|
100
|
+
):
|
101
|
+
# same as above
|
102
|
+
return None
|
103
|
+
|
104
|
+
logger.debug("uploading pdf file to openai", filename=file.filename)
|
105
|
+
|
106
|
+
# use a separate AsyncClient instance since the model might be provided as a string
|
107
|
+
from openai import AsyncClient
|
108
|
+
|
109
|
+
client = AsyncClient()
|
110
|
+
|
111
|
+
# upload the file to the provider
|
112
|
+
openai_file = await client.files.create(
|
113
|
+
file=(
|
114
|
+
file.filename,
|
115
|
+
await file.get_content(),
|
116
|
+
file.content_type,
|
117
|
+
),
|
118
|
+
purpose="user_data",
|
119
|
+
)
|
120
|
+
logger.info(
|
121
|
+
"uploaded pdf file to openai",
|
122
|
+
filename=file.filename,
|
123
|
+
openai_file_id=openai_file.id,
|
124
|
+
)
|
125
|
+
return m.FileIdContent(content=openai_file.id)
|
126
|
+
|
127
|
+
|
128
|
+
async def build_file_map(
|
129
|
+
model: KnownModelName | Model, messages: list[m.ModelMessage]
|
130
|
+
) -> m.FileMap:
|
131
|
+
logger.debug("building file map", num_messages=len(messages))
|
132
|
+
file_dict = {}
|
133
|
+
|
134
|
+
for message_idx, message in enumerate(messages):
|
135
|
+
if isinstance(message, m.UserMessage) and message.files:
|
136
|
+
logger.debug(
|
137
|
+
"processing files in message",
|
138
|
+
num_files=len(message.files),
|
139
|
+
message_index=message_idx,
|
140
|
+
)
|
141
|
+
for file_idx, file in enumerate(message.files):
|
142
|
+
logger.debug(
|
143
|
+
"processing file",
|
144
|
+
file_index=file_idx,
|
145
|
+
file_id=file.id,
|
146
|
+
content_type=file.content_type,
|
147
|
+
)
|
148
|
+
|
149
|
+
file_content_id = await openai_try_upload_file(model, file)
|
150
|
+
# TODO: add more `try_upload_file` implementations for other providers that support
|
151
|
+
if file_content_id is not None:
|
152
|
+
file_dict[str(file.id)] = file_content_id
|
153
|
+
continue
|
154
|
+
|
155
|
+
# For now we are not using uploaded files with Gemini, so convert all to base64
|
156
|
+
if file.content_type.startswith(
|
157
|
+
("image/", "audio/", "video/", "application/pdf")
|
158
|
+
):
|
159
|
+
logger.debug(
|
160
|
+
"encoding file to base64",
|
161
|
+
filename=file.filename,
|
162
|
+
content_type=file.content_type,
|
163
|
+
)
|
164
|
+
file_dict[str(file.id)] = m.Base64Content(
|
165
|
+
content=base64.b64encode(await file.get_content()).decode(
|
166
|
+
"utf-8"
|
167
|
+
),
|
168
|
+
content_type=file.content_type,
|
169
|
+
)
|
170
|
+
else:
|
171
|
+
raise ValueError(f"Unsupported file type: {file.content_type}")
|
172
|
+
|
173
|
+
return m.FileMap(mapping=file_dict)
|
174
|
+
|
175
|
+
|
176
|
+
async def prepare_messages(
|
177
|
+
model: KnownModelName | Model, messages: list[m.ModelMessage]
|
178
|
+
) -> list[Any]:
|
179
|
+
"""Prepare messages from Planar representations into the format expected by PydanticAI.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
messages: List of structured messages.
|
183
|
+
file_map: Optional file map for file content.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
List of messages in PydanticAI format
|
187
|
+
"""
|
188
|
+
pydantic_messages: list[ModelMessage] = []
|
189
|
+
file_map = await build_file_map(model, messages)
|
190
|
+
|
191
|
+
def append_request_part(part: ModelRequestPart):
|
192
|
+
last = (
|
193
|
+
pydantic_messages[-1]
|
194
|
+
if pydantic_messages and isinstance(pydantic_messages[-1], ModelRequest)
|
195
|
+
else None
|
196
|
+
)
|
197
|
+
if not last:
|
198
|
+
last = ModelRequest(parts=[])
|
199
|
+
pydantic_messages.append(last)
|
200
|
+
last.parts.append(part)
|
201
|
+
|
202
|
+
def append_response_part(part: ModelResponsePart):
|
203
|
+
last = (
|
204
|
+
pydantic_messages[-1]
|
205
|
+
if pydantic_messages and isinstance(pydantic_messages[-1], ModelResponse)
|
206
|
+
else None
|
207
|
+
)
|
208
|
+
if not last:
|
209
|
+
last = ModelResponse(parts=[])
|
210
|
+
pydantic_messages.append(last)
|
211
|
+
last.parts.append(part)
|
212
|
+
|
213
|
+
for message in messages:
|
214
|
+
if isinstance(message, m.SystemMessage):
|
215
|
+
append_request_part(SystemPromptPart(content=message.content or ""))
|
216
|
+
elif isinstance(message, m.UserMessage):
|
217
|
+
user_content: list[UserContent] = []
|
218
|
+
files: list[m.FileContent] = []
|
219
|
+
if message.files:
|
220
|
+
if not file_map:
|
221
|
+
raise ValueError("File map empty while user message has files.")
|
222
|
+
for file in message.files:
|
223
|
+
if str(file.id) not in file_map.mapping:
|
224
|
+
raise ValueError(
|
225
|
+
f"File {file} not found in file map {file_map}."
|
226
|
+
)
|
227
|
+
files.append(file_map.mapping[str(file.id)])
|
228
|
+
for file in files:
|
229
|
+
match file:
|
230
|
+
case m.Base64Content():
|
231
|
+
user_content.append(
|
232
|
+
BinaryContent(
|
233
|
+
data=base64.b64decode(file.content),
|
234
|
+
media_type=file.content_type,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
case m.FileIdContent():
|
238
|
+
raise Exception(
|
239
|
+
"file id handling not implemented yet for PydanticAI"
|
240
|
+
)
|
241
|
+
if message.content is not None:
|
242
|
+
user_content.append(message.content)
|
243
|
+
append_request_part(UserPromptPart(content=user_content))
|
244
|
+
elif isinstance(message, m.ToolMessage):
|
245
|
+
append_request_part(
|
246
|
+
ToolReturnPart(
|
247
|
+
tool_name="unknown", # FIXME: Planar's ToolMessage doesn't include tool name
|
248
|
+
content=message.content,
|
249
|
+
tool_call_id=message.tool_call_id,
|
250
|
+
)
|
251
|
+
)
|
252
|
+
elif isinstance(message, m.AssistantMessage):
|
253
|
+
if message.content:
|
254
|
+
append_response_part(TextPart(content=message.content or ""))
|
255
|
+
if message.tool_calls:
|
256
|
+
for tc in message.tool_calls:
|
257
|
+
append_response_part(
|
258
|
+
ToolCallPart(
|
259
|
+
tool_call_id=str(tc.id),
|
260
|
+
tool_name=tc.name,
|
261
|
+
args=tc.arguments,
|
262
|
+
)
|
263
|
+
)
|
264
|
+
|
265
|
+
return pydantic_messages
|
266
|
+
|
267
|
+
|
268
|
+
class StreamEventHandler(Protocol):
|
269
|
+
def emit(self, event: Literal["text", "think"], data: str) -> None: ...
|
270
|
+
|
271
|
+
|
272
|
+
def setup_native_structured_output(
|
273
|
+
request_params: ModelRequestParameters,
|
274
|
+
output_type: Type[BaseModel],
|
275
|
+
):
|
276
|
+
schema_name = output_type.__name__
|
277
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", output_type.__name__):
|
278
|
+
schema_name = re.sub(r"[^a-zA-Z0-9_-]", "_", output_type.__name__)
|
279
|
+
json_schema = output_type.model_json_schema()
|
280
|
+
request_params.output_object = OutputObjectDefinition(
|
281
|
+
name=schema_name,
|
282
|
+
description=output_type.__doc__ or "",
|
283
|
+
json_schema=json_schema,
|
284
|
+
)
|
285
|
+
request_params.output_mode = "native"
|
286
|
+
|
287
|
+
|
288
|
+
def setup_tool_structured_output(
|
289
|
+
request_params: ModelRequestParameters,
|
290
|
+
output_type: Type[BaseModel],
|
291
|
+
messages: list[ModelMessage],
|
292
|
+
):
|
293
|
+
request_params.output_mode = "tool"
|
294
|
+
toolset = OutputToolset.build(
|
295
|
+
[output_type],
|
296
|
+
name=OUTPUT_TOOL_NAME,
|
297
|
+
description=OUTPUT_TOOL_DESCRIPTION,
|
298
|
+
)
|
299
|
+
assert toolset
|
300
|
+
output_tool_defs = toolset._tool_defs
|
301
|
+
assert len(output_tool_defs) == 1, "Only one output tool is expected"
|
302
|
+
output_tool_defs[0].strict = True
|
303
|
+
request_params.output_tools = output_tool_defs
|
304
|
+
|
305
|
+
if not len(messages):
|
306
|
+
return
|
307
|
+
|
308
|
+
# Some weaker models might not understand that they need to call a function
|
309
|
+
# to return the final response. Add a reminder to the end of the system
|
310
|
+
# prompt.
|
311
|
+
first_request = messages[0]
|
312
|
+
first_part = first_request.parts[0]
|
313
|
+
if not isinstance(first_part, SystemPromptPart):
|
314
|
+
return
|
315
|
+
extra_system = textwrap.dedent(
|
316
|
+
f"""\n
|
317
|
+
WHEN you have a final JSON response, you MUST call the "{OUTPUT_TOOL_NAME}" function/tool with the response to return it. DO NOT RETURN the JSON response directly!!!
|
318
|
+
"""
|
319
|
+
)
|
320
|
+
first_part.content += extra_system
|
321
|
+
|
322
|
+
|
323
|
+
def return_native_structured_output[TOutput: BaseModel](
|
324
|
+
output_type: Type[TOutput],
|
325
|
+
final_tool_calls: list[m.ToolCall],
|
326
|
+
content: str,
|
327
|
+
thinking: str | None = None,
|
328
|
+
) -> m.CompletionResponse[TOutput]:
|
329
|
+
try:
|
330
|
+
result = m.CompletionResponse(
|
331
|
+
content=output_type.model_validate_json(content),
|
332
|
+
tool_calls=final_tool_calls,
|
333
|
+
reasoning_content=thinking,
|
334
|
+
)
|
335
|
+
logger.info(
|
336
|
+
"model run completed with structured output",
|
337
|
+
content=result.content,
|
338
|
+
reasoning_content=result.reasoning_content,
|
339
|
+
tool_calls=result.tool_calls,
|
340
|
+
)
|
341
|
+
return result
|
342
|
+
except Exception:
|
343
|
+
logger.exception(
|
344
|
+
"model output parse failure",
|
345
|
+
content=content,
|
346
|
+
output_model=output_type,
|
347
|
+
)
|
348
|
+
raise
|
349
|
+
|
350
|
+
|
351
|
+
def return_tool_structured_output[TOutput: BaseModel](
|
352
|
+
output_type: Type[TOutput],
|
353
|
+
tool_calls: list[m.ToolCall],
|
354
|
+
final_result_tc: m.ToolCall,
|
355
|
+
content: str,
|
356
|
+
thinking: str | None = None,
|
357
|
+
) -> m.CompletionResponse[TOutput]:
|
358
|
+
try:
|
359
|
+
result = m.CompletionResponse(
|
360
|
+
content=output_type.model_validate(final_result_tc.arguments),
|
361
|
+
tool_calls=tool_calls,
|
362
|
+
reasoning_content=thinking,
|
363
|
+
)
|
364
|
+
logger.info(
|
365
|
+
"model run completed with structured output",
|
366
|
+
content=result.content,
|
367
|
+
reasoning_content=result.reasoning_content,
|
368
|
+
tool_calls=result.tool_calls,
|
369
|
+
)
|
370
|
+
return result
|
371
|
+
except Exception:
|
372
|
+
logger.exception(
|
373
|
+
"model output parse failure",
|
374
|
+
content=content,
|
375
|
+
output_model=output_type,
|
376
|
+
)
|
377
|
+
raise
|
378
|
+
|
379
|
+
|
380
|
+
class ModelRunResponse[TOutput: BaseModel | str](BaseModel):
|
381
|
+
response: m.CompletionResponse[TOutput]
|
382
|
+
extra_turns_used: int
|
383
|
+
|
384
|
+
|
385
|
+
async def model_run[TOutput: BaseModel | str](
|
386
|
+
model: Model | KnownModelName,
|
387
|
+
max_extra_turns: int,
|
388
|
+
model_settings: dict[str, Any] | None = None,
|
389
|
+
messages: list[m.ModelMessage] = [],
|
390
|
+
tools: list[m.ToolDefinition] = [],
|
391
|
+
event_handler: StreamEventHandler | None = None,
|
392
|
+
output_type: Type[TOutput] = str,
|
393
|
+
) -> ModelRunResponse[TOutput]:
|
394
|
+
# assert that the caller doesn't provide a tool called "final_result"
|
395
|
+
if any(tool.name == OUTPUT_TOOL_NAME for tool in tools):
|
396
|
+
raise ValueError(
|
397
|
+
f'Tool named "{OUTPUT_TOOL_NAME}" is reserved and should not be provided.'
|
398
|
+
)
|
399
|
+
|
400
|
+
extra_turns_used = 0
|
401
|
+
model_name = model.model_name if isinstance(model, Model) else model
|
402
|
+
# Only enable native structured output for models that support it
|
403
|
+
supports_native_structured_output = bool(
|
404
|
+
NATIVE_STRUCTURED_OUTPUT_MODELS.search(model_name)
|
405
|
+
)
|
406
|
+
|
407
|
+
request_params = ModelRequestParameters(
|
408
|
+
function_tools=[
|
409
|
+
ToolDefinition(
|
410
|
+
name=tool.name,
|
411
|
+
description=tool.description,
|
412
|
+
parameters_json_schema=tool.parameters,
|
413
|
+
strict=True,
|
414
|
+
)
|
415
|
+
for tool in tools
|
416
|
+
]
|
417
|
+
)
|
418
|
+
|
419
|
+
structured_output = issubclass(output_type, BaseModel)
|
420
|
+
|
421
|
+
def emit(event_type: Literal["text", "think"], content: str):
|
422
|
+
if event_handler:
|
423
|
+
event_handler.emit(event_type, content)
|
424
|
+
|
425
|
+
history = await prepare_messages(model, messages=messages)
|
426
|
+
|
427
|
+
if structured_output:
|
428
|
+
if supports_native_structured_output:
|
429
|
+
setup_native_structured_output(request_params, output_type)
|
430
|
+
else:
|
431
|
+
setup_tool_structured_output(request_params, output_type, history)
|
432
|
+
|
433
|
+
while True:
|
434
|
+
think_buffer = []
|
435
|
+
text_buffer = []
|
436
|
+
current_tool_call = None
|
437
|
+
current_tool_args_buffer = []
|
438
|
+
current_tool_call_id = None
|
439
|
+
tool_calls = []
|
440
|
+
|
441
|
+
response_parts: list[ModelResponsePart] = []
|
442
|
+
|
443
|
+
async with model_request_stream(
|
444
|
+
model=model,
|
445
|
+
messages=history,
|
446
|
+
model_request_parameters=request_params,
|
447
|
+
model_settings=cast(ModelSettings, model_settings),
|
448
|
+
) as stream:
|
449
|
+
async for event in stream:
|
450
|
+
match event:
|
451
|
+
case PartStartEvent(part=part):
|
452
|
+
response_parts.append(part)
|
453
|
+
if isinstance(part, TextPart):
|
454
|
+
emit("text", part.content)
|
455
|
+
text_buffer.append(part.content)
|
456
|
+
elif isinstance(part, ThinkingPart):
|
457
|
+
emit("think", part.content)
|
458
|
+
think_buffer.append(part.content)
|
459
|
+
elif isinstance(part, ToolCallPart):
|
460
|
+
if current_tool_call is not None:
|
461
|
+
# If we already have a tool call, emit the previous one
|
462
|
+
tool_calls.append(
|
463
|
+
dict(
|
464
|
+
name=current_tool_call,
|
465
|
+
arg_str="".join(current_tool_args_buffer),
|
466
|
+
id=current_tool_call_id,
|
467
|
+
)
|
468
|
+
)
|
469
|
+
current_tool_call = part.tool_name
|
470
|
+
current_tool_call_id = part.tool_call_id
|
471
|
+
current_tool_args_buffer = []
|
472
|
+
if part.args:
|
473
|
+
if isinstance(part.args, dict):
|
474
|
+
current_tool_args_buffer.append(
|
475
|
+
json.dumps(part.args)
|
476
|
+
)
|
477
|
+
else:
|
478
|
+
current_tool_args_buffer.append(part.args)
|
479
|
+
case PartDeltaEvent(delta=delta):
|
480
|
+
current = response_parts[-1]
|
481
|
+
if isinstance(delta, TextPartDelta):
|
482
|
+
assert isinstance(current, TextPart)
|
483
|
+
emit("text", delta.content_delta)
|
484
|
+
text_buffer.append(delta.content_delta)
|
485
|
+
current.content += delta.content_delta
|
486
|
+
elif (
|
487
|
+
isinstance(delta, ThinkingPartDelta) and delta.content_delta
|
488
|
+
):
|
489
|
+
assert isinstance(current, ThinkingPart)
|
490
|
+
emit("think", delta.content_delta)
|
491
|
+
think_buffer.append(delta.content_delta)
|
492
|
+
current.content += delta.content_delta
|
493
|
+
elif isinstance(delta, ToolCallPartDelta):
|
494
|
+
assert isinstance(current, ToolCallPart)
|
495
|
+
assert current_tool_call is not None
|
496
|
+
assert current_tool_call_id == delta.tool_call_id
|
497
|
+
current_tool_args_buffer.append(delta.args_delta)
|
498
|
+
if delta.tool_name_delta:
|
499
|
+
current.tool_name += delta.tool_name_delta
|
500
|
+
if isinstance(delta.args_delta, str):
|
501
|
+
if current.args is None:
|
502
|
+
current.args = ""
|
503
|
+
assert isinstance(current.args, str)
|
504
|
+
current.args += delta.args_delta
|
505
|
+
|
506
|
+
if current_tool_call is not None:
|
507
|
+
tool_calls.append(
|
508
|
+
dict(
|
509
|
+
name=current_tool_call,
|
510
|
+
arg_str="".join(current_tool_args_buffer),
|
511
|
+
id=current_tool_call_id,
|
512
|
+
)
|
513
|
+
)
|
514
|
+
|
515
|
+
content = "".join(text_buffer)
|
516
|
+
thinking = "".join(think_buffer)
|
517
|
+
|
518
|
+
logger.debug(
|
519
|
+
"model run completed",
|
520
|
+
content=content,
|
521
|
+
thinking=thinking,
|
522
|
+
tool_calls=tool_calls,
|
523
|
+
)
|
524
|
+
|
525
|
+
try:
|
526
|
+
calls = [
|
527
|
+
m.ToolCall(
|
528
|
+
id=tc["id"],
|
529
|
+
name=tc["name"],
|
530
|
+
arguments=json.loads(tc["arg_str"]),
|
531
|
+
)
|
532
|
+
for tc in tool_calls
|
533
|
+
]
|
534
|
+
|
535
|
+
def is_output_tool(tc):
|
536
|
+
return tc.name == OUTPUT_TOOL_NAME
|
537
|
+
|
538
|
+
final_tool_calls, final_result_tool_calls = partition(is_output_tool, calls)
|
539
|
+
except json.JSONDecodeError:
|
540
|
+
logger.exception(
|
541
|
+
"tool call json parse failure",
|
542
|
+
tool_calls=tool_calls,
|
543
|
+
)
|
544
|
+
raise
|
545
|
+
|
546
|
+
if final_tool_calls:
|
547
|
+
return ModelRunResponse(
|
548
|
+
response=m.CompletionResponse(
|
549
|
+
tool_calls=final_tool_calls,
|
550
|
+
reasoning_content=thinking,
|
551
|
+
),
|
552
|
+
extra_turns_used=extra_turns_used,
|
553
|
+
)
|
554
|
+
|
555
|
+
if final_result_tool_calls:
|
556
|
+
# only 1 final result tool call is expected
|
557
|
+
assert len(final_result_tool_calls) == 1
|
558
|
+
|
559
|
+
if structured_output:
|
560
|
+
try:
|
561
|
+
if supports_native_structured_output:
|
562
|
+
return ModelRunResponse(
|
563
|
+
response=return_native_structured_output(
|
564
|
+
output_type, final_tool_calls, content, thinking
|
565
|
+
),
|
566
|
+
extra_turns_used=extra_turns_used,
|
567
|
+
)
|
568
|
+
elif final_result_tool_calls:
|
569
|
+
return ModelRunResponse(
|
570
|
+
response=return_tool_structured_output(
|
571
|
+
output_type,
|
572
|
+
final_tool_calls,
|
573
|
+
final_result_tool_calls[0],
|
574
|
+
content,
|
575
|
+
thinking,
|
576
|
+
),
|
577
|
+
extra_turns_used=extra_turns_used,
|
578
|
+
)
|
579
|
+
except ValidationError as e:
|
580
|
+
if extra_turns_used >= max_extra_turns:
|
581
|
+
raise
|
582
|
+
# retry passing the validation error to the LLM
|
583
|
+
# first, append the collected response parts to the history
|
584
|
+
history.append(ModelResponse(parts=response_parts))
|
585
|
+
# now append the ToolResponse with the validation errors
|
586
|
+
|
587
|
+
retry_part = RetryPromptPart(
|
588
|
+
content=format_validation_errors(
|
589
|
+
e.errors(), function=len(final_result_tool_calls) > 0
|
590
|
+
)
|
591
|
+
)
|
592
|
+
if final_result_tool_calls:
|
593
|
+
retry_part.tool_name = OUTPUT_TOOL_NAME
|
594
|
+
retry_part.tool_call_id = cast(str, final_result_tool_calls[0].id)
|
595
|
+
|
596
|
+
request_parts: list[ModelRequestPart] = [retry_part]
|
597
|
+
history.append(ModelRequest(parts=request_parts))
|
598
|
+
extra_turns_used += 1
|
599
|
+
continue
|
600
|
+
|
601
|
+
if output_type is not str:
|
602
|
+
if extra_turns_used >= max_extra_turns:
|
603
|
+
raise ValueError(
|
604
|
+
"Model did not return structured output, and no turns left to retry."
|
605
|
+
)
|
606
|
+
# We can only reach this point if the model did not call send_final_response
|
607
|
+
# To return structured output. Report the error back to the LLM and retry
|
608
|
+
history.append(ModelResponse(parts=response_parts))
|
609
|
+
history.append(
|
610
|
+
ModelRequest(
|
611
|
+
parts=[
|
612
|
+
UserPromptPart(
|
613
|
+
content=f'Error processing response. You MUST pass the final JSON response to the "{OUTPUT_TOOL_NAME}" tool/function. DO NOT RETURN the JSON directly!!!'
|
614
|
+
)
|
615
|
+
]
|
616
|
+
)
|
617
|
+
)
|
618
|
+
extra_turns_used += 1
|
619
|
+
continue
|
620
|
+
|
621
|
+
result = cast(
|
622
|
+
m.CompletionResponse[TOutput],
|
623
|
+
m.CompletionResponse(
|
624
|
+
content=content,
|
625
|
+
tool_calls=final_tool_calls,
|
626
|
+
reasoning_content=thinking,
|
627
|
+
),
|
628
|
+
)
|
629
|
+
logger.info(
|
630
|
+
"model run completed with string output",
|
631
|
+
content=result.content,
|
632
|
+
reasoning_content=result.reasoning_content,
|
633
|
+
tool_calls=result.tool_calls,
|
634
|
+
)
|
635
|
+
return ModelRunResponse(
|
636
|
+
response=result,
|
637
|
+
extra_turns_used=extra_turns_used,
|
638
|
+
)
|