weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1021 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared helpers for provider adapters."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import re
|
|
19
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
20
|
+
from dataclasses import dataclass, field, replace
|
|
21
|
+
from datetime import UTC, datetime, timedelta
|
|
22
|
+
from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, TypeVar, cast
|
|
23
|
+
from uuid import uuid4
|
|
24
|
+
|
|
25
|
+
from ..deadlines import Deadline
|
|
26
|
+
from ..prompt._types import SupportsDataclass, SupportsToolResult
|
|
27
|
+
from ..prompt.prompt import Prompt, RenderedPrompt
|
|
28
|
+
from ..prompt.protocols import PromptProtocol, ProviderAdapterProtocol
|
|
29
|
+
from ..prompt.structured_output import (
|
|
30
|
+
ARRAY_WRAPPER_KEY,
|
|
31
|
+
OutputParseError,
|
|
32
|
+
parse_dataclass_payload,
|
|
33
|
+
parse_structured_output,
|
|
34
|
+
)
|
|
35
|
+
from ..prompt.tool import Tool, ToolContext, ToolResult
|
|
36
|
+
from ..runtime.events import (
|
|
37
|
+
EventBus,
|
|
38
|
+
HandlerFailure,
|
|
39
|
+
PromptExecuted,
|
|
40
|
+
PromptRendered,
|
|
41
|
+
ToolInvoked,
|
|
42
|
+
)
|
|
43
|
+
from ..runtime.logging import StructuredLogger, get_logger
|
|
44
|
+
from ..runtime.session.dataclasses import is_dataclass_instance
|
|
45
|
+
from ..serde import parse, schema
|
|
46
|
+
from ..tools.errors import DeadlineExceededError, ToolValidationError
|
|
47
|
+
from ..types import JSONValue
|
|
48
|
+
from ._names import LITELLM_ADAPTER_NAME, OPENAI_ADAPTER_NAME, AdapterName
|
|
49
|
+
from ._provider_protocols import (
|
|
50
|
+
ProviderChoice,
|
|
51
|
+
ProviderCompletionCallable,
|
|
52
|
+
ProviderCompletionResponse,
|
|
53
|
+
ProviderFunctionCall,
|
|
54
|
+
ProviderMessage,
|
|
55
|
+
ProviderToolCall,
|
|
56
|
+
)
|
|
57
|
+
from .core import (
|
|
58
|
+
PROMPT_EVALUATION_PHASE_REQUEST,
|
|
59
|
+
PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
60
|
+
PROMPT_EVALUATION_PHASE_TOOL,
|
|
61
|
+
PromptEvaluationError,
|
|
62
|
+
PromptEvaluationPhase,
|
|
63
|
+
PromptResponse,
|
|
64
|
+
SessionProtocol,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if TYPE_CHECKING:
|
|
68
|
+
from ..adapters.core import ProviderAdapter
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
logger: StructuredLogger = get_logger(
|
|
72
|
+
__name__, context={"component": "adapters.shared"}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(slots=True)
|
|
77
|
+
class _RejectedToolParams:
|
|
78
|
+
"""Dataclass used when provider arguments fail validation."""
|
|
79
|
+
|
|
80
|
+
raw_arguments: dict[str, Any]
|
|
81
|
+
error: str
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ToolArgumentsParser(Protocol):
|
|
85
|
+
def __call__(
|
|
86
|
+
self,
|
|
87
|
+
arguments_json: str | None,
|
|
88
|
+
*,
|
|
89
|
+
prompt_name: str,
|
|
90
|
+
provider_payload: dict[str, Any] | None,
|
|
91
|
+
) -> dict[str, Any]: ...
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
ToolChoice = Literal["auto"] | Mapping[str, Any] | None
|
|
95
|
+
"""Supported tool choice directives for provider APIs."""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def deadline_provider_payload(deadline: Deadline | None) -> dict[str, Any] | None:
|
|
99
|
+
"""Return a provider payload snippet describing the active deadline."""
|
|
100
|
+
|
|
101
|
+
if deadline is None:
|
|
102
|
+
return None
|
|
103
|
+
return {"deadline_expires_at": deadline.expires_at.isoformat()}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _raise_tool_deadline_error(
|
|
107
|
+
*, prompt_name: str, tool_name: str, deadline: Deadline
|
|
108
|
+
) -> NoReturn:
|
|
109
|
+
raise PromptEvaluationError(
|
|
110
|
+
f"Deadline expired before executing tool '{tool_name}'.",
|
|
111
|
+
prompt_name=prompt_name,
|
|
112
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
113
|
+
provider_payload=deadline_provider_payload(deadline),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def format_publish_failures(failures: Sequence[HandlerFailure]) -> str:
|
|
118
|
+
"""Summarize publish failures encountered while applying tool results."""
|
|
119
|
+
|
|
120
|
+
messages: list[str] = []
|
|
121
|
+
for failure in failures:
|
|
122
|
+
error = failure.error
|
|
123
|
+
message = str(error).strip()
|
|
124
|
+
if not message:
|
|
125
|
+
message = error.__class__.__name__
|
|
126
|
+
messages.append(message)
|
|
127
|
+
|
|
128
|
+
if not messages:
|
|
129
|
+
return "Reducer errors prevented applying tool result."
|
|
130
|
+
|
|
131
|
+
joined = "; ".join(messages)
|
|
132
|
+
return f"Reducer errors prevented applying tool result: {joined}"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def tool_to_spec(tool: Tool[SupportsDataclass, SupportsToolResult]) -> dict[str, Any]:
|
|
136
|
+
"""Return a provider-agnostic tool specification payload."""
|
|
137
|
+
|
|
138
|
+
parameters_schema = schema(tool.params_type, extra="forbid")
|
|
139
|
+
_ = parameters_schema.pop("title", None)
|
|
140
|
+
return {
|
|
141
|
+
"type": "function",
|
|
142
|
+
"function": {
|
|
143
|
+
"name": tool.name,
|
|
144
|
+
"description": tool.description,
|
|
145
|
+
"parameters": parameters_schema,
|
|
146
|
+
},
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def extract_payload(response: object) -> dict[str, Any] | None:
|
|
151
|
+
"""Return a provider payload from an SDK response when available."""
|
|
152
|
+
|
|
153
|
+
model_dump = getattr(response, "model_dump", None)
|
|
154
|
+
if callable(model_dump):
|
|
155
|
+
try:
|
|
156
|
+
payload = model_dump()
|
|
157
|
+
except Exception: # pragma: no cover - defensive
|
|
158
|
+
return None
|
|
159
|
+
if isinstance(payload, Mapping):
|
|
160
|
+
mapping_payload = _mapping_to_str_dict(cast(Mapping[Any, Any], payload))
|
|
161
|
+
if mapping_payload is not None:
|
|
162
|
+
return mapping_payload
|
|
163
|
+
return None
|
|
164
|
+
if isinstance(response, Mapping): # pragma: no cover - defensive
|
|
165
|
+
mapping_payload = _mapping_to_str_dict(cast(Mapping[Any, Any], response))
|
|
166
|
+
if mapping_payload is not None:
|
|
167
|
+
return mapping_payload
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def first_choice(response: object, *, prompt_name: str) -> object:
|
|
172
|
+
"""Return the first choice in a provider response or raise consistently."""
|
|
173
|
+
|
|
174
|
+
choices = getattr(response, "choices", None)
|
|
175
|
+
if not isinstance(choices, Sequence):
|
|
176
|
+
raise PromptEvaluationError(
|
|
177
|
+
"Provider response did not include any choices.",
|
|
178
|
+
prompt_name=prompt_name,
|
|
179
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
180
|
+
)
|
|
181
|
+
sequence_choices = cast(Sequence[object], choices)
|
|
182
|
+
try:
|
|
183
|
+
return sequence_choices[0]
|
|
184
|
+
except IndexError as error: # pragma: no cover - defensive
|
|
185
|
+
raise PromptEvaluationError(
|
|
186
|
+
"Provider response did not include any choices.",
|
|
187
|
+
prompt_name=prompt_name,
|
|
188
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
189
|
+
) from error
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def serialize_tool_call(tool_call: ProviderToolCall) -> dict[str, Any]:
|
|
193
|
+
"""Serialize a provider tool call into the assistant message payload."""
|
|
194
|
+
|
|
195
|
+
function = tool_call.function
|
|
196
|
+
return {
|
|
197
|
+
"id": getattr(tool_call, "id", None),
|
|
198
|
+
"type": "function",
|
|
199
|
+
"function": {
|
|
200
|
+
"name": function.name,
|
|
201
|
+
"arguments": function.arguments or "{}",
|
|
202
|
+
},
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def parse_tool_arguments(
|
|
207
|
+
arguments_json: str | None,
|
|
208
|
+
*,
|
|
209
|
+
prompt_name: str,
|
|
210
|
+
provider_payload: dict[str, Any] | None,
|
|
211
|
+
) -> dict[str, Any]:
|
|
212
|
+
"""Decode tool call arguments from provider payloads."""
|
|
213
|
+
|
|
214
|
+
if not arguments_json:
|
|
215
|
+
return {}
|
|
216
|
+
try:
|
|
217
|
+
parsed = json.loads(arguments_json)
|
|
218
|
+
except json.JSONDecodeError as error:
|
|
219
|
+
raise PromptEvaluationError(
|
|
220
|
+
"Failed to decode tool call arguments.",
|
|
221
|
+
prompt_name=prompt_name,
|
|
222
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
223
|
+
provider_payload=provider_payload,
|
|
224
|
+
) from error
|
|
225
|
+
if not isinstance(parsed, Mapping):
|
|
226
|
+
raise PromptEvaluationError(
|
|
227
|
+
"Tool call arguments must be a JSON object.",
|
|
228
|
+
prompt_name=prompt_name,
|
|
229
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
230
|
+
provider_payload=provider_payload,
|
|
231
|
+
)
|
|
232
|
+
parsed_mapping = cast(Mapping[Any, Any], parsed)
|
|
233
|
+
arguments: dict[str, Any] = {}
|
|
234
|
+
for key, value in parsed_mapping.items():
|
|
235
|
+
if not isinstance(key, str):
|
|
236
|
+
raise PromptEvaluationError(
|
|
237
|
+
"Tool call arguments must use string keys.",
|
|
238
|
+
prompt_name=prompt_name,
|
|
239
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
240
|
+
provider_payload=provider_payload,
|
|
241
|
+
)
|
|
242
|
+
arguments[key] = value
|
|
243
|
+
return arguments
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def execute_tool_call(
|
|
247
|
+
*,
|
|
248
|
+
adapter_name: AdapterName,
|
|
249
|
+
adapter: ProviderAdapter[Any],
|
|
250
|
+
prompt: Prompt[Any],
|
|
251
|
+
rendered_prompt: RenderedPrompt[Any] | None,
|
|
252
|
+
tool_call: ProviderToolCall,
|
|
253
|
+
tool_registry: Mapping[str, Tool[SupportsDataclass, SupportsToolResult]],
|
|
254
|
+
bus: EventBus,
|
|
255
|
+
session: SessionProtocol,
|
|
256
|
+
prompt_name: str,
|
|
257
|
+
provider_payload: dict[str, Any] | None,
|
|
258
|
+
deadline: Deadline | None,
|
|
259
|
+
format_publish_failures: Callable[[Sequence[HandlerFailure]], str],
|
|
260
|
+
parse_arguments: ToolArgumentsParser,
|
|
261
|
+
logger_override: StructuredLogger | None = None,
|
|
262
|
+
) -> tuple[ToolInvoked, ToolResult[SupportsToolResult]]:
|
|
263
|
+
"""Execute a provider tool call and publish the resulting event."""
|
|
264
|
+
|
|
265
|
+
function = tool_call.function
|
|
266
|
+
tool_name = function.name
|
|
267
|
+
tool = tool_registry.get(tool_name)
|
|
268
|
+
if tool is None:
|
|
269
|
+
raise PromptEvaluationError(
|
|
270
|
+
f"Unknown tool '{tool_name}' requested by provider.",
|
|
271
|
+
prompt_name=prompt_name,
|
|
272
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
273
|
+
provider_payload=provider_payload,
|
|
274
|
+
)
|
|
275
|
+
handler = tool.handler
|
|
276
|
+
if handler is None:
|
|
277
|
+
raise PromptEvaluationError(
|
|
278
|
+
f"Tool '{tool_name}' does not have a registered handler.",
|
|
279
|
+
prompt_name=prompt_name,
|
|
280
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
281
|
+
provider_payload=provider_payload,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
arguments_mapping = parse_arguments(
|
|
285
|
+
function.arguments,
|
|
286
|
+
prompt_name=prompt_name,
|
|
287
|
+
provider_payload=provider_payload,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
call_id = getattr(tool_call, "id", None)
|
|
291
|
+
log = (logger_override or logger).bind(
|
|
292
|
+
adapter=adapter_name,
|
|
293
|
+
prompt=prompt_name,
|
|
294
|
+
tool=tool_name,
|
|
295
|
+
call_id=call_id,
|
|
296
|
+
)
|
|
297
|
+
tool_params: SupportsDataclass | None = None
|
|
298
|
+
tool_result: ToolResult[SupportsToolResult]
|
|
299
|
+
try:
|
|
300
|
+
try:
|
|
301
|
+
parsed_params = parse(tool.params_type, arguments_mapping, extra="forbid")
|
|
302
|
+
except (TypeError, ValueError) as error:
|
|
303
|
+
tool_params = cast(
|
|
304
|
+
SupportsDataclass,
|
|
305
|
+
_RejectedToolParams(
|
|
306
|
+
raw_arguments=dict(arguments_mapping),
|
|
307
|
+
error=str(error),
|
|
308
|
+
),
|
|
309
|
+
)
|
|
310
|
+
raise ToolValidationError(str(error)) from error
|
|
311
|
+
|
|
312
|
+
tool_params = parsed_params
|
|
313
|
+
if deadline is not None and deadline.remaining() <= timedelta(0):
|
|
314
|
+
_raise_tool_deadline_error(
|
|
315
|
+
prompt_name=prompt_name, tool_name=tool_name, deadline=deadline
|
|
316
|
+
)
|
|
317
|
+
context = ToolContext(
|
|
318
|
+
prompt=cast(PromptProtocol[Any], prompt),
|
|
319
|
+
rendered_prompt=rendered_prompt,
|
|
320
|
+
adapter=cast(ProviderAdapterProtocol[Any], adapter),
|
|
321
|
+
session=session,
|
|
322
|
+
event_bus=bus,
|
|
323
|
+
deadline=deadline,
|
|
324
|
+
)
|
|
325
|
+
tool_result = handler(tool_params, context=context)
|
|
326
|
+
except ToolValidationError as error:
|
|
327
|
+
if tool_params is None: # pragma: no cover - defensive
|
|
328
|
+
tool_params = cast(
|
|
329
|
+
SupportsDataclass,
|
|
330
|
+
_RejectedToolParams(
|
|
331
|
+
raw_arguments=dict(arguments_mapping),
|
|
332
|
+
error=str(error),
|
|
333
|
+
),
|
|
334
|
+
)
|
|
335
|
+
log.warning(
|
|
336
|
+
"Tool validation failed.",
|
|
337
|
+
event="tool_validation_failed",
|
|
338
|
+
context={"reason": str(error)},
|
|
339
|
+
)
|
|
340
|
+
tool_result = ToolResult(
|
|
341
|
+
message=f"Tool validation failed: {error}",
|
|
342
|
+
value=None,
|
|
343
|
+
success=False,
|
|
344
|
+
)
|
|
345
|
+
except PromptEvaluationError:
|
|
346
|
+
raise
|
|
347
|
+
except DeadlineExceededError as error:
|
|
348
|
+
raise PromptEvaluationError(
|
|
349
|
+
str(error) or f"Tool '{tool_name}' exceeded the deadline.",
|
|
350
|
+
prompt_name=prompt_name,
|
|
351
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
352
|
+
provider_payload=deadline_provider_payload(deadline),
|
|
353
|
+
) from error
|
|
354
|
+
except Exception as error: # propagate message via ToolResult
|
|
355
|
+
log.exception(
|
|
356
|
+
"Tool handler raised an unexpected exception.",
|
|
357
|
+
event="tool_handler_exception",
|
|
358
|
+
context={"provider_payload": provider_payload},
|
|
359
|
+
)
|
|
360
|
+
tool_result = ToolResult(
|
|
361
|
+
message=f"Tool '{tool_name}' execution failed: {error}",
|
|
362
|
+
value=None,
|
|
363
|
+
success=False,
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
log.info(
|
|
367
|
+
"Tool handler completed.",
|
|
368
|
+
event="tool_handler_completed",
|
|
369
|
+
context={
|
|
370
|
+
"success": tool_result.success,
|
|
371
|
+
"has_value": tool_result.value is not None,
|
|
372
|
+
},
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if tool_params is None: # pragma: no cover - defensive
|
|
376
|
+
raise RuntimeError("Tool parameters were not parsed.")
|
|
377
|
+
|
|
378
|
+
snapshot = session.snapshot()
|
|
379
|
+
session_id = getattr(session, "session_id", None)
|
|
380
|
+
tool_value = tool_result.value
|
|
381
|
+
dataclass_value: SupportsDataclass | None = None
|
|
382
|
+
if is_dataclass_instance(tool_value):
|
|
383
|
+
dataclass_value = cast(SupportsDataclass, tool_value) # pyright: ignore[reportUnnecessaryCast]
|
|
384
|
+
|
|
385
|
+
invocation = ToolInvoked(
|
|
386
|
+
prompt_name=prompt_name,
|
|
387
|
+
adapter=adapter_name,
|
|
388
|
+
name=tool_name,
|
|
389
|
+
params=tool_params,
|
|
390
|
+
result=cast(ToolResult[object], tool_result),
|
|
391
|
+
session_id=session_id,
|
|
392
|
+
created_at=datetime.now(UTC),
|
|
393
|
+
value=dataclass_value,
|
|
394
|
+
call_id=call_id,
|
|
395
|
+
event_id=uuid4(),
|
|
396
|
+
)
|
|
397
|
+
publish_result = bus.publish(invocation)
|
|
398
|
+
if not publish_result.ok:
|
|
399
|
+
session.rollback(snapshot)
|
|
400
|
+
log.warning(
|
|
401
|
+
"Session rollback triggered after publish failure.",
|
|
402
|
+
event="session_rollback_due_to_publish_failure",
|
|
403
|
+
)
|
|
404
|
+
failure_handlers = [
|
|
405
|
+
getattr(failure.handler, "__qualname__", repr(failure.handler))
|
|
406
|
+
for failure in publish_result.errors
|
|
407
|
+
]
|
|
408
|
+
log.error(
|
|
409
|
+
"Tool event publish failed.",
|
|
410
|
+
event="tool_event_publish_failed",
|
|
411
|
+
context={
|
|
412
|
+
"failure_count": len(publish_result.errors),
|
|
413
|
+
"failed_handlers": failure_handlers,
|
|
414
|
+
},
|
|
415
|
+
)
|
|
416
|
+
tool_result.message = format_publish_failures(publish_result.errors)
|
|
417
|
+
else:
|
|
418
|
+
log.debug(
|
|
419
|
+
"Tool event published.",
|
|
420
|
+
event="tool_event_published",
|
|
421
|
+
context={"handler_count": publish_result.handled_count},
|
|
422
|
+
)
|
|
423
|
+
return invocation, tool_result
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def build_json_schema_response_format(
|
|
427
|
+
rendered: RenderedPrompt[Any], prompt_name: str
|
|
428
|
+
) -> dict[str, JSONValue] | None:
|
|
429
|
+
"""Construct a JSON schema response format for structured outputs."""
|
|
430
|
+
|
|
431
|
+
output_type = rendered.output_type
|
|
432
|
+
container = rendered.container
|
|
433
|
+
allow_extra_keys = bool(rendered.allow_extra_keys)
|
|
434
|
+
|
|
435
|
+
if output_type is None or container is None:
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
439
|
+
base_schema = schema(output_type, extra=extra_mode)
|
|
440
|
+
_ = base_schema.pop("title", None)
|
|
441
|
+
|
|
442
|
+
if container == "array":
|
|
443
|
+
schema_payload: dict[str, JSONValue] = {
|
|
444
|
+
"type": "object",
|
|
445
|
+
"properties": {
|
|
446
|
+
ARRAY_WRAPPER_KEY: {
|
|
447
|
+
"type": "array",
|
|
448
|
+
"items": base_schema,
|
|
449
|
+
}
|
|
450
|
+
},
|
|
451
|
+
"required": [ARRAY_WRAPPER_KEY],
|
|
452
|
+
}
|
|
453
|
+
if not allow_extra_keys:
|
|
454
|
+
schema_payload["additionalProperties"] = False
|
|
455
|
+
else:
|
|
456
|
+
schema_payload = base_schema
|
|
457
|
+
|
|
458
|
+
schema_name = _schema_name(prompt_name)
|
|
459
|
+
return {
|
|
460
|
+
"type": "json_schema",
|
|
461
|
+
"json_schema": {
|
|
462
|
+
"name": schema_name,
|
|
463
|
+
"schema": schema_payload,
|
|
464
|
+
},
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def parse_schema_constrained_payload(
|
|
469
|
+
payload: JSONValue, rendered: RenderedPrompt[Any]
|
|
470
|
+
) -> object:
|
|
471
|
+
"""Parse structured provider payloads constrained by prompt schema."""
|
|
472
|
+
|
|
473
|
+
dataclass_type = rendered.output_type
|
|
474
|
+
container = rendered.container
|
|
475
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
476
|
+
|
|
477
|
+
if dataclass_type is None or container is None:
|
|
478
|
+
raise TypeError("Prompt does not declare structured output.")
|
|
479
|
+
|
|
480
|
+
return parse_dataclass_payload(
|
|
481
|
+
dataclass_type,
|
|
482
|
+
container,
|
|
483
|
+
payload,
|
|
484
|
+
allow_extra_keys=bool(allow_extra_keys),
|
|
485
|
+
object_error="Expected provider payload to be a JSON object.",
|
|
486
|
+
array_error="Expected provider payload to be a JSON array.",
|
|
487
|
+
array_item_error="Array item at index {index} is not an object.",
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def message_text_content(content: object) -> str:
|
|
492
|
+
"""Extract text content from provider message payloads."""
|
|
493
|
+
|
|
494
|
+
if isinstance(content, str) or content is None:
|
|
495
|
+
return content or ""
|
|
496
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
497
|
+
content, (str, bytes, bytearray)
|
|
498
|
+
):
|
|
499
|
+
sequence_content = cast(
|
|
500
|
+
Sequence[object],
|
|
501
|
+
content,
|
|
502
|
+
)
|
|
503
|
+
fragments = [_content_part_text(part) for part in sequence_content]
|
|
504
|
+
return "".join(fragments)
|
|
505
|
+
return str(content)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def extract_parsed_content(message: object) -> object | None:
|
|
509
|
+
"""Extract structured payloads surfaced directly by the provider."""
|
|
510
|
+
|
|
511
|
+
parsed = getattr(message, "parsed", None)
|
|
512
|
+
if parsed is not None:
|
|
513
|
+
return parsed
|
|
514
|
+
|
|
515
|
+
content = getattr(message, "content", None)
|
|
516
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
517
|
+
content, (str, bytes, bytearray)
|
|
518
|
+
):
|
|
519
|
+
sequence_content = cast(
|
|
520
|
+
Sequence[object],
|
|
521
|
+
content,
|
|
522
|
+
)
|
|
523
|
+
for part in sequence_content:
|
|
524
|
+
payload = _parsed_payload_from_part(part)
|
|
525
|
+
if payload is not None:
|
|
526
|
+
return payload
|
|
527
|
+
return None
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _schema_name(prompt_name: str) -> str:
|
|
531
|
+
sanitized = re.sub(r"[^a-zA-Z0-9_-]+", "_", prompt_name.strip())
|
|
532
|
+
cleaned = sanitized.strip("_") or "prompt"
|
|
533
|
+
return f"{cleaned}_schema"
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _content_part_text(part: object) -> str:
|
|
537
|
+
if part is None:
|
|
538
|
+
return ""
|
|
539
|
+
if isinstance(part, Mapping):
|
|
540
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
541
|
+
part_type = mapping_part.get("type")
|
|
542
|
+
if part_type in {"output_text", "text"}:
|
|
543
|
+
text_value = mapping_part.get("text")
|
|
544
|
+
if isinstance(text_value, str):
|
|
545
|
+
return text_value
|
|
546
|
+
return ""
|
|
547
|
+
part_type = getattr(part, "type", None)
|
|
548
|
+
if part_type in {"output_text", "text"}:
|
|
549
|
+
text_value = getattr(part, "text", None)
|
|
550
|
+
if isinstance(text_value, str):
|
|
551
|
+
return text_value
|
|
552
|
+
return ""
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def _parsed_payload_from_part(part: object) -> object | None:
|
|
556
|
+
if isinstance(part, Mapping):
|
|
557
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
558
|
+
if mapping_part.get("type") == "output_json":
|
|
559
|
+
return mapping_part.get("json")
|
|
560
|
+
return None
|
|
561
|
+
part_type = getattr(part, "type", None)
|
|
562
|
+
if part_type == "output_json":
|
|
563
|
+
return getattr(part, "json", None)
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def _mapping_to_str_dict(mapping: Mapping[Any, Any]) -> dict[str, Any] | None:
|
|
568
|
+
str_mapping: dict[str, Any] = {}
|
|
569
|
+
for key, value in mapping.items():
|
|
570
|
+
if not isinstance(key, str):
|
|
571
|
+
return None
|
|
572
|
+
str_mapping[key] = value
|
|
573
|
+
return str_mapping
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
__all__ = [
|
|
577
|
+
"LITELLM_ADAPTER_NAME",
|
|
578
|
+
"OPENAI_ADAPTER_NAME",
|
|
579
|
+
"AdapterName",
|
|
580
|
+
"ChoiceSelector",
|
|
581
|
+
"ConversationRequest",
|
|
582
|
+
"ConversationRunner",
|
|
583
|
+
"ProviderChoice",
|
|
584
|
+
"ProviderCompletionCallable",
|
|
585
|
+
"ProviderCompletionResponse",
|
|
586
|
+
"ProviderFunctionCall",
|
|
587
|
+
"ProviderMessage",
|
|
588
|
+
"ProviderToolCall",
|
|
589
|
+
"ToolArgumentsParser",
|
|
590
|
+
"ToolChoice",
|
|
591
|
+
"ToolMessageSerializer",
|
|
592
|
+
"_content_part_text",
|
|
593
|
+
"_parsed_payload_from_part",
|
|
594
|
+
"build_json_schema_response_format",
|
|
595
|
+
"deadline_provider_payload",
|
|
596
|
+
"execute_tool_call",
|
|
597
|
+
"extract_parsed_content",
|
|
598
|
+
"extract_payload",
|
|
599
|
+
"first_choice",
|
|
600
|
+
"format_publish_failures",
|
|
601
|
+
"message_text_content",
|
|
602
|
+
"parse_schema_constrained_payload",
|
|
603
|
+
"parse_tool_arguments",
|
|
604
|
+
"run_conversation",
|
|
605
|
+
"serialize_tool_call",
|
|
606
|
+
"tool_to_spec",
|
|
607
|
+
]
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
OutputT = TypeVar("OutputT")
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
ConversationRequest = Callable[
|
|
614
|
+
[
|
|
615
|
+
list[dict[str, Any]],
|
|
616
|
+
Sequence[Mapping[str, Any]],
|
|
617
|
+
ToolChoice | None,
|
|
618
|
+
Mapping[str, Any] | None,
|
|
619
|
+
],
|
|
620
|
+
object,
|
|
621
|
+
]
|
|
622
|
+
"""Callable responsible for invoking the provider with assembled payloads."""
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
ChoiceSelector = Callable[[object], ProviderChoice]
|
|
626
|
+
"""Callable that extracts the relevant choice from a provider response."""
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
class ToolMessageSerializer(Protocol):
|
|
630
|
+
def __call__(
|
|
631
|
+
self,
|
|
632
|
+
result: ToolResult[SupportsToolResult],
|
|
633
|
+
*,
|
|
634
|
+
payload: object | None = ...,
|
|
635
|
+
) -> object: ...
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@dataclass(slots=True)
|
|
639
|
+
class ConversationRunner[OutputT]:
|
|
640
|
+
"""Coordinate a conversational exchange with a provider."""
|
|
641
|
+
|
|
642
|
+
adapter_name: AdapterName
|
|
643
|
+
adapter: ProviderAdapter[OutputT]
|
|
644
|
+
prompt: Prompt[OutputT]
|
|
645
|
+
prompt_name: str
|
|
646
|
+
rendered: RenderedPrompt[OutputT]
|
|
647
|
+
render_inputs: tuple[SupportsDataclass, ...]
|
|
648
|
+
initial_messages: list[dict[str, Any]]
|
|
649
|
+
parse_output: bool
|
|
650
|
+
bus: EventBus
|
|
651
|
+
session: SessionProtocol
|
|
652
|
+
tool_choice: ToolChoice
|
|
653
|
+
response_format: Mapping[str, Any] | None
|
|
654
|
+
require_structured_output_text: bool
|
|
655
|
+
call_provider: ConversationRequest
|
|
656
|
+
select_choice: ChoiceSelector
|
|
657
|
+
serialize_tool_message_fn: ToolMessageSerializer
|
|
658
|
+
format_publish_failures: Callable[[Sequence[HandlerFailure]], str] = (
|
|
659
|
+
format_publish_failures
|
|
660
|
+
)
|
|
661
|
+
parse_arguments: ToolArgumentsParser = parse_tool_arguments
|
|
662
|
+
logger_override: StructuredLogger | None = None
|
|
663
|
+
deadline: Deadline | None = None
|
|
664
|
+
_log: StructuredLogger = field(init=False)
|
|
665
|
+
_messages: list[dict[str, Any]] = field(init=False)
|
|
666
|
+
_tool_specs: list[dict[str, Any]] = field(init=False)
|
|
667
|
+
_tool_registry: dict[str, Tool[SupportsDataclass, SupportsToolResult]] = field(
|
|
668
|
+
init=False
|
|
669
|
+
)
|
|
670
|
+
_tool_events: list[ToolInvoked] = field(init=False)
|
|
671
|
+
_tool_message_records: list[
|
|
672
|
+
tuple[ToolResult[SupportsToolResult], dict[str, Any]]
|
|
673
|
+
] = field(init=False)
|
|
674
|
+
_provider_payload: dict[str, Any] | None = field(init=False, default=None)
|
|
675
|
+
_next_tool_choice: ToolChoice = field(init=False)
|
|
676
|
+
_should_parse_structured_output: bool = field(init=False)
|
|
677
|
+
|
|
678
|
+
def _raise_deadline_error(
|
|
679
|
+
self, message: str, *, phase: PromptEvaluationPhase
|
|
680
|
+
) -> NoReturn:
|
|
681
|
+
raise PromptEvaluationError(
|
|
682
|
+
message,
|
|
683
|
+
prompt_name=self.prompt_name,
|
|
684
|
+
phase=phase,
|
|
685
|
+
provider_payload=deadline_provider_payload(self.deadline),
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
def _ensure_deadline_remaining(
|
|
689
|
+
self, message: str, *, phase: PromptEvaluationPhase
|
|
690
|
+
) -> None:
|
|
691
|
+
if self.deadline is None:
|
|
692
|
+
return
|
|
693
|
+
if self.deadline.remaining() <= timedelta(0):
|
|
694
|
+
self._raise_deadline_error(message, phase=phase)
|
|
695
|
+
|
|
696
|
+
def run(self) -> PromptResponse[OutputT]:
|
|
697
|
+
"""Execute the conversation loop and return the final response."""
|
|
698
|
+
|
|
699
|
+
self._prepare_payload()
|
|
700
|
+
|
|
701
|
+
while True:
|
|
702
|
+
self._ensure_deadline_remaining(
|
|
703
|
+
"Deadline expired before provider request.",
|
|
704
|
+
phase=PROMPT_EVALUATION_PHASE_REQUEST,
|
|
705
|
+
)
|
|
706
|
+
response = self.call_provider(
|
|
707
|
+
self._messages,
|
|
708
|
+
self._tool_specs,
|
|
709
|
+
self._next_tool_choice if self._tool_specs else None,
|
|
710
|
+
self.response_format,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
self._provider_payload = extract_payload(response)
|
|
714
|
+
choice = self.select_choice(response)
|
|
715
|
+
message = getattr(choice, "message", None)
|
|
716
|
+
if message is None:
|
|
717
|
+
raise PromptEvaluationError(
|
|
718
|
+
"Provider response did not include a message payload.",
|
|
719
|
+
prompt_name=self.prompt_name,
|
|
720
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
721
|
+
provider_payload=self._provider_payload,
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
tool_calls_sequence = getattr(message, "tool_calls", None)
|
|
725
|
+
tool_calls = list(tool_calls_sequence or [])
|
|
726
|
+
|
|
727
|
+
if not tool_calls:
|
|
728
|
+
return self._finalize_response(message)
|
|
729
|
+
|
|
730
|
+
self._handle_tool_calls(message, tool_calls)
|
|
731
|
+
|
|
732
|
+
def _prepare_payload(self) -> None:
|
|
733
|
+
"""Initialize execution state prior to the provider loop."""
|
|
734
|
+
|
|
735
|
+
self._messages = list(self.initial_messages)
|
|
736
|
+
self._log = (self.logger_override or logger).bind(
|
|
737
|
+
adapter=self.adapter_name,
|
|
738
|
+
prompt=self.prompt_name,
|
|
739
|
+
)
|
|
740
|
+
self._log.info(
|
|
741
|
+
"Prompt execution started.",
|
|
742
|
+
event="prompt_execution_started",
|
|
743
|
+
context={
|
|
744
|
+
"tool_count": len(self.rendered.tools),
|
|
745
|
+
"parse_output": self.parse_output,
|
|
746
|
+
},
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
tools = list(self.rendered.tools)
|
|
750
|
+
self._tool_specs = [tool_to_spec(tool) for tool in tools]
|
|
751
|
+
self._tool_registry = {tool.name: tool for tool in tools}
|
|
752
|
+
self._tool_events = []
|
|
753
|
+
self._tool_message_records = []
|
|
754
|
+
self._provider_payload = None
|
|
755
|
+
self._next_tool_choice = self.tool_choice
|
|
756
|
+
self._should_parse_structured_output = (
|
|
757
|
+
self.parse_output
|
|
758
|
+
and self.rendered.output_type is not None
|
|
759
|
+
and self.rendered.container is not None
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
publish_result = self.bus.publish(
|
|
763
|
+
PromptRendered(
|
|
764
|
+
prompt_ns=self.prompt.ns,
|
|
765
|
+
prompt_key=self.prompt.key,
|
|
766
|
+
prompt_name=self.prompt.name,
|
|
767
|
+
adapter=self.adapter_name,
|
|
768
|
+
session_id=getattr(self.session, "session_id", None),
|
|
769
|
+
render_inputs=self.render_inputs,
|
|
770
|
+
rendered_prompt=self.rendered.text,
|
|
771
|
+
created_at=datetime.now(UTC),
|
|
772
|
+
event_id=uuid4(),
|
|
773
|
+
)
|
|
774
|
+
)
|
|
775
|
+
if not publish_result.ok:
|
|
776
|
+
failure_handlers = [
|
|
777
|
+
getattr(failure.handler, "__qualname__", repr(failure.handler))
|
|
778
|
+
for failure in publish_result.errors
|
|
779
|
+
]
|
|
780
|
+
self._log.error(
|
|
781
|
+
"Prompt rendered publish failed.",
|
|
782
|
+
event="prompt_rendered_publish_failed",
|
|
783
|
+
context={
|
|
784
|
+
"failure_count": len(publish_result.errors),
|
|
785
|
+
"failed_handlers": failure_handlers,
|
|
786
|
+
},
|
|
787
|
+
)
|
|
788
|
+
else:
|
|
789
|
+
self._log.debug(
|
|
790
|
+
"Prompt rendered event published.",
|
|
791
|
+
event="prompt_rendered_published",
|
|
792
|
+
context={"handler_count": publish_result.handled_count},
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
def _handle_tool_calls(
|
|
796
|
+
self,
|
|
797
|
+
message: object,
|
|
798
|
+
tool_calls: Sequence[ProviderToolCall],
|
|
799
|
+
) -> None:
|
|
800
|
+
"""Execute provider tool calls and record emitted messages."""
|
|
801
|
+
|
|
802
|
+
assistant_tool_calls = [serialize_tool_call(call) for call in tool_calls]
|
|
803
|
+
self._messages.append(
|
|
804
|
+
{
|
|
805
|
+
"role": "assistant",
|
|
806
|
+
"content": getattr(message, "content", None) or "",
|
|
807
|
+
"tool_calls": assistant_tool_calls,
|
|
808
|
+
}
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
self._log.debug(
|
|
812
|
+
"Processing tool calls.",
|
|
813
|
+
event="prompt_tool_calls_detected",
|
|
814
|
+
context={"count": len(tool_calls)},
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
for tool_call in tool_calls:
|
|
818
|
+
tool_name = getattr(tool_call.function, "name", "tool")
|
|
819
|
+
self._ensure_deadline_remaining(
|
|
820
|
+
f"Deadline expired before executing tool '{tool_name}'.",
|
|
821
|
+
phase=PROMPT_EVALUATION_PHASE_TOOL,
|
|
822
|
+
)
|
|
823
|
+
invocation, tool_result = execute_tool_call(
|
|
824
|
+
adapter_name=self.adapter_name,
|
|
825
|
+
adapter=self.adapter,
|
|
826
|
+
prompt=self.prompt,
|
|
827
|
+
rendered_prompt=self.rendered,
|
|
828
|
+
tool_call=tool_call,
|
|
829
|
+
tool_registry=self._tool_registry,
|
|
830
|
+
bus=self.bus,
|
|
831
|
+
session=self.session,
|
|
832
|
+
prompt_name=self.prompt_name,
|
|
833
|
+
provider_payload=self._provider_payload,
|
|
834
|
+
deadline=self.deadline,
|
|
835
|
+
format_publish_failures=self.format_publish_failures,
|
|
836
|
+
parse_arguments=self.parse_arguments,
|
|
837
|
+
logger_override=self.logger_override,
|
|
838
|
+
)
|
|
839
|
+
self._tool_events.append(invocation)
|
|
840
|
+
|
|
841
|
+
tool_message = {
|
|
842
|
+
"role": "tool",
|
|
843
|
+
"tool_call_id": getattr(tool_call, "id", None),
|
|
844
|
+
"content": self.serialize_tool_message_fn(tool_result),
|
|
845
|
+
}
|
|
846
|
+
self._messages.append(tool_message)
|
|
847
|
+
self._tool_message_records.append((tool_result, tool_message))
|
|
848
|
+
|
|
849
|
+
if isinstance(self._next_tool_choice, Mapping):
|
|
850
|
+
tool_choice_mapping = cast(Mapping[str, object], self._next_tool_choice)
|
|
851
|
+
if tool_choice_mapping.get("type") == "function":
|
|
852
|
+
self._next_tool_choice = "auto"
|
|
853
|
+
|
|
854
|
+
def _finalize_response(self, message: object) -> PromptResponse[OutputT]:
|
|
855
|
+
"""Assemble and publish the final prompt response."""
|
|
856
|
+
|
|
857
|
+
self._ensure_deadline_remaining(
|
|
858
|
+
"Deadline expired while finalizing provider response.",
|
|
859
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
860
|
+
)
|
|
861
|
+
final_text = message_text_content(getattr(message, "content", None))
|
|
862
|
+
output: OutputT | None = None
|
|
863
|
+
text_value: str | None = final_text or None
|
|
864
|
+
|
|
865
|
+
if self._should_parse_structured_output:
|
|
866
|
+
parsed_payload = extract_parsed_content(message)
|
|
867
|
+
if parsed_payload is not None:
|
|
868
|
+
try:
|
|
869
|
+
output = cast(
|
|
870
|
+
OutputT,
|
|
871
|
+
parse_schema_constrained_payload(
|
|
872
|
+
cast(JSONValue, parsed_payload), self.rendered
|
|
873
|
+
),
|
|
874
|
+
)
|
|
875
|
+
except (TypeError, ValueError) as error:
|
|
876
|
+
raise PromptEvaluationError(
|
|
877
|
+
str(error),
|
|
878
|
+
prompt_name=self.prompt_name,
|
|
879
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
880
|
+
provider_payload=self._provider_payload,
|
|
881
|
+
) from error
|
|
882
|
+
else:
|
|
883
|
+
if final_text or not self.require_structured_output_text:
|
|
884
|
+
try:
|
|
885
|
+
output = parse_structured_output(
|
|
886
|
+
final_text or "", self.rendered
|
|
887
|
+
)
|
|
888
|
+
except OutputParseError as error:
|
|
889
|
+
raise PromptEvaluationError(
|
|
890
|
+
error.message,
|
|
891
|
+
prompt_name=self.prompt_name,
|
|
892
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
893
|
+
provider_payload=self._provider_payload,
|
|
894
|
+
) from error
|
|
895
|
+
else:
|
|
896
|
+
raise PromptEvaluationError(
|
|
897
|
+
"Provider response did not include structured output.",
|
|
898
|
+
prompt_name=self.prompt_name,
|
|
899
|
+
phase=PROMPT_EVALUATION_PHASE_RESPONSE,
|
|
900
|
+
provider_payload=self._provider_payload,
|
|
901
|
+
)
|
|
902
|
+
if output is not None:
|
|
903
|
+
text_value = None
|
|
904
|
+
|
|
905
|
+
if (
|
|
906
|
+
output is not None
|
|
907
|
+
and self._tool_message_records
|
|
908
|
+
and self._tool_message_records[-1][0].success
|
|
909
|
+
):
|
|
910
|
+
last_result, last_message = self._tool_message_records[-1]
|
|
911
|
+
last_message["content"] = self.serialize_tool_message_fn(
|
|
912
|
+
last_result, payload=output
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
response_payload = PromptResponse(
|
|
916
|
+
prompt_name=self.prompt_name,
|
|
917
|
+
text=text_value,
|
|
918
|
+
output=output,
|
|
919
|
+
tool_results=tuple(self._tool_events),
|
|
920
|
+
provider_payload=self._provider_payload,
|
|
921
|
+
)
|
|
922
|
+
prompt_value: SupportsDataclass | None = None
|
|
923
|
+
if is_dataclass_instance(output):
|
|
924
|
+
prompt_value = cast(SupportsDataclass, output) # pyright: ignore[reportUnnecessaryCast]
|
|
925
|
+
|
|
926
|
+
publish_result = self.bus.publish(
|
|
927
|
+
PromptExecuted(
|
|
928
|
+
prompt_name=self.prompt_name,
|
|
929
|
+
adapter=self.adapter_name,
|
|
930
|
+
result=cast(PromptResponse[object], response_payload),
|
|
931
|
+
session_id=getattr(self.session, "session_id", None),
|
|
932
|
+
created_at=datetime.now(UTC),
|
|
933
|
+
value=prompt_value,
|
|
934
|
+
event_id=uuid4(),
|
|
935
|
+
)
|
|
936
|
+
)
|
|
937
|
+
if not publish_result.ok:
|
|
938
|
+
failure_handlers = [
|
|
939
|
+
getattr(failure.handler, "__qualname__", repr(failure.handler))
|
|
940
|
+
for failure in publish_result.errors
|
|
941
|
+
]
|
|
942
|
+
self._log.error(
|
|
943
|
+
"Prompt execution publish failed.",
|
|
944
|
+
event="prompt_execution_publish_failed",
|
|
945
|
+
context={
|
|
946
|
+
"failure_count": len(publish_result.errors),
|
|
947
|
+
"failed_handlers": failure_handlers,
|
|
948
|
+
},
|
|
949
|
+
)
|
|
950
|
+
publish_result.raise_if_errors()
|
|
951
|
+
self._log.info(
|
|
952
|
+
"Prompt execution completed.",
|
|
953
|
+
event="prompt_execution_succeeded",
|
|
954
|
+
context={
|
|
955
|
+
"tool_count": len(self._tool_events),
|
|
956
|
+
"has_output": output is not None,
|
|
957
|
+
"text_length": len(text_value or "") if text_value else 0,
|
|
958
|
+
"structured_output": self._should_parse_structured_output,
|
|
959
|
+
"handler_count": publish_result.handled_count,
|
|
960
|
+
},
|
|
961
|
+
)
|
|
962
|
+
return response_payload
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def run_conversation[
|
|
966
|
+
OutputT,
|
|
967
|
+
](
|
|
968
|
+
*,
|
|
969
|
+
adapter_name: AdapterName,
|
|
970
|
+
adapter: ProviderAdapter[OutputT],
|
|
971
|
+
prompt: Prompt[OutputT],
|
|
972
|
+
prompt_name: str,
|
|
973
|
+
rendered: RenderedPrompt[OutputT],
|
|
974
|
+
render_inputs: tuple[SupportsDataclass, ...],
|
|
975
|
+
initial_messages: list[dict[str, Any]],
|
|
976
|
+
parse_output: bool,
|
|
977
|
+
bus: EventBus,
|
|
978
|
+
session: SessionProtocol,
|
|
979
|
+
tool_choice: ToolChoice,
|
|
980
|
+
response_format: Mapping[str, Any] | None,
|
|
981
|
+
require_structured_output_text: bool,
|
|
982
|
+
call_provider: ConversationRequest,
|
|
983
|
+
select_choice: ChoiceSelector,
|
|
984
|
+
serialize_tool_message_fn: ToolMessageSerializer,
|
|
985
|
+
format_publish_failures: Callable[
|
|
986
|
+
[Sequence[HandlerFailure]], str
|
|
987
|
+
] = format_publish_failures,
|
|
988
|
+
parse_arguments: ToolArgumentsParser = parse_tool_arguments,
|
|
989
|
+
logger_override: StructuredLogger | None = None,
|
|
990
|
+
deadline: Deadline | None = None,
|
|
991
|
+
) -> PromptResponse[OutputT]:
|
|
992
|
+
"""Execute a conversational exchange with a provider and return the result."""
|
|
993
|
+
|
|
994
|
+
effective_deadline = deadline or rendered.deadline
|
|
995
|
+
rendered_with_deadline = rendered
|
|
996
|
+
if effective_deadline is not None and rendered.deadline is not effective_deadline:
|
|
997
|
+
rendered_with_deadline = replace(rendered, deadline=effective_deadline)
|
|
998
|
+
|
|
999
|
+
runner = ConversationRunner[OutputT](
|
|
1000
|
+
adapter_name=adapter_name,
|
|
1001
|
+
adapter=adapter,
|
|
1002
|
+
prompt=prompt,
|
|
1003
|
+
prompt_name=prompt_name,
|
|
1004
|
+
rendered=rendered_with_deadline,
|
|
1005
|
+
render_inputs=render_inputs,
|
|
1006
|
+
initial_messages=initial_messages,
|
|
1007
|
+
parse_output=parse_output,
|
|
1008
|
+
bus=bus,
|
|
1009
|
+
session=session,
|
|
1010
|
+
tool_choice=tool_choice,
|
|
1011
|
+
response_format=response_format,
|
|
1012
|
+
require_structured_output_text=require_structured_output_text,
|
|
1013
|
+
call_provider=call_provider,
|
|
1014
|
+
select_choice=select_choice,
|
|
1015
|
+
serialize_tool_message_fn=serialize_tool_message_fn,
|
|
1016
|
+
format_publish_failures=format_publish_failures,
|
|
1017
|
+
parse_arguments=parse_arguments,
|
|
1018
|
+
logger_override=logger_override,
|
|
1019
|
+
deadline=effective_deadline,
|
|
1020
|
+
)
|
|
1021
|
+
return runner.run()
|