weakincentives 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of weakincentives might be problematic. Click here for more details.

Files changed (36) hide show
  1. weakincentives/__init__.py +26 -2
  2. weakincentives/adapters/__init__.py +6 -5
  3. weakincentives/adapters/core.py +7 -17
  4. weakincentives/adapters/litellm.py +594 -0
  5. weakincentives/adapters/openai.py +286 -57
  6. weakincentives/events.py +103 -0
  7. weakincentives/examples/__init__.py +67 -0
  8. weakincentives/examples/code_review_prompt.py +118 -0
  9. weakincentives/examples/code_review_session.py +171 -0
  10. weakincentives/examples/code_review_tools.py +376 -0
  11. weakincentives/{prompts → prompt}/__init__.py +6 -8
  12. weakincentives/{prompts → prompt}/_types.py +1 -1
  13. weakincentives/{prompts/text.py → prompt/markdown.py} +19 -9
  14. weakincentives/{prompts → prompt}/prompt.py +216 -66
  15. weakincentives/{prompts → prompt}/response_format.py +9 -6
  16. weakincentives/{prompts → prompt}/section.py +25 -4
  17. weakincentives/{prompts/structured.py → prompt/structured_output.py} +16 -5
  18. weakincentives/{prompts → prompt}/tool.py +6 -6
  19. weakincentives/prompt/versioning.py +144 -0
  20. weakincentives/serde/__init__.py +0 -14
  21. weakincentives/serde/dataclass_serde.py +3 -17
  22. weakincentives/session/__init__.py +31 -0
  23. weakincentives/session/reducers.py +60 -0
  24. weakincentives/session/selectors.py +45 -0
  25. weakincentives/session/session.py +168 -0
  26. weakincentives/tools/__init__.py +69 -0
  27. weakincentives/tools/errors.py +22 -0
  28. weakincentives/tools/planning.py +538 -0
  29. weakincentives/tools/vfs.py +590 -0
  30. weakincentives-0.3.0.dist-info/METADATA +231 -0
  31. weakincentives-0.3.0.dist-info/RECORD +35 -0
  32. weakincentives-0.2.0.dist-info/METADATA +0 -173
  33. weakincentives-0.2.0.dist-info/RECORD +0 -20
  34. /weakincentives/{prompts → prompt}/errors.py +0 -0
  35. {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/WHEEL +0 -0
  36. {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,594 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Optional LiteLLM adapter utilities."""
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import re
19
+ from collections.abc import Mapping, Sequence
20
+ from importlib import import_module
21
+ from typing import TYPE_CHECKING, Any, Final, Literal, Protocol, cast
22
+
23
+ from ..events import EventBus, PromptExecuted, ToolInvoked
24
+ from ..prompt._types import SupportsDataclass
25
+ from ..prompt.prompt import Prompt, RenderedPrompt
26
+ from ..prompt.structured_output import (
27
+ ARRAY_WRAPPER_KEY,
28
+ OutputParseError,
29
+ parse_structured_output,
30
+ )
31
+ from ..prompt.tool import Tool, ToolResult
32
+ from ..serde import parse, schema
33
+ from .core import PromptEvaluationError, PromptResponse
34
+
35
+ _ERROR_MESSAGE: Final[str] = (
36
+ "LiteLLM support requires the optional 'litellm' dependency. "
37
+ "Install it with `uv sync --extra litellm` or `pip install weakincentives[litellm]`."
38
+ )
39
+
40
+ try: # pragma: no cover - optional dependency import for tooling
41
+ import litellm as _optional_litellm # type: ignore[import]
42
+ except ModuleNotFoundError: # pragma: no cover - handled lazily in loader
43
+ _optional_litellm = None # type: ignore[assignment]
44
+
45
+ if TYPE_CHECKING: # pragma: no cover - optional dependency for typing only
46
+ import litellm
47
+
48
+ _ = litellm.__name__
49
+
50
+
51
+ class _CompletionFunctionCall(Protocol):
52
+ name: str
53
+ arguments: str | None
54
+
55
+
56
+ class _ToolCall(Protocol):
57
+ id: str
58
+ function: _CompletionFunctionCall
59
+
60
+
61
+ class _Message(Protocol):
62
+ content: str | Sequence[object] | None
63
+ tool_calls: Sequence[_ToolCall] | None
64
+
65
+
66
+ class _CompletionChoice(Protocol):
67
+ message: _Message
68
+
69
+
70
+ class _CompletionResponse(Protocol):
71
+ choices: Sequence[_CompletionChoice]
72
+
73
+
74
+ class _CompletionCallable(Protocol):
75
+ def __call__(self, *args: object, **kwargs: object) -> _CompletionResponse: ...
76
+
77
+
78
+ class _LiteLLMModule(Protocol):
79
+ def completion(self, *args: object, **kwargs: object) -> _CompletionResponse: ...
80
+
81
+
82
+ class _LiteLLMCompletionFactory(Protocol):
83
+ def __call__(self, **kwargs: object) -> _CompletionCallable: ...
84
+
85
+
86
+ LiteLLMCompletion = _CompletionCallable
87
+
88
+
89
+ def _load_litellm_module() -> _LiteLLMModule:
90
+ try:
91
+ module = import_module("litellm")
92
+ except ModuleNotFoundError as exc: # pragma: no cover - dependency guard
93
+ raise RuntimeError(_ERROR_MESSAGE) from exc
94
+ return cast(_LiteLLMModule, module)
95
+
96
+
97
+ def create_litellm_completion(**kwargs: object) -> LiteLLMCompletion:
98
+ """Return a LiteLLM completion callable, guarding the optional dependency."""
99
+
100
+ module = _load_litellm_module()
101
+ if not kwargs:
102
+ return module.completion
103
+
104
+ def _wrapped_completion(
105
+ *args: object, **request_kwargs: object
106
+ ) -> _CompletionResponse:
107
+ merged: dict[str, object] = dict(kwargs)
108
+ merged.update(request_kwargs)
109
+ return module.completion(*args, **merged)
110
+
111
+ return _wrapped_completion
112
+
113
+
114
+ ToolChoice = Literal["auto"] | Mapping[str, Any] | None
115
+ """Supported tool choice directives for provider APIs."""
116
+
117
+
118
+ class LiteLLMAdapter:
119
+ """Adapter that evaluates prompts via LiteLLM's completion helper."""
120
+
121
+ def __init__(
122
+ self,
123
+ *,
124
+ model: str,
125
+ tool_choice: ToolChoice = "auto",
126
+ completion: LiteLLMCompletion | None = None,
127
+ completion_factory: _LiteLLMCompletionFactory | None = None,
128
+ completion_kwargs: Mapping[str, object] | None = None,
129
+ ) -> None:
130
+ if completion is not None:
131
+ if completion_factory is not None:
132
+ raise ValueError(
133
+ "completion_factory cannot be provided when an explicit completion is supplied.",
134
+ )
135
+ if completion_kwargs:
136
+ raise ValueError(
137
+ "completion_kwargs cannot be provided when an explicit completion is supplied.",
138
+ )
139
+ else:
140
+ factory = completion_factory or create_litellm_completion
141
+ completion = factory(**dict(completion_kwargs or {}))
142
+
143
+ self._completion = completion
144
+ self._model = model
145
+ self._tool_choice: ToolChoice = tool_choice
146
+
147
+ def evaluate[OutputT](
148
+ self,
149
+ prompt: Prompt[OutputT],
150
+ *params: SupportsDataclass,
151
+ parse_output: bool = True,
152
+ bus: EventBus,
153
+ ) -> PromptResponse[OutputT]:
154
+ prompt_name = prompt.name or prompt.__class__.__name__
155
+ has_structured_output = (
156
+ getattr(prompt, "_output_type", None) is not None
157
+ and getattr(prompt, "_output_container", None) is not None
158
+ )
159
+ should_disable_instructions = (
160
+ parse_output
161
+ and has_structured_output
162
+ and getattr(prompt, "inject_output_instructions", False)
163
+ )
164
+
165
+ if should_disable_instructions:
166
+ rendered = prompt.render(
167
+ *params,
168
+ inject_output_instructions=False,
169
+ ) # type: ignore[reportArgumentType]
170
+ else:
171
+ rendered = prompt.render(*params) # type: ignore[reportArgumentType]
172
+ messages: list[dict[str, Any]] = [
173
+ {"role": "system", "content": rendered.text},
174
+ ]
175
+
176
+ should_parse_structured_output = (
177
+ parse_output
178
+ and rendered.output_type is not None
179
+ and rendered.container is not None
180
+ )
181
+ response_format: dict[str, Any] | None = None
182
+ if should_parse_structured_output:
183
+ response_format = _build_json_schema_response_format(rendered, prompt_name)
184
+
185
+ tools = list(rendered.tools)
186
+ tool_specs = [_tool_to_litellm_spec(tool) for tool in tools]
187
+ tool_registry = {tool.name: tool for tool in tools}
188
+ tool_events: list[ToolInvoked] = []
189
+ provider_payload: dict[str, Any] | None = None
190
+ next_tool_choice: ToolChoice = self._tool_choice
191
+
192
+ while True:
193
+ request_payload: dict[str, Any] = {
194
+ "model": self._model,
195
+ "messages": messages,
196
+ }
197
+ if tool_specs:
198
+ request_payload["tools"] = tool_specs
199
+ if next_tool_choice is not None:
200
+ request_payload["tool_choice"] = next_tool_choice
201
+ if response_format is not None:
202
+ request_payload["response_format"] = response_format
203
+
204
+ try:
205
+ response = self._completion(**request_payload)
206
+ except Exception as error: # pragma: no cover - network/SDK failure
207
+ raise PromptEvaluationError(
208
+ "LiteLLM request failed.",
209
+ prompt_name=prompt_name,
210
+ phase="request",
211
+ ) from error
212
+
213
+ provider_payload = _extract_payload(response)
214
+ choice = _first_choice(response, prompt_name=prompt_name)
215
+ message = choice.message
216
+ tool_calls = list(message.tool_calls or [])
217
+
218
+ if not tool_calls:
219
+ final_text = _message_text_content(message.content)
220
+ output: OutputT | None = None
221
+ text_value: str | None = final_text or None
222
+
223
+ if should_parse_structured_output:
224
+ parsed_payload = _extract_parsed_content(message)
225
+ if parsed_payload is not None:
226
+ try:
227
+ output = cast(
228
+ OutputT,
229
+ _parse_schema_constrained_payload(
230
+ parsed_payload, rendered
231
+ ),
232
+ )
233
+ except TypeError as error:
234
+ raise PromptEvaluationError(
235
+ str(error),
236
+ prompt_name=prompt_name,
237
+ phase="response",
238
+ provider_payload=provider_payload,
239
+ ) from error
240
+ elif final_text:
241
+ try:
242
+ output = parse_structured_output(final_text, rendered)
243
+ except OutputParseError as error:
244
+ raise PromptEvaluationError(
245
+ error.message,
246
+ prompt_name=prompt_name,
247
+ phase="response",
248
+ provider_payload=provider_payload,
249
+ ) from error
250
+ else:
251
+ raise PromptEvaluationError(
252
+ "Provider response did not include structured output.",
253
+ prompt_name=prompt_name,
254
+ phase="response",
255
+ provider_payload=provider_payload,
256
+ )
257
+ text_value = None
258
+
259
+ response = PromptResponse(
260
+ prompt_name=prompt_name,
261
+ text=text_value,
262
+ output=output,
263
+ tool_results=tuple(tool_events),
264
+ provider_payload=provider_payload,
265
+ )
266
+ bus.publish(
267
+ PromptExecuted(
268
+ prompt_name=prompt_name,
269
+ adapter="litellm",
270
+ result=cast(PromptResponse[object], response),
271
+ )
272
+ )
273
+ return response
274
+
275
+ assistant_tool_calls = [_serialize_tool_call(call) for call in tool_calls]
276
+ messages.append(
277
+ {
278
+ "role": "assistant",
279
+ "content": message.content or "",
280
+ "tool_calls": assistant_tool_calls,
281
+ }
282
+ )
283
+
284
+ for tool_call in tool_calls:
285
+ function = tool_call.function
286
+ tool_name = function.name
287
+ tool = tool_registry.get(tool_name)
288
+ if tool is None:
289
+ raise PromptEvaluationError(
290
+ f"Unknown tool '{tool_name}' requested by provider.",
291
+ prompt_name=prompt_name,
292
+ phase="tool",
293
+ provider_payload=provider_payload,
294
+ )
295
+ if tool.handler is None:
296
+ raise PromptEvaluationError(
297
+ f"Tool '{tool_name}' does not have a registered handler.",
298
+ prompt_name=prompt_name,
299
+ phase="tool",
300
+ provider_payload=provider_payload,
301
+ )
302
+
303
+ arguments_mapping = _parse_tool_arguments(
304
+ function.arguments,
305
+ prompt_name=prompt_name,
306
+ provider_payload=provider_payload,
307
+ )
308
+
309
+ try:
310
+ tool_params = parse(
311
+ tool.params_type,
312
+ arguments_mapping,
313
+ extra="forbid",
314
+ )
315
+ except (TypeError, ValueError) as error:
316
+ raise PromptEvaluationError(
317
+ f"Failed to parse params for tool '{tool_name}'.",
318
+ prompt_name=prompt_name,
319
+ phase="tool",
320
+ provider_payload=provider_payload,
321
+ ) from error
322
+
323
+ try:
324
+ tool_result = tool.handler(tool_params)
325
+ except Exception as error: # pragma: no cover - handler bug
326
+ raise PromptEvaluationError(
327
+ f"Tool '{tool_name}' raised an exception.",
328
+ prompt_name=prompt_name,
329
+ phase="tool",
330
+ provider_payload=provider_payload,
331
+ ) from error
332
+
333
+ invocation = ToolInvoked(
334
+ prompt_name=prompt_name,
335
+ adapter="litellm",
336
+ name=tool_name,
337
+ params=tool_params,
338
+ result=cast(ToolResult[object], tool_result),
339
+ call_id=getattr(tool_call, "id", None),
340
+ )
341
+ tool_events.append(invocation)
342
+ bus.publish(invocation)
343
+
344
+ messages.append(
345
+ {
346
+ "role": "tool",
347
+ "tool_call_id": getattr(tool_call, "id", None),
348
+ "content": tool_result.message,
349
+ }
350
+ )
351
+
352
+ if isinstance(next_tool_choice, Mapping):
353
+ tool_choice_mapping = cast(Mapping[str, object], next_tool_choice)
354
+ if tool_choice_mapping.get("type") == "function":
355
+ next_tool_choice = "auto"
356
+
357
+
358
+ def _build_json_schema_response_format(
359
+ rendered: RenderedPrompt[Any], prompt_name: str
360
+ ) -> dict[str, Any] | None:
361
+ output_type = rendered.output_type
362
+ container = rendered.container
363
+ allow_extra_keys = rendered.allow_extra_keys
364
+
365
+ if output_type is None or container is None:
366
+ return None
367
+
368
+ extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
369
+ base_schema = schema(output_type, extra=extra_mode)
370
+ base_schema.pop("title", None)
371
+
372
+ if container == "array":
373
+ schema_payload = cast(
374
+ dict[str, Any],
375
+ {
376
+ "type": "object",
377
+ "properties": {
378
+ ARRAY_WRAPPER_KEY: {
379
+ "type": "array",
380
+ "items": base_schema,
381
+ }
382
+ },
383
+ "required": [ARRAY_WRAPPER_KEY],
384
+ },
385
+ )
386
+ if not allow_extra_keys:
387
+ schema_payload["additionalProperties"] = False
388
+ else:
389
+ schema_payload = base_schema
390
+
391
+ schema_name = _schema_name(prompt_name)
392
+ return {
393
+ "type": "json_schema",
394
+ "json_schema": {
395
+ "name": schema_name,
396
+ "schema": schema_payload,
397
+ },
398
+ }
399
+
400
+
401
+ def _schema_name(prompt_name: str) -> str:
402
+ sanitized = re.sub(r"[^a-zA-Z0-9_-]+", "_", prompt_name.strip())
403
+ cleaned = sanitized.strip("_") or "prompt"
404
+ return f"{cleaned}_schema"
405
+
406
+
407
+ def _tool_to_litellm_spec(tool: Tool[Any, Any]) -> dict[str, Any]:
408
+ parameters_schema = schema(tool.params_type, extra="forbid")
409
+ parameters_schema.pop("title", None)
410
+ return {
411
+ "type": "function",
412
+ "function": {
413
+ "name": tool.name,
414
+ "description": tool.description,
415
+ "parameters": parameters_schema,
416
+ },
417
+ }
418
+
419
+
420
+ def _extract_payload(response: _CompletionResponse) -> dict[str, Any] | None:
421
+ model_dump = getattr(response, "model_dump", None)
422
+ if callable(model_dump):
423
+ try:
424
+ payload = model_dump()
425
+ except Exception: # pragma: no cover - defensive
426
+ return None
427
+ if isinstance(payload, Mapping):
428
+ mapping_payload = cast(Mapping[str, Any], payload)
429
+ return dict(mapping_payload)
430
+ return None
431
+ if isinstance(response, Mapping): # pragma: no cover - defensive
432
+ return dict(response)
433
+ return None
434
+
435
+
436
+ def _first_choice(
437
+ response: _CompletionResponse, *, prompt_name: str
438
+ ) -> _CompletionChoice:
439
+ try:
440
+ return response.choices[0]
441
+ except (AttributeError, IndexError) as error: # pragma: no cover - defensive
442
+ raise PromptEvaluationError(
443
+ "Provider response did not include any choices.",
444
+ prompt_name=prompt_name,
445
+ phase="response",
446
+ ) from error
447
+
448
+
449
+ def _serialize_tool_call(tool_call: _ToolCall) -> dict[str, Any]:
450
+ function = tool_call.function
451
+ return {
452
+ "id": getattr(tool_call, "id", None),
453
+ "type": "function",
454
+ "function": {
455
+ "name": function.name,
456
+ "arguments": function.arguments or "{}",
457
+ },
458
+ }
459
+
460
+
461
+ def _parse_tool_arguments(
462
+ arguments_json: str | None,
463
+ *,
464
+ prompt_name: str,
465
+ provider_payload: dict[str, Any] | None,
466
+ ) -> dict[str, Any]:
467
+ if not arguments_json:
468
+ return {}
469
+ try:
470
+ parsed = json.loads(arguments_json)
471
+ except json.JSONDecodeError as error:
472
+ raise PromptEvaluationError(
473
+ "Failed to decode tool call arguments.",
474
+ prompt_name=prompt_name,
475
+ phase="tool",
476
+ provider_payload=provider_payload,
477
+ ) from error
478
+ if not isinstance(parsed, Mapping):
479
+ raise PromptEvaluationError(
480
+ "Tool call arguments must be a JSON object.",
481
+ prompt_name=prompt_name,
482
+ phase="tool",
483
+ provider_payload=provider_payload,
484
+ )
485
+ return dict(cast(Mapping[str, Any], parsed))
486
+
487
+
488
+ def _message_text_content(content: object) -> str:
489
+ if isinstance(content, str) or content is None:
490
+ return content or ""
491
+ if isinstance(content, Sequence) and not isinstance(
492
+ content, (str, bytes, bytearray)
493
+ ):
494
+ fragments: list[str] = []
495
+ sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
496
+ for part in sequence_content:
497
+ fragments.append(_content_part_text(part))
498
+ return "".join(fragments)
499
+ return str(content)
500
+
501
+
502
+ def _content_part_text(part: object) -> str:
503
+ if part is None:
504
+ return ""
505
+ if isinstance(part, Mapping):
506
+ mapping_part = cast(Mapping[str, object], part)
507
+ part_type = mapping_part.get("type")
508
+ if part_type in {"output_text", "text"}:
509
+ text_value = mapping_part.get("text")
510
+ if isinstance(text_value, str):
511
+ return text_value
512
+ return ""
513
+ part_type = getattr(part, "type", None)
514
+ if part_type in {"output_text", "text"}:
515
+ text_value = getattr(part, "text", None)
516
+ if isinstance(text_value, str):
517
+ return text_value
518
+ return ""
519
+
520
+
521
+ def _extract_parsed_content(message: _Message) -> object | None:
522
+ parsed = getattr(message, "parsed", None)
523
+ if parsed is not None:
524
+ return parsed
525
+
526
+ content = message.content
527
+ if isinstance(content, Sequence) and not isinstance(
528
+ content, (str, bytes, bytearray)
529
+ ):
530
+ sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
531
+ for part in sequence_content:
532
+ payload = _parsed_payload_from_part(part)
533
+ if payload is not None:
534
+ return payload
535
+ return None
536
+
537
+
538
+ def _parsed_payload_from_part(part: object) -> object | None:
539
+ if isinstance(part, Mapping):
540
+ mapping_part = cast(Mapping[str, object], part)
541
+ if mapping_part.get("type") == "output_json":
542
+ return mapping_part.get("json")
543
+ return None
544
+ part_type = getattr(part, "type", None)
545
+ if part_type == "output_json":
546
+ return getattr(part, "json", None)
547
+ return None
548
+
549
+
550
+ def _parse_schema_constrained_payload(
551
+ payload: object, rendered: RenderedPrompt[Any]
552
+ ) -> object:
553
+ dataclass_type = rendered.output_type
554
+ container = rendered.container
555
+ allow_extra_keys = rendered.allow_extra_keys
556
+
557
+ if dataclass_type is None or container is None:
558
+ raise TypeError("Prompt does not declare structured output.")
559
+
560
+ extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
561
+
562
+ if container == "object":
563
+ if not isinstance(payload, Mapping):
564
+ raise TypeError("Expected provider payload to be a JSON object.")
565
+ return parse(
566
+ dataclass_type, cast(Mapping[str, object], payload), extra=extra_mode
567
+ )
568
+
569
+ if container == "array":
570
+ if isinstance(payload, Mapping):
571
+ if ARRAY_WRAPPER_KEY not in payload:
572
+ raise TypeError("Expected provider payload to be a JSON array.")
573
+ payload = cast(Mapping[str, object], payload)[ARRAY_WRAPPER_KEY]
574
+ if not isinstance(payload, Sequence) or isinstance(
575
+ payload, (str, bytes, bytearray)
576
+ ):
577
+ raise TypeError("Expected provider payload to be a JSON array.")
578
+ parsed_items: list[object] = []
579
+ sequence_payload = cast(Sequence[object], payload) # pyright: ignore[reportUnnecessaryCast]
580
+ for index, item in enumerate(sequence_payload):
581
+ if not isinstance(item, Mapping):
582
+ raise TypeError(f"Array item at index {index} is not an object.")
583
+ parsed_item = parse(
584
+ dataclass_type,
585
+ cast(Mapping[str, object], item),
586
+ extra=extra_mode,
587
+ )
588
+ parsed_items.append(parsed_item)
589
+ return parsed_items
590
+
591
+ raise TypeError("Unknown output container declared.")
592
+
593
+
594
+ __all__ = ["LiteLLMAdapter", "LiteLLMCompletion", "create_litellm_completion"]