weakincentives 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +26 -2
- weakincentives/adapters/__init__.py +6 -5
- weakincentives/adapters/core.py +7 -17
- weakincentives/adapters/litellm.py +594 -0
- weakincentives/adapters/openai.py +286 -57
- weakincentives/events.py +103 -0
- weakincentives/examples/__init__.py +67 -0
- weakincentives/examples/code_review_prompt.py +118 -0
- weakincentives/examples/code_review_session.py +171 -0
- weakincentives/examples/code_review_tools.py +376 -0
- weakincentives/{prompts → prompt}/__init__.py +6 -8
- weakincentives/{prompts → prompt}/_types.py +1 -1
- weakincentives/{prompts/text.py → prompt/markdown.py} +19 -9
- weakincentives/{prompts → prompt}/prompt.py +216 -66
- weakincentives/{prompts → prompt}/response_format.py +9 -6
- weakincentives/{prompts → prompt}/section.py +25 -4
- weakincentives/{prompts/structured.py → prompt/structured_output.py} +16 -5
- weakincentives/{prompts → prompt}/tool.py +6 -6
- weakincentives/prompt/versioning.py +144 -0
- weakincentives/serde/__init__.py +0 -14
- weakincentives/serde/dataclass_serde.py +3 -17
- weakincentives/session/__init__.py +31 -0
- weakincentives/session/reducers.py +60 -0
- weakincentives/session/selectors.py +45 -0
- weakincentives/session/session.py +168 -0
- weakincentives/tools/__init__.py +69 -0
- weakincentives/tools/errors.py +22 -0
- weakincentives/tools/planning.py +538 -0
- weakincentives/tools/vfs.py +590 -0
- weakincentives-0.3.0.dist-info/METADATA +231 -0
- weakincentives-0.3.0.dist-info/RECORD +35 -0
- weakincentives-0.2.0.dist-info/METADATA +0 -173
- weakincentives-0.2.0.dist-info/RECORD +0 -20
- /weakincentives/{prompts → prompt}/errors.py +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,19 +15,25 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
+
import re
|
|
18
19
|
from collections.abc import Mapping, Sequence
|
|
19
20
|
from importlib import import_module
|
|
20
|
-
from typing import Any, Protocol, cast
|
|
21
|
-
|
|
22
|
-
from ..
|
|
23
|
-
from ..
|
|
24
|
-
from ..
|
|
25
|
-
from ..
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
21
|
+
from typing import Any, Final, Literal, Protocol, cast
|
|
22
|
+
|
|
23
|
+
from ..events import EventBus, PromptExecuted, ToolInvoked
|
|
24
|
+
from ..prompt._types import SupportsDataclass
|
|
25
|
+
from ..prompt.prompt import Prompt, RenderedPrompt
|
|
26
|
+
from ..prompt.structured_output import (
|
|
27
|
+
ARRAY_WRAPPER_KEY,
|
|
28
|
+
OutputParseError,
|
|
29
|
+
parse_structured_output,
|
|
30
|
+
)
|
|
31
|
+
from ..prompt.tool import Tool, ToolResult
|
|
32
|
+
from ..serde import parse, schema
|
|
33
|
+
from ..tools.errors import ToolValidationError
|
|
34
|
+
from .core import PromptEvaluationError, PromptResponse
|
|
29
35
|
|
|
30
|
-
_ERROR_MESSAGE = (
|
|
36
|
+
_ERROR_MESSAGE: Final[str] = (
|
|
31
37
|
"OpenAI support requires the optional 'openai' dependency. "
|
|
32
38
|
"Install it with `uv sync --extra openai` or `pip install weakincentives[openai]`."
|
|
33
39
|
)
|
|
@@ -44,7 +50,7 @@ class _ToolCall(Protocol):
|
|
|
44
50
|
|
|
45
51
|
|
|
46
52
|
class _Message(Protocol):
|
|
47
|
-
content: str | None
|
|
53
|
+
content: str | Sequence[object] | None
|
|
48
54
|
tool_calls: Sequence[_ToolCall] | None
|
|
49
55
|
|
|
50
56
|
|
|
@@ -96,6 +102,10 @@ def create_openai_client(**kwargs: object) -> _OpenAIProtocol:
|
|
|
96
102
|
return openai_module.OpenAI(**kwargs)
|
|
97
103
|
|
|
98
104
|
|
|
105
|
+
ToolChoice = Literal["auto"] | Mapping[str, Any] | None
|
|
106
|
+
"""Supported tool choice directives for provider APIs."""
|
|
107
|
+
|
|
108
|
+
|
|
99
109
|
class OpenAIAdapter:
|
|
100
110
|
"""Adapter that evaluates prompts against OpenAI's Responses API."""
|
|
101
111
|
|
|
@@ -103,7 +113,8 @@ class OpenAIAdapter:
|
|
|
103
113
|
self,
|
|
104
114
|
*,
|
|
105
115
|
model: str,
|
|
106
|
-
tool_choice:
|
|
116
|
+
tool_choice: ToolChoice = "auto",
|
|
117
|
+
use_native_response_format: bool = True,
|
|
107
118
|
client: _OpenAIProtocol | None = None,
|
|
108
119
|
client_factory: _OpenAIClientFactory | None = None,
|
|
109
120
|
client_kwargs: Mapping[str, object] | None = None,
|
|
@@ -123,25 +134,56 @@ class OpenAIAdapter:
|
|
|
123
134
|
|
|
124
135
|
self._client = client
|
|
125
136
|
self._model = model
|
|
126
|
-
self._tool_choice = tool_choice
|
|
137
|
+
self._tool_choice: ToolChoice = tool_choice
|
|
138
|
+
self._use_native_response_format = use_native_response_format
|
|
127
139
|
|
|
128
140
|
def evaluate[OutputT](
|
|
129
141
|
self,
|
|
130
142
|
prompt: Prompt[OutputT],
|
|
131
143
|
*params: SupportsDataclass,
|
|
132
144
|
parse_output: bool = True,
|
|
145
|
+
bus: EventBus,
|
|
133
146
|
) -> PromptResponse[OutputT]:
|
|
134
147
|
prompt_name = prompt.name or prompt.__class__.__name__
|
|
135
|
-
|
|
148
|
+
|
|
149
|
+
has_structured_output = (
|
|
150
|
+
getattr(prompt, "_output_type", None) is not None
|
|
151
|
+
and getattr(prompt, "_output_container", None) is not None
|
|
152
|
+
)
|
|
153
|
+
should_disable_instructions = (
|
|
154
|
+
parse_output
|
|
155
|
+
and has_structured_output
|
|
156
|
+
and self._use_native_response_format
|
|
157
|
+
and getattr(prompt, "inject_output_instructions", False)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if should_disable_instructions:
|
|
161
|
+
rendered = prompt.render(
|
|
162
|
+
*params,
|
|
163
|
+
inject_output_instructions=False,
|
|
164
|
+
) # type: ignore[reportArgumentType]
|
|
165
|
+
else:
|
|
166
|
+
rendered = prompt.render(*params) # type: ignore[reportArgumentType]
|
|
136
167
|
messages: list[dict[str, Any]] = [
|
|
137
168
|
{"role": "system", "content": rendered.text},
|
|
138
169
|
]
|
|
139
170
|
|
|
171
|
+
should_parse_structured_output = (
|
|
172
|
+
parse_output
|
|
173
|
+
and rendered.output_type is not None
|
|
174
|
+
and rendered.container is not None
|
|
175
|
+
)
|
|
176
|
+
response_format: dict[str, Any] | None = None
|
|
177
|
+
if should_parse_structured_output and self._use_native_response_format:
|
|
178
|
+
response_format = _build_json_schema_response_format(rendered, prompt_name)
|
|
179
|
+
|
|
140
180
|
tools = list(rendered.tools)
|
|
141
181
|
tool_specs = [_tool_to_openai_spec(tool) for tool in tools]
|
|
142
182
|
tool_registry = {tool.name: tool for tool in tools}
|
|
143
|
-
|
|
183
|
+
tool_events: list[ToolInvoked] = []
|
|
144
184
|
provider_payload: dict[str, Any] | None = None
|
|
185
|
+
# Allow forcing a specific tool once, then fall back to provider defaults.
|
|
186
|
+
next_tool_choice: ToolChoice = self._tool_choice
|
|
145
187
|
|
|
146
188
|
while True:
|
|
147
189
|
request_payload: dict[str, Any] = {
|
|
@@ -150,8 +192,10 @@ class OpenAIAdapter:
|
|
|
150
192
|
}
|
|
151
193
|
if tool_specs:
|
|
152
194
|
request_payload["tools"] = tool_specs
|
|
153
|
-
if
|
|
154
|
-
request_payload["tool_choice"] =
|
|
195
|
+
if next_tool_choice is not None:
|
|
196
|
+
request_payload["tool_choice"] = next_tool_choice
|
|
197
|
+
if response_format is not None:
|
|
198
|
+
request_payload["response_format"] = response_format
|
|
155
199
|
|
|
156
200
|
try:
|
|
157
201
|
response = self._client.chat.completions.create(**request_payload)
|
|
@@ -159,7 +203,7 @@ class OpenAIAdapter:
|
|
|
159
203
|
raise PromptEvaluationError(
|
|
160
204
|
"OpenAI request failed.",
|
|
161
205
|
prompt_name=prompt_name,
|
|
162
|
-
|
|
206
|
+
phase="request",
|
|
163
207
|
) from error
|
|
164
208
|
|
|
165
209
|
provider_payload = _extract_payload(response)
|
|
@@ -168,33 +212,55 @@ class OpenAIAdapter:
|
|
|
168
212
|
tool_calls = list(message.tool_calls or [])
|
|
169
213
|
|
|
170
214
|
if not tool_calls:
|
|
171
|
-
final_text = message.content
|
|
215
|
+
final_text = _message_text_content(message.content)
|
|
172
216
|
output: OutputT | None = None
|
|
173
217
|
text_value: str | None = final_text or None
|
|
174
218
|
|
|
175
|
-
if
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
219
|
+
if should_parse_structured_output:
|
|
220
|
+
parsed_payload = _extract_parsed_content(message)
|
|
221
|
+
if parsed_payload is not None:
|
|
222
|
+
try:
|
|
223
|
+
output = cast(
|
|
224
|
+
OutputT,
|
|
225
|
+
_parse_schema_constrained_payload(
|
|
226
|
+
parsed_payload, rendered
|
|
227
|
+
),
|
|
228
|
+
)
|
|
229
|
+
except (TypeError, ValueError) as error:
|
|
230
|
+
raise PromptEvaluationError(
|
|
231
|
+
str(error),
|
|
232
|
+
prompt_name=prompt_name,
|
|
233
|
+
phase="response",
|
|
234
|
+
provider_payload=provider_payload,
|
|
235
|
+
) from error
|
|
236
|
+
else:
|
|
237
|
+
try:
|
|
238
|
+
output = parse_structured_output(final_text, rendered)
|
|
239
|
+
except OutputParseError as error:
|
|
240
|
+
raise PromptEvaluationError(
|
|
241
|
+
error.message,
|
|
242
|
+
prompt_name=prompt_name,
|
|
243
|
+
phase="response",
|
|
244
|
+
provider_payload=provider_payload,
|
|
245
|
+
) from error
|
|
246
|
+
if output is not None:
|
|
247
|
+
text_value = None
|
|
248
|
+
|
|
249
|
+
response = PromptResponse(
|
|
192
250
|
prompt_name=prompt_name,
|
|
193
251
|
text=text_value,
|
|
194
252
|
output=output,
|
|
195
|
-
tool_results=tuple(
|
|
253
|
+
tool_results=tuple(tool_events),
|
|
196
254
|
provider_payload=provider_payload,
|
|
197
255
|
)
|
|
256
|
+
bus.publish(
|
|
257
|
+
PromptExecuted(
|
|
258
|
+
prompt_name=prompt_name,
|
|
259
|
+
adapter="openai",
|
|
260
|
+
result=cast(PromptResponse[object], response),
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
return response
|
|
198
264
|
|
|
199
265
|
assistant_tool_calls = [_serialize_tool_call(call) for call in tool_calls]
|
|
200
266
|
messages.append(
|
|
@@ -213,14 +279,14 @@ class OpenAIAdapter:
|
|
|
213
279
|
raise PromptEvaluationError(
|
|
214
280
|
f"Unknown tool '{tool_name}' requested by provider.",
|
|
215
281
|
prompt_name=prompt_name,
|
|
216
|
-
|
|
282
|
+
phase="tool",
|
|
217
283
|
provider_payload=provider_payload,
|
|
218
284
|
)
|
|
219
285
|
if tool.handler is None:
|
|
220
286
|
raise PromptEvaluationError(
|
|
221
287
|
f"Tool '{tool_name}' does not have a registered handler.",
|
|
222
288
|
prompt_name=prompt_name,
|
|
223
|
-
|
|
289
|
+
phase="tool",
|
|
224
290
|
provider_payload=provider_payload,
|
|
225
291
|
)
|
|
226
292
|
|
|
@@ -240,42 +306,50 @@ class OpenAIAdapter:
|
|
|
240
306
|
raise PromptEvaluationError(
|
|
241
307
|
f"Failed to parse params for tool '{tool_name}'.",
|
|
242
308
|
prompt_name=prompt_name,
|
|
243
|
-
|
|
309
|
+
phase="tool",
|
|
244
310
|
provider_payload=provider_payload,
|
|
245
311
|
) from error
|
|
246
312
|
|
|
247
313
|
try:
|
|
248
314
|
tool_result = tool.handler(tool_params)
|
|
315
|
+
except ToolValidationError as error:
|
|
316
|
+
tool_result = ToolResult(
|
|
317
|
+
message=f"Tool validation failed: {error}",
|
|
318
|
+
value=tool_params,
|
|
319
|
+
)
|
|
249
320
|
except Exception as error: # pragma: no cover - handler bug
|
|
250
321
|
raise PromptEvaluationError(
|
|
251
322
|
f"Tool '{tool_name}' raised an exception.",
|
|
252
323
|
prompt_name=prompt_name,
|
|
253
|
-
|
|
324
|
+
phase="tool",
|
|
254
325
|
provider_payload=provider_payload,
|
|
255
326
|
) from error
|
|
256
327
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
)
|
|
328
|
+
invocation = ToolInvoked(
|
|
329
|
+
prompt_name=prompt_name,
|
|
330
|
+
adapter="openai",
|
|
331
|
+
name=tool_name,
|
|
332
|
+
params=tool_params,
|
|
333
|
+
result=cast(ToolResult[object], tool_result),
|
|
334
|
+
call_id=getattr(tool_call, "id", None),
|
|
264
335
|
)
|
|
336
|
+
tool_events.append(invocation)
|
|
337
|
+
bus.publish(invocation)
|
|
265
338
|
|
|
266
|
-
payload = dump(tool_result.payload, exclude_none=True)
|
|
267
|
-
tool_content = {
|
|
268
|
-
"message": tool_result.message,
|
|
269
|
-
"payload": payload,
|
|
270
|
-
}
|
|
271
339
|
messages.append(
|
|
272
340
|
{
|
|
273
341
|
"role": "tool",
|
|
274
342
|
"tool_call_id": getattr(tool_call, "id", None),
|
|
275
|
-
"content":
|
|
343
|
+
"content": tool_result.message,
|
|
276
344
|
}
|
|
277
345
|
)
|
|
278
346
|
|
|
347
|
+
if isinstance(next_tool_choice, Mapping):
|
|
348
|
+
tool_choice_mapping = cast(Mapping[str, object], next_tool_choice)
|
|
349
|
+
if tool_choice_mapping.get("type") == "function":
|
|
350
|
+
# Relax forced single-function choice after the first call.
|
|
351
|
+
next_tool_choice = "auto"
|
|
352
|
+
|
|
279
353
|
|
|
280
354
|
def _tool_to_openai_spec(tool: Tool[Any, Any]) -> dict[str, Any]:
|
|
281
355
|
parameters_schema = schema(tool.params_type, extra="forbid")
|
|
@@ -306,6 +380,161 @@ def _extract_payload(response: _CompletionResponse) -> dict[str, Any] | None:
|
|
|
306
380
|
return None
|
|
307
381
|
|
|
308
382
|
|
|
383
|
+
def _build_json_schema_response_format(
|
|
384
|
+
rendered: RenderedPrompt[Any], prompt_name: str
|
|
385
|
+
) -> dict[str, Any] | None:
|
|
386
|
+
output_type = rendered.output_type
|
|
387
|
+
container = rendered.container
|
|
388
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
389
|
+
|
|
390
|
+
if output_type is None or container is None:
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
394
|
+
base_schema = schema(output_type, extra=extra_mode)
|
|
395
|
+
base_schema.pop("title", None)
|
|
396
|
+
|
|
397
|
+
if container == "array":
|
|
398
|
+
schema_payload = cast(
|
|
399
|
+
dict[str, Any],
|
|
400
|
+
{
|
|
401
|
+
"type": "object",
|
|
402
|
+
"properties": {
|
|
403
|
+
ARRAY_WRAPPER_KEY: {
|
|
404
|
+
"type": "array",
|
|
405
|
+
"items": base_schema,
|
|
406
|
+
}
|
|
407
|
+
},
|
|
408
|
+
"required": [ARRAY_WRAPPER_KEY],
|
|
409
|
+
},
|
|
410
|
+
)
|
|
411
|
+
if not allow_extra_keys:
|
|
412
|
+
schema_payload["additionalProperties"] = False
|
|
413
|
+
else:
|
|
414
|
+
schema_payload = base_schema
|
|
415
|
+
|
|
416
|
+
schema_name = _schema_name(prompt_name)
|
|
417
|
+
return {
|
|
418
|
+
"type": "json_schema",
|
|
419
|
+
"json_schema": {
|
|
420
|
+
"name": schema_name,
|
|
421
|
+
"schema": schema_payload,
|
|
422
|
+
},
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _schema_name(prompt_name: str) -> str:
|
|
427
|
+
sanitized = re.sub(r"[^a-zA-Z0-9_-]+", "_", prompt_name.strip())
|
|
428
|
+
cleaned = sanitized.strip("_") or "prompt"
|
|
429
|
+
return f"{cleaned}_schema"
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _message_text_content(content: object) -> str:
|
|
433
|
+
if isinstance(content, str) or content is None:
|
|
434
|
+
return content or ""
|
|
435
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
436
|
+
content, (str, bytes, bytearray)
|
|
437
|
+
):
|
|
438
|
+
fragments: list[str] = []
|
|
439
|
+
sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
|
|
440
|
+
for part in sequence_content:
|
|
441
|
+
fragments.append(_content_part_text(part))
|
|
442
|
+
return "".join(fragments)
|
|
443
|
+
return str(content)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _content_part_text(part: object) -> str:
|
|
447
|
+
if part is None:
|
|
448
|
+
return ""
|
|
449
|
+
if isinstance(part, Mapping):
|
|
450
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
451
|
+
part_type = mapping_part.get("type")
|
|
452
|
+
if part_type in {"output_text", "text"}:
|
|
453
|
+
text_value = mapping_part.get("text")
|
|
454
|
+
if isinstance(text_value, str):
|
|
455
|
+
return text_value
|
|
456
|
+
return ""
|
|
457
|
+
part_type = getattr(part, "type", None)
|
|
458
|
+
if part_type in {"output_text", "text"}:
|
|
459
|
+
text_value = getattr(part, "text", None)
|
|
460
|
+
if isinstance(text_value, str):
|
|
461
|
+
return text_value
|
|
462
|
+
return ""
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _extract_parsed_content(message: _Message) -> object | None:
|
|
466
|
+
parsed = getattr(message, "parsed", None)
|
|
467
|
+
if parsed is not None:
|
|
468
|
+
return parsed
|
|
469
|
+
|
|
470
|
+
content = message.content
|
|
471
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
472
|
+
content, (str, bytes, bytearray)
|
|
473
|
+
):
|
|
474
|
+
sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
|
|
475
|
+
for part in sequence_content:
|
|
476
|
+
payload = _parsed_payload_from_part(part)
|
|
477
|
+
if payload is not None:
|
|
478
|
+
return payload
|
|
479
|
+
return None
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def _parsed_payload_from_part(part: object) -> object | None:
|
|
483
|
+
if isinstance(part, Mapping):
|
|
484
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
485
|
+
if mapping_part.get("type") == "output_json":
|
|
486
|
+
return mapping_part.get("json")
|
|
487
|
+
return None
|
|
488
|
+
part_type = getattr(part, "type", None)
|
|
489
|
+
if part_type == "output_json":
|
|
490
|
+
return getattr(part, "json", None)
|
|
491
|
+
return None
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _parse_schema_constrained_payload(
|
|
495
|
+
payload: object, rendered: RenderedPrompt[Any]
|
|
496
|
+
) -> object:
|
|
497
|
+
dataclass_type = rendered.output_type
|
|
498
|
+
container = rendered.container
|
|
499
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
500
|
+
|
|
501
|
+
if dataclass_type is None or container is None:
|
|
502
|
+
raise TypeError("Prompt does not declare structured output.")
|
|
503
|
+
|
|
504
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
505
|
+
|
|
506
|
+
if container == "object":
|
|
507
|
+
if not isinstance(payload, Mapping):
|
|
508
|
+
raise TypeError("Expected provider payload to be a JSON object.")
|
|
509
|
+
return parse(
|
|
510
|
+
dataclass_type, cast(Mapping[str, object], payload), extra=extra_mode
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
if container == "array":
|
|
514
|
+
if isinstance(payload, Mapping):
|
|
515
|
+
if ARRAY_WRAPPER_KEY not in payload:
|
|
516
|
+
raise TypeError("Expected provider payload to be a JSON array.")
|
|
517
|
+
payload = cast(Mapping[str, object], payload)[ARRAY_WRAPPER_KEY]
|
|
518
|
+
if not isinstance(payload, Sequence) or isinstance(
|
|
519
|
+
payload, (str, bytes, bytearray)
|
|
520
|
+
):
|
|
521
|
+
raise TypeError("Expected provider payload to be a JSON array.")
|
|
522
|
+
parsed_items: list[object] = []
|
|
523
|
+
sequence_payload = cast(Sequence[object], payload) # pyright: ignore[reportUnnecessaryCast]
|
|
524
|
+
for index, item in enumerate(sequence_payload):
|
|
525
|
+
if not isinstance(item, Mapping):
|
|
526
|
+
raise TypeError(f"Array item at index {index} is not an object.")
|
|
527
|
+
parsed_item = parse(
|
|
528
|
+
dataclass_type,
|
|
529
|
+
cast(Mapping[str, object], item),
|
|
530
|
+
extra=extra_mode,
|
|
531
|
+
)
|
|
532
|
+
parsed_items.append(parsed_item)
|
|
533
|
+
return parsed_items
|
|
534
|
+
|
|
535
|
+
raise TypeError("Unknown output container declared.")
|
|
536
|
+
|
|
537
|
+
|
|
309
538
|
def _first_choice(
|
|
310
539
|
response: _CompletionResponse, *, prompt_name: str
|
|
311
540
|
) -> _CompletionChoice:
|
|
@@ -315,7 +544,7 @@ def _first_choice(
|
|
|
315
544
|
raise PromptEvaluationError(
|
|
316
545
|
"Provider response did not include any choices.",
|
|
317
546
|
prompt_name=prompt_name,
|
|
318
|
-
|
|
547
|
+
phase="response",
|
|
319
548
|
) from error
|
|
320
549
|
|
|
321
550
|
|
|
@@ -345,14 +574,14 @@ def _parse_tool_arguments(
|
|
|
345
574
|
raise PromptEvaluationError(
|
|
346
575
|
"Failed to decode tool call arguments.",
|
|
347
576
|
prompt_name=prompt_name,
|
|
348
|
-
|
|
577
|
+
phase="tool",
|
|
349
578
|
provider_payload=provider_payload,
|
|
350
579
|
) from error
|
|
351
580
|
if not isinstance(parsed, Mapping):
|
|
352
581
|
raise PromptEvaluationError(
|
|
353
582
|
"Tool call arguments must be a JSON object.",
|
|
354
583
|
prompt_name=prompt_name,
|
|
355
|
-
|
|
584
|
+
phase="tool",
|
|
356
585
|
provider_payload=provider_payload,
|
|
357
586
|
)
|
|
358
587
|
return dict(cast(Mapping[str, Any], parsed))
|
weakincentives/events.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""In-process event primitives for adapter telemetry."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import TYPE_CHECKING, Protocol
|
|
21
|
+
|
|
22
|
+
from .prompt._types import SupportsDataclass
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from .adapters.core import PromptResponse
|
|
26
|
+
from .prompt.tool import ToolResult
|
|
27
|
+
|
|
28
|
+
EventHandler = Callable[[object], None]
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EventBus(Protocol):
|
|
34
|
+
"""Minimal synchronous publish/subscribe abstraction."""
|
|
35
|
+
|
|
36
|
+
def subscribe(self, event_type: type[object], handler: EventHandler) -> None:
|
|
37
|
+
"""Register a handler for the given event type."""
|
|
38
|
+
|
|
39
|
+
def publish(self, event: object) -> None:
|
|
40
|
+
"""Publish an event instance to subscribers."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class NullEventBus:
|
|
44
|
+
"""Event bus implementation that discards all events."""
|
|
45
|
+
|
|
46
|
+
def subscribe(self, event_type: type[object], handler: EventHandler) -> None: # noqa: D401
|
|
47
|
+
"""No-op subscription hook."""
|
|
48
|
+
|
|
49
|
+
def publish(self, event: object) -> None: # noqa: D401
|
|
50
|
+
"""Drop the provided event instance."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class InProcessEventBus:
|
|
54
|
+
"""Process-local event bus that delivers events synchronously."""
|
|
55
|
+
|
|
56
|
+
def __init__(self) -> None:
|
|
57
|
+
self._handlers: dict[type[object], list[EventHandler]] = {}
|
|
58
|
+
|
|
59
|
+
def subscribe(self, event_type: type[object], handler: EventHandler) -> None:
|
|
60
|
+
handlers = self._handlers.setdefault(event_type, [])
|
|
61
|
+
handlers.append(handler)
|
|
62
|
+
|
|
63
|
+
def publish(self, event: object) -> None:
|
|
64
|
+
handlers = tuple(self._handlers.get(type(event), ()))
|
|
65
|
+
for handler in handlers:
|
|
66
|
+
try:
|
|
67
|
+
handler(event)
|
|
68
|
+
except Exception: # noqa: BLE001
|
|
69
|
+
logger.exception(
|
|
70
|
+
"Error delivering event %s to handler %r",
|
|
71
|
+
type(event).__name__,
|
|
72
|
+
handler,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(slots=True, frozen=True)
|
|
77
|
+
class PromptExecuted:
|
|
78
|
+
"""Event emitted after an adapter finishes evaluating a prompt."""
|
|
79
|
+
|
|
80
|
+
prompt_name: str
|
|
81
|
+
adapter: str
|
|
82
|
+
result: PromptResponse[object]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(slots=True, frozen=True)
|
|
86
|
+
class ToolInvoked:
|
|
87
|
+
"""Event emitted after an adapter executes a tool handler."""
|
|
88
|
+
|
|
89
|
+
prompt_name: str
|
|
90
|
+
adapter: str
|
|
91
|
+
name: str
|
|
92
|
+
params: SupportsDataclass
|
|
93
|
+
result: ToolResult[object]
|
|
94
|
+
call_id: str | None = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
__all__ = [
|
|
98
|
+
"EventBus",
|
|
99
|
+
"InProcessEventBus",
|
|
100
|
+
"NullEventBus",
|
|
101
|
+
"PromptExecuted",
|
|
102
|
+
"ToolInvoked",
|
|
103
|
+
]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Example scripts and shared utilities for the weakincentives project."""
|
|
14
|
+
|
|
15
|
+
from .code_review_prompt import (
|
|
16
|
+
ReviewGuidance,
|
|
17
|
+
ReviewResponse,
|
|
18
|
+
ReviewTurnParams,
|
|
19
|
+
build_code_review_prompt,
|
|
20
|
+
)
|
|
21
|
+
from .code_review_session import (
|
|
22
|
+
CodeReviewSession,
|
|
23
|
+
SupportsReviewEvaluate,
|
|
24
|
+
ToolCallLog,
|
|
25
|
+
)
|
|
26
|
+
from .code_review_tools import (
|
|
27
|
+
MAX_OUTPUT_CHARS,
|
|
28
|
+
REPO_ROOT,
|
|
29
|
+
BranchListParams,
|
|
30
|
+
BranchListResult,
|
|
31
|
+
GitLogParams,
|
|
32
|
+
GitLogResult,
|
|
33
|
+
TagListParams,
|
|
34
|
+
TagListResult,
|
|
35
|
+
TimeQueryParams,
|
|
36
|
+
TimeQueryResult,
|
|
37
|
+
branch_list_handler,
|
|
38
|
+
build_tools,
|
|
39
|
+
current_time_handler,
|
|
40
|
+
git_log_handler,
|
|
41
|
+
tag_list_handler,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"MAX_OUTPUT_CHARS",
|
|
46
|
+
"REPO_ROOT",
|
|
47
|
+
"BranchListParams",
|
|
48
|
+
"BranchListResult",
|
|
49
|
+
"GitLogParams",
|
|
50
|
+
"GitLogResult",
|
|
51
|
+
"TagListParams",
|
|
52
|
+
"TagListResult",
|
|
53
|
+
"TimeQueryParams",
|
|
54
|
+
"TimeQueryResult",
|
|
55
|
+
"ReviewGuidance",
|
|
56
|
+
"ReviewResponse",
|
|
57
|
+
"ReviewTurnParams",
|
|
58
|
+
"ToolCallLog",
|
|
59
|
+
"SupportsReviewEvaluate",
|
|
60
|
+
"CodeReviewSession",
|
|
61
|
+
"build_tools",
|
|
62
|
+
"build_code_review_prompt",
|
|
63
|
+
"git_log_handler",
|
|
64
|
+
"current_time_handler",
|
|
65
|
+
"branch_list_handler",
|
|
66
|
+
"tag_list_handler",
|
|
67
|
+
]
|