weakincentives 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. weakincentives/__init__.py +67 -0
  2. weakincentives/adapters/__init__.py +37 -0
  3. weakincentives/adapters/_names.py +32 -0
  4. weakincentives/adapters/_provider_protocols.py +69 -0
  5. weakincentives/adapters/_tool_messages.py +80 -0
  6. weakincentives/adapters/core.py +102 -0
  7. weakincentives/adapters/litellm.py +254 -0
  8. weakincentives/adapters/openai.py +254 -0
  9. weakincentives/adapters/shared.py +1021 -0
  10. weakincentives/cli/__init__.py +23 -0
  11. weakincentives/cli/wink.py +58 -0
  12. weakincentives/dbc/__init__.py +412 -0
  13. weakincentives/deadlines.py +58 -0
  14. weakincentives/prompt/__init__.py +105 -0
  15. weakincentives/prompt/_generic_params_specializer.py +64 -0
  16. weakincentives/prompt/_normalization.py +48 -0
  17. weakincentives/prompt/_overrides_protocols.py +33 -0
  18. weakincentives/prompt/_types.py +34 -0
  19. weakincentives/prompt/chapter.py +146 -0
  20. weakincentives/prompt/composition.py +281 -0
  21. weakincentives/prompt/errors.py +57 -0
  22. weakincentives/prompt/markdown.py +108 -0
  23. weakincentives/prompt/overrides/__init__.py +59 -0
  24. weakincentives/prompt/overrides/_fs.py +164 -0
  25. weakincentives/prompt/overrides/inspection.py +141 -0
  26. weakincentives/prompt/overrides/local_store.py +275 -0
  27. weakincentives/prompt/overrides/validation.py +534 -0
  28. weakincentives/prompt/overrides/versioning.py +269 -0
  29. weakincentives/prompt/prompt.py +353 -0
  30. weakincentives/prompt/protocols.py +103 -0
  31. weakincentives/prompt/registry.py +375 -0
  32. weakincentives/prompt/rendering.py +288 -0
  33. weakincentives/prompt/response_format.py +60 -0
  34. weakincentives/prompt/section.py +166 -0
  35. weakincentives/prompt/structured_output.py +179 -0
  36. weakincentives/prompt/tool.py +397 -0
  37. weakincentives/prompt/tool_result.py +30 -0
  38. weakincentives/py.typed +0 -0
  39. weakincentives/runtime/__init__.py +82 -0
  40. weakincentives/runtime/events/__init__.py +126 -0
  41. weakincentives/runtime/events/_types.py +110 -0
  42. weakincentives/runtime/logging.py +284 -0
  43. weakincentives/runtime/session/__init__.py +46 -0
  44. weakincentives/runtime/session/_slice_types.py +24 -0
  45. weakincentives/runtime/session/_types.py +55 -0
  46. weakincentives/runtime/session/dataclasses.py +29 -0
  47. weakincentives/runtime/session/protocols.py +34 -0
  48. weakincentives/runtime/session/reducer_context.py +40 -0
  49. weakincentives/runtime/session/reducers.py +82 -0
  50. weakincentives/runtime/session/selectors.py +56 -0
  51. weakincentives/runtime/session/session.py +387 -0
  52. weakincentives/runtime/session/snapshots.py +310 -0
  53. weakincentives/serde/__init__.py +19 -0
  54. weakincentives/serde/_utils.py +240 -0
  55. weakincentives/serde/dataclass_serde.py +55 -0
  56. weakincentives/serde/dump.py +189 -0
  57. weakincentives/serde/parse.py +417 -0
  58. weakincentives/serde/schema.py +260 -0
  59. weakincentives/tools/__init__.py +154 -0
  60. weakincentives/tools/_context.py +38 -0
  61. weakincentives/tools/asteval.py +853 -0
  62. weakincentives/tools/errors.py +26 -0
  63. weakincentives/tools/planning.py +831 -0
  64. weakincentives/tools/podman.py +1655 -0
  65. weakincentives/tools/subagents.py +346 -0
  66. weakincentives/tools/vfs.py +1390 -0
  67. weakincentives/types/__init__.py +35 -0
  68. weakincentives/types/json.py +45 -0
  69. weakincentives-0.9.0.dist-info/METADATA +775 -0
  70. weakincentives-0.9.0.dist-info/RECORD +73 -0
  71. weakincentives-0.9.0.dist-info/WHEEL +4 -0
  72. weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
  73. weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1021 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Shared helpers for provider adapters."""
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import re
19
+ from collections.abc import Callable, Mapping, Sequence
20
+ from dataclasses import dataclass, field, replace
21
+ from datetime import UTC, datetime, timedelta
22
+ from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, TypeVar, cast
23
+ from uuid import uuid4
24
+
25
+ from ..deadlines import Deadline
26
+ from ..prompt._types import SupportsDataclass, SupportsToolResult
27
+ from ..prompt.prompt import Prompt, RenderedPrompt
28
+ from ..prompt.protocols import PromptProtocol, ProviderAdapterProtocol
29
+ from ..prompt.structured_output import (
30
+ ARRAY_WRAPPER_KEY,
31
+ OutputParseError,
32
+ parse_dataclass_payload,
33
+ parse_structured_output,
34
+ )
35
+ from ..prompt.tool import Tool, ToolContext, ToolResult
36
+ from ..runtime.events import (
37
+ EventBus,
38
+ HandlerFailure,
39
+ PromptExecuted,
40
+ PromptRendered,
41
+ ToolInvoked,
42
+ )
43
+ from ..runtime.logging import StructuredLogger, get_logger
44
+ from ..runtime.session.dataclasses import is_dataclass_instance
45
+ from ..serde import parse, schema
46
+ from ..tools.errors import DeadlineExceededError, ToolValidationError
47
+ from ..types import JSONValue
48
+ from ._names import LITELLM_ADAPTER_NAME, OPENAI_ADAPTER_NAME, AdapterName
49
+ from ._provider_protocols import (
50
+ ProviderChoice,
51
+ ProviderCompletionCallable,
52
+ ProviderCompletionResponse,
53
+ ProviderFunctionCall,
54
+ ProviderMessage,
55
+ ProviderToolCall,
56
+ )
57
+ from .core import (
58
+ PROMPT_EVALUATION_PHASE_REQUEST,
59
+ PROMPT_EVALUATION_PHASE_RESPONSE,
60
+ PROMPT_EVALUATION_PHASE_TOOL,
61
+ PromptEvaluationError,
62
+ PromptEvaluationPhase,
63
+ PromptResponse,
64
+ SessionProtocol,
65
+ )
66
+
67
+ if TYPE_CHECKING:
68
+ from ..adapters.core import ProviderAdapter
69
+
70
+
71
+ logger: StructuredLogger = get_logger(
72
+ __name__, context={"component": "adapters.shared"}
73
+ )
74
+
75
+
76
+ @dataclass(slots=True)
77
+ class _RejectedToolParams:
78
+ """Dataclass used when provider arguments fail validation."""
79
+
80
+ raw_arguments: dict[str, Any]
81
+ error: str
82
+
83
+
84
+ class ToolArgumentsParser(Protocol):
85
+ def __call__(
86
+ self,
87
+ arguments_json: str | None,
88
+ *,
89
+ prompt_name: str,
90
+ provider_payload: dict[str, Any] | None,
91
+ ) -> dict[str, Any]: ...
92
+
93
+
94
+ ToolChoice = Literal["auto"] | Mapping[str, Any] | None
95
+ """Supported tool choice directives for provider APIs."""
96
+
97
+
98
+ def deadline_provider_payload(deadline: Deadline | None) -> dict[str, Any] | None:
99
+ """Return a provider payload snippet describing the active deadline."""
100
+
101
+ if deadline is None:
102
+ return None
103
+ return {"deadline_expires_at": deadline.expires_at.isoformat()}
104
+
105
+
106
+ def _raise_tool_deadline_error(
107
+ *, prompt_name: str, tool_name: str, deadline: Deadline
108
+ ) -> NoReturn:
109
+ raise PromptEvaluationError(
110
+ f"Deadline expired before executing tool '{tool_name}'.",
111
+ prompt_name=prompt_name,
112
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
113
+ provider_payload=deadline_provider_payload(deadline),
114
+ )
115
+
116
+
117
+ def format_publish_failures(failures: Sequence[HandlerFailure]) -> str:
118
+ """Summarize publish failures encountered while applying tool results."""
119
+
120
+ messages: list[str] = []
121
+ for failure in failures:
122
+ error = failure.error
123
+ message = str(error).strip()
124
+ if not message:
125
+ message = error.__class__.__name__
126
+ messages.append(message)
127
+
128
+ if not messages:
129
+ return "Reducer errors prevented applying tool result."
130
+
131
+ joined = "; ".join(messages)
132
+ return f"Reducer errors prevented applying tool result: {joined}"
133
+
134
+
135
+ def tool_to_spec(tool: Tool[SupportsDataclass, SupportsToolResult]) -> dict[str, Any]:
136
+ """Return a provider-agnostic tool specification payload."""
137
+
138
+ parameters_schema = schema(tool.params_type, extra="forbid")
139
+ _ = parameters_schema.pop("title", None)
140
+ return {
141
+ "type": "function",
142
+ "function": {
143
+ "name": tool.name,
144
+ "description": tool.description,
145
+ "parameters": parameters_schema,
146
+ },
147
+ }
148
+
149
+
150
+ def extract_payload(response: object) -> dict[str, Any] | None:
151
+ """Return a provider payload from an SDK response when available."""
152
+
153
+ model_dump = getattr(response, "model_dump", None)
154
+ if callable(model_dump):
155
+ try:
156
+ payload = model_dump()
157
+ except Exception: # pragma: no cover - defensive
158
+ return None
159
+ if isinstance(payload, Mapping):
160
+ mapping_payload = _mapping_to_str_dict(cast(Mapping[Any, Any], payload))
161
+ if mapping_payload is not None:
162
+ return mapping_payload
163
+ return None
164
+ if isinstance(response, Mapping): # pragma: no cover - defensive
165
+ mapping_payload = _mapping_to_str_dict(cast(Mapping[Any, Any], response))
166
+ if mapping_payload is not None:
167
+ return mapping_payload
168
+ return None
169
+
170
+
171
+ def first_choice(response: object, *, prompt_name: str) -> object:
172
+ """Return the first choice in a provider response or raise consistently."""
173
+
174
+ choices = getattr(response, "choices", None)
175
+ if not isinstance(choices, Sequence):
176
+ raise PromptEvaluationError(
177
+ "Provider response did not include any choices.",
178
+ prompt_name=prompt_name,
179
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
180
+ )
181
+ sequence_choices = cast(Sequence[object], choices)
182
+ try:
183
+ return sequence_choices[0]
184
+ except IndexError as error: # pragma: no cover - defensive
185
+ raise PromptEvaluationError(
186
+ "Provider response did not include any choices.",
187
+ prompt_name=prompt_name,
188
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
189
+ ) from error
190
+
191
+
192
+ def serialize_tool_call(tool_call: ProviderToolCall) -> dict[str, Any]:
193
+ """Serialize a provider tool call into the assistant message payload."""
194
+
195
+ function = tool_call.function
196
+ return {
197
+ "id": getattr(tool_call, "id", None),
198
+ "type": "function",
199
+ "function": {
200
+ "name": function.name,
201
+ "arguments": function.arguments or "{}",
202
+ },
203
+ }
204
+
205
+
206
+ def parse_tool_arguments(
207
+ arguments_json: str | None,
208
+ *,
209
+ prompt_name: str,
210
+ provider_payload: dict[str, Any] | None,
211
+ ) -> dict[str, Any]:
212
+ """Decode tool call arguments from provider payloads."""
213
+
214
+ if not arguments_json:
215
+ return {}
216
+ try:
217
+ parsed = json.loads(arguments_json)
218
+ except json.JSONDecodeError as error:
219
+ raise PromptEvaluationError(
220
+ "Failed to decode tool call arguments.",
221
+ prompt_name=prompt_name,
222
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
223
+ provider_payload=provider_payload,
224
+ ) from error
225
+ if not isinstance(parsed, Mapping):
226
+ raise PromptEvaluationError(
227
+ "Tool call arguments must be a JSON object.",
228
+ prompt_name=prompt_name,
229
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
230
+ provider_payload=provider_payload,
231
+ )
232
+ parsed_mapping = cast(Mapping[Any, Any], parsed)
233
+ arguments: dict[str, Any] = {}
234
+ for key, value in parsed_mapping.items():
235
+ if not isinstance(key, str):
236
+ raise PromptEvaluationError(
237
+ "Tool call arguments must use string keys.",
238
+ prompt_name=prompt_name,
239
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
240
+ provider_payload=provider_payload,
241
+ )
242
+ arguments[key] = value
243
+ return arguments
244
+
245
+
246
+ def execute_tool_call(
247
+ *,
248
+ adapter_name: AdapterName,
249
+ adapter: ProviderAdapter[Any],
250
+ prompt: Prompt[Any],
251
+ rendered_prompt: RenderedPrompt[Any] | None,
252
+ tool_call: ProviderToolCall,
253
+ tool_registry: Mapping[str, Tool[SupportsDataclass, SupportsToolResult]],
254
+ bus: EventBus,
255
+ session: SessionProtocol,
256
+ prompt_name: str,
257
+ provider_payload: dict[str, Any] | None,
258
+ deadline: Deadline | None,
259
+ format_publish_failures: Callable[[Sequence[HandlerFailure]], str],
260
+ parse_arguments: ToolArgumentsParser,
261
+ logger_override: StructuredLogger | None = None,
262
+ ) -> tuple[ToolInvoked, ToolResult[SupportsToolResult]]:
263
+ """Execute a provider tool call and publish the resulting event."""
264
+
265
+ function = tool_call.function
266
+ tool_name = function.name
267
+ tool = tool_registry.get(tool_name)
268
+ if tool is None:
269
+ raise PromptEvaluationError(
270
+ f"Unknown tool '{tool_name}' requested by provider.",
271
+ prompt_name=prompt_name,
272
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
273
+ provider_payload=provider_payload,
274
+ )
275
+ handler = tool.handler
276
+ if handler is None:
277
+ raise PromptEvaluationError(
278
+ f"Tool '{tool_name}' does not have a registered handler.",
279
+ prompt_name=prompt_name,
280
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
281
+ provider_payload=provider_payload,
282
+ )
283
+
284
+ arguments_mapping = parse_arguments(
285
+ function.arguments,
286
+ prompt_name=prompt_name,
287
+ provider_payload=provider_payload,
288
+ )
289
+
290
+ call_id = getattr(tool_call, "id", None)
291
+ log = (logger_override or logger).bind(
292
+ adapter=adapter_name,
293
+ prompt=prompt_name,
294
+ tool=tool_name,
295
+ call_id=call_id,
296
+ )
297
+ tool_params: SupportsDataclass | None = None
298
+ tool_result: ToolResult[SupportsToolResult]
299
+ try:
300
+ try:
301
+ parsed_params = parse(tool.params_type, arguments_mapping, extra="forbid")
302
+ except (TypeError, ValueError) as error:
303
+ tool_params = cast(
304
+ SupportsDataclass,
305
+ _RejectedToolParams(
306
+ raw_arguments=dict(arguments_mapping),
307
+ error=str(error),
308
+ ),
309
+ )
310
+ raise ToolValidationError(str(error)) from error
311
+
312
+ tool_params = parsed_params
313
+ if deadline is not None and deadline.remaining() <= timedelta(0):
314
+ _raise_tool_deadline_error(
315
+ prompt_name=prompt_name, tool_name=tool_name, deadline=deadline
316
+ )
317
+ context = ToolContext(
318
+ prompt=cast(PromptProtocol[Any], prompt),
319
+ rendered_prompt=rendered_prompt,
320
+ adapter=cast(ProviderAdapterProtocol[Any], adapter),
321
+ session=session,
322
+ event_bus=bus,
323
+ deadline=deadline,
324
+ )
325
+ tool_result = handler(tool_params, context=context)
326
+ except ToolValidationError as error:
327
+ if tool_params is None: # pragma: no cover - defensive
328
+ tool_params = cast(
329
+ SupportsDataclass,
330
+ _RejectedToolParams(
331
+ raw_arguments=dict(arguments_mapping),
332
+ error=str(error),
333
+ ),
334
+ )
335
+ log.warning(
336
+ "Tool validation failed.",
337
+ event="tool_validation_failed",
338
+ context={"reason": str(error)},
339
+ )
340
+ tool_result = ToolResult(
341
+ message=f"Tool validation failed: {error}",
342
+ value=None,
343
+ success=False,
344
+ )
345
+ except PromptEvaluationError:
346
+ raise
347
+ except DeadlineExceededError as error:
348
+ raise PromptEvaluationError(
349
+ str(error) or f"Tool '{tool_name}' exceeded the deadline.",
350
+ prompt_name=prompt_name,
351
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
352
+ provider_payload=deadline_provider_payload(deadline),
353
+ ) from error
354
+ except Exception as error: # propagate message via ToolResult
355
+ log.exception(
356
+ "Tool handler raised an unexpected exception.",
357
+ event="tool_handler_exception",
358
+ context={"provider_payload": provider_payload},
359
+ )
360
+ tool_result = ToolResult(
361
+ message=f"Tool '{tool_name}' execution failed: {error}",
362
+ value=None,
363
+ success=False,
364
+ )
365
+ else:
366
+ log.info(
367
+ "Tool handler completed.",
368
+ event="tool_handler_completed",
369
+ context={
370
+ "success": tool_result.success,
371
+ "has_value": tool_result.value is not None,
372
+ },
373
+ )
374
+
375
+ if tool_params is None: # pragma: no cover - defensive
376
+ raise RuntimeError("Tool parameters were not parsed.")
377
+
378
+ snapshot = session.snapshot()
379
+ session_id = getattr(session, "session_id", None)
380
+ tool_value = tool_result.value
381
+ dataclass_value: SupportsDataclass | None = None
382
+ if is_dataclass_instance(tool_value):
383
+ dataclass_value = cast(SupportsDataclass, tool_value) # pyright: ignore[reportUnnecessaryCast]
384
+
385
+ invocation = ToolInvoked(
386
+ prompt_name=prompt_name,
387
+ adapter=adapter_name,
388
+ name=tool_name,
389
+ params=tool_params,
390
+ result=cast(ToolResult[object], tool_result),
391
+ session_id=session_id,
392
+ created_at=datetime.now(UTC),
393
+ value=dataclass_value,
394
+ call_id=call_id,
395
+ event_id=uuid4(),
396
+ )
397
+ publish_result = bus.publish(invocation)
398
+ if not publish_result.ok:
399
+ session.rollback(snapshot)
400
+ log.warning(
401
+ "Session rollback triggered after publish failure.",
402
+ event="session_rollback_due_to_publish_failure",
403
+ )
404
+ failure_handlers = [
405
+ getattr(failure.handler, "__qualname__", repr(failure.handler))
406
+ for failure in publish_result.errors
407
+ ]
408
+ log.error(
409
+ "Tool event publish failed.",
410
+ event="tool_event_publish_failed",
411
+ context={
412
+ "failure_count": len(publish_result.errors),
413
+ "failed_handlers": failure_handlers,
414
+ },
415
+ )
416
+ tool_result.message = format_publish_failures(publish_result.errors)
417
+ else:
418
+ log.debug(
419
+ "Tool event published.",
420
+ event="tool_event_published",
421
+ context={"handler_count": publish_result.handled_count},
422
+ )
423
+ return invocation, tool_result
424
+
425
+
426
+ def build_json_schema_response_format(
427
+ rendered: RenderedPrompt[Any], prompt_name: str
428
+ ) -> dict[str, JSONValue] | None:
429
+ """Construct a JSON schema response format for structured outputs."""
430
+
431
+ output_type = rendered.output_type
432
+ container = rendered.container
433
+ allow_extra_keys = bool(rendered.allow_extra_keys)
434
+
435
+ if output_type is None or container is None:
436
+ return None
437
+
438
+ extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
439
+ base_schema = schema(output_type, extra=extra_mode)
440
+ _ = base_schema.pop("title", None)
441
+
442
+ if container == "array":
443
+ schema_payload: dict[str, JSONValue] = {
444
+ "type": "object",
445
+ "properties": {
446
+ ARRAY_WRAPPER_KEY: {
447
+ "type": "array",
448
+ "items": base_schema,
449
+ }
450
+ },
451
+ "required": [ARRAY_WRAPPER_KEY],
452
+ }
453
+ if not allow_extra_keys:
454
+ schema_payload["additionalProperties"] = False
455
+ else:
456
+ schema_payload = base_schema
457
+
458
+ schema_name = _schema_name(prompt_name)
459
+ return {
460
+ "type": "json_schema",
461
+ "json_schema": {
462
+ "name": schema_name,
463
+ "schema": schema_payload,
464
+ },
465
+ }
466
+
467
+
468
+ def parse_schema_constrained_payload(
469
+ payload: JSONValue, rendered: RenderedPrompt[Any]
470
+ ) -> object:
471
+ """Parse structured provider payloads constrained by prompt schema."""
472
+
473
+ dataclass_type = rendered.output_type
474
+ container = rendered.container
475
+ allow_extra_keys = rendered.allow_extra_keys
476
+
477
+ if dataclass_type is None or container is None:
478
+ raise TypeError("Prompt does not declare structured output.")
479
+
480
+ return parse_dataclass_payload(
481
+ dataclass_type,
482
+ container,
483
+ payload,
484
+ allow_extra_keys=bool(allow_extra_keys),
485
+ object_error="Expected provider payload to be a JSON object.",
486
+ array_error="Expected provider payload to be a JSON array.",
487
+ array_item_error="Array item at index {index} is not an object.",
488
+ )
489
+
490
+
491
+ def message_text_content(content: object) -> str:
492
+ """Extract text content from provider message payloads."""
493
+
494
+ if isinstance(content, str) or content is None:
495
+ return content or ""
496
+ if isinstance(content, Sequence) and not isinstance(
497
+ content, (str, bytes, bytearray)
498
+ ):
499
+ sequence_content = cast(
500
+ Sequence[object],
501
+ content,
502
+ )
503
+ fragments = [_content_part_text(part) for part in sequence_content]
504
+ return "".join(fragments)
505
+ return str(content)
506
+
507
+
508
+ def extract_parsed_content(message: object) -> object | None:
509
+ """Extract structured payloads surfaced directly by the provider."""
510
+
511
+ parsed = getattr(message, "parsed", None)
512
+ if parsed is not None:
513
+ return parsed
514
+
515
+ content = getattr(message, "content", None)
516
+ if isinstance(content, Sequence) and not isinstance(
517
+ content, (str, bytes, bytearray)
518
+ ):
519
+ sequence_content = cast(
520
+ Sequence[object],
521
+ content,
522
+ )
523
+ for part in sequence_content:
524
+ payload = _parsed_payload_from_part(part)
525
+ if payload is not None:
526
+ return payload
527
+ return None
528
+
529
+
530
+ def _schema_name(prompt_name: str) -> str:
531
+ sanitized = re.sub(r"[^a-zA-Z0-9_-]+", "_", prompt_name.strip())
532
+ cleaned = sanitized.strip("_") or "prompt"
533
+ return f"{cleaned}_schema"
534
+
535
+
536
+ def _content_part_text(part: object) -> str:
537
+ if part is None:
538
+ return ""
539
+ if isinstance(part, Mapping):
540
+ mapping_part = cast(Mapping[str, object], part)
541
+ part_type = mapping_part.get("type")
542
+ if part_type in {"output_text", "text"}:
543
+ text_value = mapping_part.get("text")
544
+ if isinstance(text_value, str):
545
+ return text_value
546
+ return ""
547
+ part_type = getattr(part, "type", None)
548
+ if part_type in {"output_text", "text"}:
549
+ text_value = getattr(part, "text", None)
550
+ if isinstance(text_value, str):
551
+ return text_value
552
+ return ""
553
+
554
+
555
+ def _parsed_payload_from_part(part: object) -> object | None:
556
+ if isinstance(part, Mapping):
557
+ mapping_part = cast(Mapping[str, object], part)
558
+ if mapping_part.get("type") == "output_json":
559
+ return mapping_part.get("json")
560
+ return None
561
+ part_type = getattr(part, "type", None)
562
+ if part_type == "output_json":
563
+ return getattr(part, "json", None)
564
+ return None
565
+
566
+
567
+ def _mapping_to_str_dict(mapping: Mapping[Any, Any]) -> dict[str, Any] | None:
568
+ str_mapping: dict[str, Any] = {}
569
+ for key, value in mapping.items():
570
+ if not isinstance(key, str):
571
+ return None
572
+ str_mapping[key] = value
573
+ return str_mapping
574
+
575
+
576
+ __all__ = [
577
+ "LITELLM_ADAPTER_NAME",
578
+ "OPENAI_ADAPTER_NAME",
579
+ "AdapterName",
580
+ "ChoiceSelector",
581
+ "ConversationRequest",
582
+ "ConversationRunner",
583
+ "ProviderChoice",
584
+ "ProviderCompletionCallable",
585
+ "ProviderCompletionResponse",
586
+ "ProviderFunctionCall",
587
+ "ProviderMessage",
588
+ "ProviderToolCall",
589
+ "ToolArgumentsParser",
590
+ "ToolChoice",
591
+ "ToolMessageSerializer",
592
+ "_content_part_text",
593
+ "_parsed_payload_from_part",
594
+ "build_json_schema_response_format",
595
+ "deadline_provider_payload",
596
+ "execute_tool_call",
597
+ "extract_parsed_content",
598
+ "extract_payload",
599
+ "first_choice",
600
+ "format_publish_failures",
601
+ "message_text_content",
602
+ "parse_schema_constrained_payload",
603
+ "parse_tool_arguments",
604
+ "run_conversation",
605
+ "serialize_tool_call",
606
+ "tool_to_spec",
607
+ ]
608
+
609
+
610
+ OutputT = TypeVar("OutputT")
611
+
612
+
613
+ ConversationRequest = Callable[
614
+ [
615
+ list[dict[str, Any]],
616
+ Sequence[Mapping[str, Any]],
617
+ ToolChoice | None,
618
+ Mapping[str, Any] | None,
619
+ ],
620
+ object,
621
+ ]
622
+ """Callable responsible for invoking the provider with assembled payloads."""
623
+
624
+
625
+ ChoiceSelector = Callable[[object], ProviderChoice]
626
+ """Callable that extracts the relevant choice from a provider response."""
627
+
628
+
629
+ class ToolMessageSerializer(Protocol):
630
+ def __call__(
631
+ self,
632
+ result: ToolResult[SupportsToolResult],
633
+ *,
634
+ payload: object | None = ...,
635
+ ) -> object: ...
636
+
637
+
638
+ @dataclass(slots=True)
639
+ class ConversationRunner[OutputT]:
640
+ """Coordinate a conversational exchange with a provider."""
641
+
642
+ adapter_name: AdapterName
643
+ adapter: ProviderAdapter[OutputT]
644
+ prompt: Prompt[OutputT]
645
+ prompt_name: str
646
+ rendered: RenderedPrompt[OutputT]
647
+ render_inputs: tuple[SupportsDataclass, ...]
648
+ initial_messages: list[dict[str, Any]]
649
+ parse_output: bool
650
+ bus: EventBus
651
+ session: SessionProtocol
652
+ tool_choice: ToolChoice
653
+ response_format: Mapping[str, Any] | None
654
+ require_structured_output_text: bool
655
+ call_provider: ConversationRequest
656
+ select_choice: ChoiceSelector
657
+ serialize_tool_message_fn: ToolMessageSerializer
658
+ format_publish_failures: Callable[[Sequence[HandlerFailure]], str] = (
659
+ format_publish_failures
660
+ )
661
+ parse_arguments: ToolArgumentsParser = parse_tool_arguments
662
+ logger_override: StructuredLogger | None = None
663
+ deadline: Deadline | None = None
664
+ _log: StructuredLogger = field(init=False)
665
+ _messages: list[dict[str, Any]] = field(init=False)
666
+ _tool_specs: list[dict[str, Any]] = field(init=False)
667
+ _tool_registry: dict[str, Tool[SupportsDataclass, SupportsToolResult]] = field(
668
+ init=False
669
+ )
670
+ _tool_events: list[ToolInvoked] = field(init=False)
671
+ _tool_message_records: list[
672
+ tuple[ToolResult[SupportsToolResult], dict[str, Any]]
673
+ ] = field(init=False)
674
+ _provider_payload: dict[str, Any] | None = field(init=False, default=None)
675
+ _next_tool_choice: ToolChoice = field(init=False)
676
+ _should_parse_structured_output: bool = field(init=False)
677
+
678
+ def _raise_deadline_error(
679
+ self, message: str, *, phase: PromptEvaluationPhase
680
+ ) -> NoReturn:
681
+ raise PromptEvaluationError(
682
+ message,
683
+ prompt_name=self.prompt_name,
684
+ phase=phase,
685
+ provider_payload=deadline_provider_payload(self.deadline),
686
+ )
687
+
688
+ def _ensure_deadline_remaining(
689
+ self, message: str, *, phase: PromptEvaluationPhase
690
+ ) -> None:
691
+ if self.deadline is None:
692
+ return
693
+ if self.deadline.remaining() <= timedelta(0):
694
+ self._raise_deadline_error(message, phase=phase)
695
+
696
+ def run(self) -> PromptResponse[OutputT]:
697
+ """Execute the conversation loop and return the final response."""
698
+
699
+ self._prepare_payload()
700
+
701
+ while True:
702
+ self._ensure_deadline_remaining(
703
+ "Deadline expired before provider request.",
704
+ phase=PROMPT_EVALUATION_PHASE_REQUEST,
705
+ )
706
+ response = self.call_provider(
707
+ self._messages,
708
+ self._tool_specs,
709
+ self._next_tool_choice if self._tool_specs else None,
710
+ self.response_format,
711
+ )
712
+
713
+ self._provider_payload = extract_payload(response)
714
+ choice = self.select_choice(response)
715
+ message = getattr(choice, "message", None)
716
+ if message is None:
717
+ raise PromptEvaluationError(
718
+ "Provider response did not include a message payload.",
719
+ prompt_name=self.prompt_name,
720
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
721
+ provider_payload=self._provider_payload,
722
+ )
723
+
724
+ tool_calls_sequence = getattr(message, "tool_calls", None)
725
+ tool_calls = list(tool_calls_sequence or [])
726
+
727
+ if not tool_calls:
728
+ return self._finalize_response(message)
729
+
730
+ self._handle_tool_calls(message, tool_calls)
731
+
732
+ def _prepare_payload(self) -> None:
733
+ """Initialize execution state prior to the provider loop."""
734
+
735
+ self._messages = list(self.initial_messages)
736
+ self._log = (self.logger_override or logger).bind(
737
+ adapter=self.adapter_name,
738
+ prompt=self.prompt_name,
739
+ )
740
+ self._log.info(
741
+ "Prompt execution started.",
742
+ event="prompt_execution_started",
743
+ context={
744
+ "tool_count": len(self.rendered.tools),
745
+ "parse_output": self.parse_output,
746
+ },
747
+ )
748
+
749
+ tools = list(self.rendered.tools)
750
+ self._tool_specs = [tool_to_spec(tool) for tool in tools]
751
+ self._tool_registry = {tool.name: tool for tool in tools}
752
+ self._tool_events = []
753
+ self._tool_message_records = []
754
+ self._provider_payload = None
755
+ self._next_tool_choice = self.tool_choice
756
+ self._should_parse_structured_output = (
757
+ self.parse_output
758
+ and self.rendered.output_type is not None
759
+ and self.rendered.container is not None
760
+ )
761
+
762
+ publish_result = self.bus.publish(
763
+ PromptRendered(
764
+ prompt_ns=self.prompt.ns,
765
+ prompt_key=self.prompt.key,
766
+ prompt_name=self.prompt.name,
767
+ adapter=self.adapter_name,
768
+ session_id=getattr(self.session, "session_id", None),
769
+ render_inputs=self.render_inputs,
770
+ rendered_prompt=self.rendered.text,
771
+ created_at=datetime.now(UTC),
772
+ event_id=uuid4(),
773
+ )
774
+ )
775
+ if not publish_result.ok:
776
+ failure_handlers = [
777
+ getattr(failure.handler, "__qualname__", repr(failure.handler))
778
+ for failure in publish_result.errors
779
+ ]
780
+ self._log.error(
781
+ "Prompt rendered publish failed.",
782
+ event="prompt_rendered_publish_failed",
783
+ context={
784
+ "failure_count": len(publish_result.errors),
785
+ "failed_handlers": failure_handlers,
786
+ },
787
+ )
788
+ else:
789
+ self._log.debug(
790
+ "Prompt rendered event published.",
791
+ event="prompt_rendered_published",
792
+ context={"handler_count": publish_result.handled_count},
793
+ )
794
+
795
+ def _handle_tool_calls(
796
+ self,
797
+ message: object,
798
+ tool_calls: Sequence[ProviderToolCall],
799
+ ) -> None:
800
+ """Execute provider tool calls and record emitted messages."""
801
+
802
+ assistant_tool_calls = [serialize_tool_call(call) for call in tool_calls]
803
+ self._messages.append(
804
+ {
805
+ "role": "assistant",
806
+ "content": getattr(message, "content", None) or "",
807
+ "tool_calls": assistant_tool_calls,
808
+ }
809
+ )
810
+
811
+ self._log.debug(
812
+ "Processing tool calls.",
813
+ event="prompt_tool_calls_detected",
814
+ context={"count": len(tool_calls)},
815
+ )
816
+
817
+ for tool_call in tool_calls:
818
+ tool_name = getattr(tool_call.function, "name", "tool")
819
+ self._ensure_deadline_remaining(
820
+ f"Deadline expired before executing tool '{tool_name}'.",
821
+ phase=PROMPT_EVALUATION_PHASE_TOOL,
822
+ )
823
+ invocation, tool_result = execute_tool_call(
824
+ adapter_name=self.adapter_name,
825
+ adapter=self.adapter,
826
+ prompt=self.prompt,
827
+ rendered_prompt=self.rendered,
828
+ tool_call=tool_call,
829
+ tool_registry=self._tool_registry,
830
+ bus=self.bus,
831
+ session=self.session,
832
+ prompt_name=self.prompt_name,
833
+ provider_payload=self._provider_payload,
834
+ deadline=self.deadline,
835
+ format_publish_failures=self.format_publish_failures,
836
+ parse_arguments=self.parse_arguments,
837
+ logger_override=self.logger_override,
838
+ )
839
+ self._tool_events.append(invocation)
840
+
841
+ tool_message = {
842
+ "role": "tool",
843
+ "tool_call_id": getattr(tool_call, "id", None),
844
+ "content": self.serialize_tool_message_fn(tool_result),
845
+ }
846
+ self._messages.append(tool_message)
847
+ self._tool_message_records.append((tool_result, tool_message))
848
+
849
+ if isinstance(self._next_tool_choice, Mapping):
850
+ tool_choice_mapping = cast(Mapping[str, object], self._next_tool_choice)
851
+ if tool_choice_mapping.get("type") == "function":
852
+ self._next_tool_choice = "auto"
853
+
854
+ def _finalize_response(self, message: object) -> PromptResponse[OutputT]:
855
+ """Assemble and publish the final prompt response."""
856
+
857
+ self._ensure_deadline_remaining(
858
+ "Deadline expired while finalizing provider response.",
859
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
860
+ )
861
+ final_text = message_text_content(getattr(message, "content", None))
862
+ output: OutputT | None = None
863
+ text_value: str | None = final_text or None
864
+
865
+ if self._should_parse_structured_output:
866
+ parsed_payload = extract_parsed_content(message)
867
+ if parsed_payload is not None:
868
+ try:
869
+ output = cast(
870
+ OutputT,
871
+ parse_schema_constrained_payload(
872
+ cast(JSONValue, parsed_payload), self.rendered
873
+ ),
874
+ )
875
+ except (TypeError, ValueError) as error:
876
+ raise PromptEvaluationError(
877
+ str(error),
878
+ prompt_name=self.prompt_name,
879
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
880
+ provider_payload=self._provider_payload,
881
+ ) from error
882
+ else:
883
+ if final_text or not self.require_structured_output_text:
884
+ try:
885
+ output = parse_structured_output(
886
+ final_text or "", self.rendered
887
+ )
888
+ except OutputParseError as error:
889
+ raise PromptEvaluationError(
890
+ error.message,
891
+ prompt_name=self.prompt_name,
892
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
893
+ provider_payload=self._provider_payload,
894
+ ) from error
895
+ else:
896
+ raise PromptEvaluationError(
897
+ "Provider response did not include structured output.",
898
+ prompt_name=self.prompt_name,
899
+ phase=PROMPT_EVALUATION_PHASE_RESPONSE,
900
+ provider_payload=self._provider_payload,
901
+ )
902
+ if output is not None:
903
+ text_value = None
904
+
905
+ if (
906
+ output is not None
907
+ and self._tool_message_records
908
+ and self._tool_message_records[-1][0].success
909
+ ):
910
+ last_result, last_message = self._tool_message_records[-1]
911
+ last_message["content"] = self.serialize_tool_message_fn(
912
+ last_result, payload=output
913
+ )
914
+
915
+ response_payload = PromptResponse(
916
+ prompt_name=self.prompt_name,
917
+ text=text_value,
918
+ output=output,
919
+ tool_results=tuple(self._tool_events),
920
+ provider_payload=self._provider_payload,
921
+ )
922
+ prompt_value: SupportsDataclass | None = None
923
+ if is_dataclass_instance(output):
924
+ prompt_value = cast(SupportsDataclass, output) # pyright: ignore[reportUnnecessaryCast]
925
+
926
+ publish_result = self.bus.publish(
927
+ PromptExecuted(
928
+ prompt_name=self.prompt_name,
929
+ adapter=self.adapter_name,
930
+ result=cast(PromptResponse[object], response_payload),
931
+ session_id=getattr(self.session, "session_id", None),
932
+ created_at=datetime.now(UTC),
933
+ value=prompt_value,
934
+ event_id=uuid4(),
935
+ )
936
+ )
937
+ if not publish_result.ok:
938
+ failure_handlers = [
939
+ getattr(failure.handler, "__qualname__", repr(failure.handler))
940
+ for failure in publish_result.errors
941
+ ]
942
+ self._log.error(
943
+ "Prompt execution publish failed.",
944
+ event="prompt_execution_publish_failed",
945
+ context={
946
+ "failure_count": len(publish_result.errors),
947
+ "failed_handlers": failure_handlers,
948
+ },
949
+ )
950
+ publish_result.raise_if_errors()
951
+ self._log.info(
952
+ "Prompt execution completed.",
953
+ event="prompt_execution_succeeded",
954
+ context={
955
+ "tool_count": len(self._tool_events),
956
+ "has_output": output is not None,
957
+ "text_length": len(text_value or "") if text_value else 0,
958
+ "structured_output": self._should_parse_structured_output,
959
+ "handler_count": publish_result.handled_count,
960
+ },
961
+ )
962
+ return response_payload
963
+
964
+
965
+ def run_conversation[
966
+ OutputT,
967
+ ](
968
+ *,
969
+ adapter_name: AdapterName,
970
+ adapter: ProviderAdapter[OutputT],
971
+ prompt: Prompt[OutputT],
972
+ prompt_name: str,
973
+ rendered: RenderedPrompt[OutputT],
974
+ render_inputs: tuple[SupportsDataclass, ...],
975
+ initial_messages: list[dict[str, Any]],
976
+ parse_output: bool,
977
+ bus: EventBus,
978
+ session: SessionProtocol,
979
+ tool_choice: ToolChoice,
980
+ response_format: Mapping[str, Any] | None,
981
+ require_structured_output_text: bool,
982
+ call_provider: ConversationRequest,
983
+ select_choice: ChoiceSelector,
984
+ serialize_tool_message_fn: ToolMessageSerializer,
985
+ format_publish_failures: Callable[
986
+ [Sequence[HandlerFailure]], str
987
+ ] = format_publish_failures,
988
+ parse_arguments: ToolArgumentsParser = parse_tool_arguments,
989
+ logger_override: StructuredLogger | None = None,
990
+ deadline: Deadline | None = None,
991
+ ) -> PromptResponse[OutputT]:
992
+ """Execute a conversational exchange with a provider and return the result."""
993
+
994
+ effective_deadline = deadline or rendered.deadline
995
+ rendered_with_deadline = rendered
996
+ if effective_deadline is not None and rendered.deadline is not effective_deadline:
997
+ rendered_with_deadline = replace(rendered, deadline=effective_deadline)
998
+
999
+ runner = ConversationRunner[OutputT](
1000
+ adapter_name=adapter_name,
1001
+ adapter=adapter,
1002
+ prompt=prompt,
1003
+ prompt_name=prompt_name,
1004
+ rendered=rendered_with_deadline,
1005
+ render_inputs=render_inputs,
1006
+ initial_messages=initial_messages,
1007
+ parse_output=parse_output,
1008
+ bus=bus,
1009
+ session=session,
1010
+ tool_choice=tool_choice,
1011
+ response_format=response_format,
1012
+ require_structured_output_text=require_structured_output_text,
1013
+ call_provider=call_provider,
1014
+ select_choice=select_choice,
1015
+ serialize_tool_message_fn=serialize_tool_message_fn,
1016
+ format_publish_failures=format_publish_failures,
1017
+ parse_arguments=parse_arguments,
1018
+ logger_override=logger_override,
1019
+ deadline=effective_deadline,
1020
+ )
1021
+ return runner.run()