weakincentives 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of weakincentives might be problematic. Click here for more details.

@@ -1,2 +1,15 @@
1
- def hello() -> str:
2
- return "Hello from weakincentives!"
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Top-level package for the weakincentives library."""
14
+
15
+ __all__ = []
@@ -0,0 +1,30 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Integration adapters for optional third-party providers."""
14
+
15
+ from .core import (
16
+ PromptEvaluationError,
17
+ PromptResponse,
18
+ ProviderAdapter,
19
+ ToolCallRecord,
20
+ )
21
+ from .openai import OpenAIAdapter, OpenAIProtocol
22
+
23
+ __all__ = [
24
+ "ProviderAdapter",
25
+ "PromptResponse",
26
+ "PromptEvaluationError",
27
+ "ToolCallRecord",
28
+ "OpenAIAdapter",
29
+ "OpenAIProtocol",
30
+ ]
@@ -0,0 +1,85 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Core adapter interfaces shared across provider integrations."""
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Protocol, TypeVar
19
+
20
+ from ..prompts._types import SupportsDataclass
21
+ from ..prompts.prompt import Prompt
22
+ from ..prompts.tool import ToolResult
23
+
24
+ OutputT = TypeVar("OutputT")
25
+
26
+
27
+ class ProviderAdapter(Protocol[OutputT]):
28
+ """Protocol describing the synchronous adapter contract."""
29
+
30
+ def evaluate(
31
+ self,
32
+ prompt: Prompt[OutputT],
33
+ *params: SupportsDataclass,
34
+ parse_output: bool = True,
35
+ ) -> PromptResponse[OutputT]:
36
+ """Evaluate the prompt and return a structured response."""
37
+
38
+ ...
39
+
40
+
41
+ @dataclass(slots=True)
42
+ class ToolCallRecord[ParamsT, ResultT]:
43
+ """Record describing a single tool invocation during evaluation."""
44
+
45
+ name: str
46
+ params: ParamsT
47
+ result: ToolResult[ResultT]
48
+ call_id: str | None = None
49
+
50
+
51
+ @dataclass(slots=True)
52
+ class PromptResponse[OutputT]:
53
+ """Structured result emitted by an adapter evaluation."""
54
+
55
+ prompt_name: str
56
+ text: str | None
57
+ output: OutputT | None
58
+ tool_results: tuple[ToolCallRecord[Any, Any], ...]
59
+ provider_payload: dict[str, Any] | None = None
60
+
61
+
62
+ class PromptEvaluationError(RuntimeError):
63
+ """Raised when evaluation against a provider fails."""
64
+
65
+ def __init__(
66
+ self,
67
+ message: str,
68
+ *,
69
+ prompt_name: str,
70
+ stage: str,
71
+ provider_payload: dict[str, Any] | None = None,
72
+ ) -> None:
73
+ super().__init__(message)
74
+ self.message = message
75
+ self.prompt_name = prompt_name
76
+ self.stage = stage
77
+ self.provider_payload = provider_payload
78
+
79
+
80
+ __all__ = [
81
+ "ProviderAdapter",
82
+ "PromptEvaluationError",
83
+ "PromptResponse",
84
+ "ToolCallRecord",
85
+ ]
@@ -0,0 +1,361 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Optional OpenAI adapter utilities."""
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from collections.abc import Mapping, Sequence
19
+ from importlib import import_module
20
+ from typing import Any, Protocol, cast
21
+
22
+ from ..prompts._types import SupportsDataclass
23
+ from ..prompts.prompt import Prompt
24
+ from ..prompts.structured import OutputParseError
25
+ from ..prompts.structured import parse_output as parse_structured_output
26
+ from ..prompts.tool import Tool
27
+ from ..serde import dump, parse, schema
28
+ from .core import PromptEvaluationError, PromptResponse, ToolCallRecord
29
+
30
+ _ERROR_MESSAGE = (
31
+ "OpenAI support requires the optional 'openai' dependency. "
32
+ "Install it with `uv sync --extra openai` or `pip install weakincentives[openai]`."
33
+ )
34
+
35
+
36
+ class _CompletionFunctionCall(Protocol):
37
+ name: str
38
+ arguments: str | None
39
+
40
+
41
+ class _ToolCall(Protocol):
42
+ id: str
43
+ function: _CompletionFunctionCall
44
+
45
+
46
+ class _Message(Protocol):
47
+ content: str | None
48
+ tool_calls: Sequence[_ToolCall] | None
49
+
50
+
51
+ class _CompletionChoice(Protocol):
52
+ message: _Message
53
+
54
+
55
+ class _CompletionResponse(Protocol):
56
+ choices: Sequence[_CompletionChoice]
57
+
58
+
59
+ class _CompletionsAPI(Protocol):
60
+ def create(self, *args: object, **kwargs: object) -> _CompletionResponse: ...
61
+
62
+
63
+ class _ChatAPI(Protocol):
64
+ completions: _CompletionsAPI
65
+
66
+
67
+ class _OpenAIProtocol(Protocol):
68
+ """Structural type for the OpenAI client."""
69
+
70
+ chat: _ChatAPI
71
+
72
+
73
+ class _OpenAIClientFactory(Protocol):
74
+ def __call__(self, **kwargs: object) -> _OpenAIProtocol: ...
75
+
76
+
77
+ OpenAIProtocol = _OpenAIProtocol
78
+
79
+
80
+ class _OpenAIModule(Protocol):
81
+ OpenAI: _OpenAIClientFactory
82
+
83
+
84
+ def _load_openai_module() -> _OpenAIModule:
85
+ try:
86
+ module = import_module("openai")
87
+ except ModuleNotFoundError as exc:
88
+ raise RuntimeError(_ERROR_MESSAGE) from exc
89
+ return cast(_OpenAIModule, module)
90
+
91
+
92
+ def create_openai_client(**kwargs: object) -> _OpenAIProtocol:
93
+ """Create an OpenAI client, raising a helpful error if the extra is missing."""
94
+
95
+ openai_module = _load_openai_module()
96
+ return openai_module.OpenAI(**kwargs)
97
+
98
+
99
+ class OpenAIAdapter:
100
+ """Adapter that evaluates prompts against OpenAI's Responses API."""
101
+
102
+ def __init__(
103
+ self,
104
+ *,
105
+ model: str,
106
+ tool_choice: str | Mapping[str, Any] | None = "auto",
107
+ client: _OpenAIProtocol | None = None,
108
+ client_factory: _OpenAIClientFactory | None = None,
109
+ client_kwargs: Mapping[str, object] | None = None,
110
+ ) -> None:
111
+ if client is not None:
112
+ if client_factory is not None:
113
+ raise ValueError(
114
+ "client_factory cannot be provided when an explicit client is supplied.",
115
+ )
116
+ if client_kwargs:
117
+ raise ValueError(
118
+ "client_kwargs cannot be provided when an explicit client is supplied.",
119
+ )
120
+ else:
121
+ factory = client_factory or create_openai_client
122
+ client = factory(**dict(client_kwargs or {}))
123
+
124
+ self._client = client
125
+ self._model = model
126
+ self._tool_choice = tool_choice
127
+
128
+ def evaluate[OutputT](
129
+ self,
130
+ prompt: Prompt[OutputT],
131
+ *params: SupportsDataclass,
132
+ parse_output: bool = True,
133
+ ) -> PromptResponse[OutputT]:
134
+ prompt_name = prompt.name or prompt.__class__.__name__
135
+ rendered = prompt.render(*params) # type: ignore[reportArgumentType]
136
+ messages: list[dict[str, Any]] = [
137
+ {"role": "system", "content": rendered.text},
138
+ ]
139
+
140
+ tools = list(rendered.tools)
141
+ tool_specs = [_tool_to_openai_spec(tool) for tool in tools]
142
+ tool_registry = {tool.name: tool for tool in tools}
143
+ tool_records: list[ToolCallRecord[Any, Any]] = []
144
+ provider_payload: dict[str, Any] | None = None
145
+
146
+ while True:
147
+ request_payload: dict[str, Any] = {
148
+ "model": self._model,
149
+ "messages": messages,
150
+ }
151
+ if tool_specs:
152
+ request_payload["tools"] = tool_specs
153
+ if self._tool_choice is not None:
154
+ request_payload["tool_choice"] = self._tool_choice
155
+
156
+ try:
157
+ response = self._client.chat.completions.create(**request_payload)
158
+ except Exception as error: # pragma: no cover - network/SDK failure
159
+ raise PromptEvaluationError(
160
+ "OpenAI request failed.",
161
+ prompt_name=prompt_name,
162
+ stage="request",
163
+ ) from error
164
+
165
+ provider_payload = _extract_payload(response)
166
+ choice = _first_choice(response, prompt_name=prompt_name)
167
+ message = choice.message
168
+ tool_calls = list(message.tool_calls or [])
169
+
170
+ if not tool_calls:
171
+ final_text = message.content or ""
172
+ output: OutputT | None = None
173
+ text_value: str | None = final_text or None
174
+
175
+ if (
176
+ parse_output
177
+ and rendered.output_type is not None
178
+ and rendered.output_container is not None
179
+ ):
180
+ try:
181
+ output = parse_structured_output(final_text, rendered)
182
+ except OutputParseError as error:
183
+ raise PromptEvaluationError(
184
+ error.message,
185
+ prompt_name=prompt_name,
186
+ stage="response",
187
+ provider_payload=provider_payload,
188
+ ) from error
189
+ text_value = None
190
+
191
+ return PromptResponse(
192
+ prompt_name=prompt_name,
193
+ text=text_value,
194
+ output=output,
195
+ tool_results=tuple(tool_records),
196
+ provider_payload=provider_payload,
197
+ )
198
+
199
+ assistant_tool_calls = [_serialize_tool_call(call) for call in tool_calls]
200
+ messages.append(
201
+ {
202
+ "role": "assistant",
203
+ "content": message.content or "",
204
+ "tool_calls": assistant_tool_calls,
205
+ }
206
+ )
207
+
208
+ for tool_call in tool_calls:
209
+ function = tool_call.function
210
+ tool_name = function.name
211
+ tool = tool_registry.get(tool_name)
212
+ if tool is None:
213
+ raise PromptEvaluationError(
214
+ f"Unknown tool '{tool_name}' requested by provider.",
215
+ prompt_name=prompt_name,
216
+ stage="tool",
217
+ provider_payload=provider_payload,
218
+ )
219
+ if tool.handler is None:
220
+ raise PromptEvaluationError(
221
+ f"Tool '{tool_name}' does not have a registered handler.",
222
+ prompt_name=prompt_name,
223
+ stage="tool",
224
+ provider_payload=provider_payload,
225
+ )
226
+
227
+ arguments_mapping = _parse_tool_arguments(
228
+ function.arguments,
229
+ prompt_name=prompt_name,
230
+ provider_payload=provider_payload,
231
+ )
232
+
233
+ try:
234
+ tool_params = parse(
235
+ tool.params_type,
236
+ arguments_mapping,
237
+ extra="forbid",
238
+ )
239
+ except (TypeError, ValueError) as error:
240
+ raise PromptEvaluationError(
241
+ f"Failed to parse params for tool '{tool_name}'.",
242
+ prompt_name=prompt_name,
243
+ stage="tool",
244
+ provider_payload=provider_payload,
245
+ ) from error
246
+
247
+ try:
248
+ tool_result = tool.handler(tool_params)
249
+ except Exception as error: # pragma: no cover - handler bug
250
+ raise PromptEvaluationError(
251
+ f"Tool '{tool_name}' raised an exception.",
252
+ prompt_name=prompt_name,
253
+ stage="tool",
254
+ provider_payload=provider_payload,
255
+ ) from error
256
+
257
+ tool_records.append(
258
+ ToolCallRecord(
259
+ name=tool_name,
260
+ params=tool_params,
261
+ result=tool_result,
262
+ call_id=getattr(tool_call, "id", None),
263
+ )
264
+ )
265
+
266
+ payload = dump(tool_result.payload, exclude_none=True)
267
+ tool_content = {
268
+ "message": tool_result.message,
269
+ "payload": payload,
270
+ }
271
+ messages.append(
272
+ {
273
+ "role": "tool",
274
+ "tool_call_id": getattr(tool_call, "id", None),
275
+ "content": json.dumps(tool_content),
276
+ }
277
+ )
278
+
279
+
280
+ def _tool_to_openai_spec(tool: Tool[Any, Any]) -> dict[str, Any]:
281
+ parameters_schema = schema(tool.params_type, extra="forbid")
282
+ parameters_schema.pop("title", None)
283
+ return {
284
+ "type": "function",
285
+ "function": {
286
+ "name": tool.name,
287
+ "description": tool.description,
288
+ "parameters": parameters_schema,
289
+ },
290
+ }
291
+
292
+
293
+ def _extract_payload(response: _CompletionResponse) -> dict[str, Any] | None:
294
+ model_dump = getattr(response, "model_dump", None)
295
+ if callable(model_dump):
296
+ try:
297
+ payload = model_dump()
298
+ except Exception: # pragma: no cover - defensive
299
+ return None
300
+ if isinstance(payload, Mapping):
301
+ mapping_payload = cast(Mapping[str, Any], payload)
302
+ return dict(mapping_payload)
303
+ return None
304
+ if isinstance(response, Mapping): # pragma: no cover - defensive
305
+ return dict(response)
306
+ return None
307
+
308
+
309
+ def _first_choice(
310
+ response: _CompletionResponse, *, prompt_name: str
311
+ ) -> _CompletionChoice:
312
+ try:
313
+ return response.choices[0]
314
+ except (AttributeError, IndexError) as error: # pragma: no cover - defensive
315
+ raise PromptEvaluationError(
316
+ "Provider response did not include any choices.",
317
+ prompt_name=prompt_name,
318
+ stage="response",
319
+ ) from error
320
+
321
+
322
+ def _serialize_tool_call(tool_call: _ToolCall) -> dict[str, Any]:
323
+ function = tool_call.function
324
+ return {
325
+ "id": getattr(tool_call, "id", None),
326
+ "type": "function",
327
+ "function": {
328
+ "name": function.name,
329
+ "arguments": function.arguments or "{}",
330
+ },
331
+ }
332
+
333
+
334
+ def _parse_tool_arguments(
335
+ arguments_json: str | None,
336
+ *,
337
+ prompt_name: str,
338
+ provider_payload: dict[str, Any] | None,
339
+ ) -> dict[str, Any]:
340
+ if not arguments_json:
341
+ return {}
342
+ try:
343
+ parsed = json.loads(arguments_json)
344
+ except json.JSONDecodeError as error:
345
+ raise PromptEvaluationError(
346
+ "Failed to decode tool call arguments.",
347
+ prompt_name=prompt_name,
348
+ stage="tool",
349
+ provider_payload=provider_payload,
350
+ ) from error
351
+ if not isinstance(parsed, Mapping):
352
+ raise PromptEvaluationError(
353
+ "Tool call arguments must be a JSON object.",
354
+ prompt_name=prompt_name,
355
+ stage="tool",
356
+ provider_payload=provider_payload,
357
+ )
358
+ return dict(cast(Mapping[str, Any], parsed))
359
+
360
+
361
+ __all__ = ["OpenAIAdapter", "OpenAIProtocol"]
@@ -0,0 +1,45 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Prompt module scaffolding."""
14
+
15
+ from __future__ import annotations
16
+
17
+ from ._types import SupportsDataclass
18
+ from .errors import (
19
+ PromptError,
20
+ PromptRenderError,
21
+ PromptValidationError,
22
+ SectionPath,
23
+ )
24
+ from .prompt import Prompt, PromptSectionNode, RenderedPrompt
25
+ from .section import Section
26
+ from .structured import OutputParseError, parse_output
27
+ from .text import TextSection
28
+ from .tool import Tool, ToolResult
29
+
30
+ __all__ = [
31
+ "Prompt",
32
+ "RenderedPrompt",
33
+ "PromptSectionNode",
34
+ "PromptError",
35
+ "PromptRenderError",
36
+ "PromptValidationError",
37
+ "Section",
38
+ "SectionPath",
39
+ "SupportsDataclass",
40
+ "TextSection",
41
+ "Tool",
42
+ "ToolResult",
43
+ "OutputParseError",
44
+ "parse_output",
45
+ ]
@@ -0,0 +1,27 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Internal typing helpers for the prompts package."""
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, ClassVar, Protocol, runtime_checkable
18
+
19
+
20
+ @runtime_checkable
21
+ class SupportsDataclass(Protocol):
22
+ """Protocol satisfied by dataclass types and instances."""
23
+
24
+ __dataclass_fields__: ClassVar[dict[str, Any]]
25
+
26
+
27
+ __all__ = ["SupportsDataclass"]
@@ -0,0 +1,57 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from collections.abc import Sequence
16
+
17
+ SectionPath = tuple[str, ...]
18
+
19
+
20
+ def _normalize_section_path(section_path: Sequence[str] | None) -> SectionPath:
21
+ if section_path is None:
22
+ return ()
23
+ return tuple(section_path)
24
+
25
+
26
+ class PromptError(Exception):
27
+ """Base class for prompt-related failures providing structured context."""
28
+
29
+ def __init__(
30
+ self,
31
+ message: str,
32
+ *,
33
+ section_path: Sequence[str] | None = None,
34
+ dataclass_type: type | None = None,
35
+ placeholder: str | None = None,
36
+ ) -> None:
37
+ super().__init__(message)
38
+ self.message = message
39
+ self.section_path: SectionPath = _normalize_section_path(section_path)
40
+ self.dataclass_type = dataclass_type
41
+ self.placeholder = placeholder
42
+
43
+
44
+ class PromptValidationError(PromptError):
45
+ """Raised when prompt construction validation fails."""
46
+
47
+
48
+ class PromptRenderError(PromptError):
49
+ """Raised when rendering a prompt fails."""
50
+
51
+
52
+ __all__ = [
53
+ "PromptError",
54
+ "PromptValidationError",
55
+ "PromptRenderError",
56
+ "SectionPath",
57
+ ]