weakincentives 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +26 -2
- weakincentives/adapters/__init__.py +6 -5
- weakincentives/adapters/core.py +7 -17
- weakincentives/adapters/litellm.py +594 -0
- weakincentives/adapters/openai.py +286 -57
- weakincentives/events.py +103 -0
- weakincentives/examples/__init__.py +67 -0
- weakincentives/examples/code_review_prompt.py +118 -0
- weakincentives/examples/code_review_session.py +171 -0
- weakincentives/examples/code_review_tools.py +376 -0
- weakincentives/{prompts → prompt}/__init__.py +6 -8
- weakincentives/{prompts → prompt}/_types.py +1 -1
- weakincentives/{prompts/text.py → prompt/markdown.py} +19 -9
- weakincentives/{prompts → prompt}/prompt.py +216 -66
- weakincentives/{prompts → prompt}/response_format.py +9 -6
- weakincentives/{prompts → prompt}/section.py +25 -4
- weakincentives/{prompts/structured.py → prompt/structured_output.py} +16 -5
- weakincentives/{prompts → prompt}/tool.py +6 -6
- weakincentives/prompt/versioning.py +144 -0
- weakincentives/serde/__init__.py +0 -14
- weakincentives/serde/dataclass_serde.py +3 -17
- weakincentives/session/__init__.py +31 -0
- weakincentives/session/reducers.py +60 -0
- weakincentives/session/selectors.py +45 -0
- weakincentives/session/session.py +168 -0
- weakincentives/tools/__init__.py +69 -0
- weakincentives/tools/errors.py +22 -0
- weakincentives/tools/planning.py +538 -0
- weakincentives/tools/vfs.py +590 -0
- weakincentives-0.3.0.dist-info/METADATA +231 -0
- weakincentives-0.3.0.dist-info/RECORD +35 -0
- weakincentives-0.2.0.dist-info/METADATA +0 -173
- weakincentives-0.2.0.dist-info/RECORD +0 -20
- /weakincentives/{prompts → prompt}/errors.py +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Optional LiteLLM adapter utilities."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import re
|
|
19
|
+
from collections.abc import Mapping, Sequence
|
|
20
|
+
from importlib import import_module
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Final, Literal, Protocol, cast
|
|
22
|
+
|
|
23
|
+
from ..events import EventBus, PromptExecuted, ToolInvoked
|
|
24
|
+
from ..prompt._types import SupportsDataclass
|
|
25
|
+
from ..prompt.prompt import Prompt, RenderedPrompt
|
|
26
|
+
from ..prompt.structured_output import (
|
|
27
|
+
ARRAY_WRAPPER_KEY,
|
|
28
|
+
OutputParseError,
|
|
29
|
+
parse_structured_output,
|
|
30
|
+
)
|
|
31
|
+
from ..prompt.tool import Tool, ToolResult
|
|
32
|
+
from ..serde import parse, schema
|
|
33
|
+
from .core import PromptEvaluationError, PromptResponse
|
|
34
|
+
|
|
35
|
+
_ERROR_MESSAGE: Final[str] = (
|
|
36
|
+
"LiteLLM support requires the optional 'litellm' dependency. "
|
|
37
|
+
"Install it with `uv sync --extra litellm` or `pip install weakincentives[litellm]`."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
try: # pragma: no cover - optional dependency import for tooling
|
|
41
|
+
import litellm as _optional_litellm # type: ignore[import]
|
|
42
|
+
except ModuleNotFoundError: # pragma: no cover - handled lazily in loader
|
|
43
|
+
_optional_litellm = None # type: ignore[assignment]
|
|
44
|
+
|
|
45
|
+
if TYPE_CHECKING: # pragma: no cover - optional dependency for typing only
|
|
46
|
+
import litellm
|
|
47
|
+
|
|
48
|
+
_ = litellm.__name__
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class _CompletionFunctionCall(Protocol):
|
|
52
|
+
name: str
|
|
53
|
+
arguments: str | None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _ToolCall(Protocol):
|
|
57
|
+
id: str
|
|
58
|
+
function: _CompletionFunctionCall
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class _Message(Protocol):
|
|
62
|
+
content: str | Sequence[object] | None
|
|
63
|
+
tool_calls: Sequence[_ToolCall] | None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class _CompletionChoice(Protocol):
|
|
67
|
+
message: _Message
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _CompletionResponse(Protocol):
|
|
71
|
+
choices: Sequence[_CompletionChoice]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class _CompletionCallable(Protocol):
|
|
75
|
+
def __call__(self, *args: object, **kwargs: object) -> _CompletionResponse: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class _LiteLLMModule(Protocol):
|
|
79
|
+
def completion(self, *args: object, **kwargs: object) -> _CompletionResponse: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class _LiteLLMCompletionFactory(Protocol):
|
|
83
|
+
def __call__(self, **kwargs: object) -> _CompletionCallable: ...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
LiteLLMCompletion = _CompletionCallable
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _load_litellm_module() -> _LiteLLMModule:
|
|
90
|
+
try:
|
|
91
|
+
module = import_module("litellm")
|
|
92
|
+
except ModuleNotFoundError as exc: # pragma: no cover - dependency guard
|
|
93
|
+
raise RuntimeError(_ERROR_MESSAGE) from exc
|
|
94
|
+
return cast(_LiteLLMModule, module)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def create_litellm_completion(**kwargs: object) -> LiteLLMCompletion:
|
|
98
|
+
"""Return a LiteLLM completion callable, guarding the optional dependency."""
|
|
99
|
+
|
|
100
|
+
module = _load_litellm_module()
|
|
101
|
+
if not kwargs:
|
|
102
|
+
return module.completion
|
|
103
|
+
|
|
104
|
+
def _wrapped_completion(
|
|
105
|
+
*args: object, **request_kwargs: object
|
|
106
|
+
) -> _CompletionResponse:
|
|
107
|
+
merged: dict[str, object] = dict(kwargs)
|
|
108
|
+
merged.update(request_kwargs)
|
|
109
|
+
return module.completion(*args, **merged)
|
|
110
|
+
|
|
111
|
+
return _wrapped_completion
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
ToolChoice = Literal["auto"] | Mapping[str, Any] | None
|
|
115
|
+
"""Supported tool choice directives for provider APIs."""
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class LiteLLMAdapter:
|
|
119
|
+
"""Adapter that evaluates prompts via LiteLLM's completion helper."""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
*,
|
|
124
|
+
model: str,
|
|
125
|
+
tool_choice: ToolChoice = "auto",
|
|
126
|
+
completion: LiteLLMCompletion | None = None,
|
|
127
|
+
completion_factory: _LiteLLMCompletionFactory | None = None,
|
|
128
|
+
completion_kwargs: Mapping[str, object] | None = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
if completion is not None:
|
|
131
|
+
if completion_factory is not None:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"completion_factory cannot be provided when an explicit completion is supplied.",
|
|
134
|
+
)
|
|
135
|
+
if completion_kwargs:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
"completion_kwargs cannot be provided when an explicit completion is supplied.",
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
factory = completion_factory or create_litellm_completion
|
|
141
|
+
completion = factory(**dict(completion_kwargs or {}))
|
|
142
|
+
|
|
143
|
+
self._completion = completion
|
|
144
|
+
self._model = model
|
|
145
|
+
self._tool_choice: ToolChoice = tool_choice
|
|
146
|
+
|
|
147
|
+
def evaluate[OutputT](
|
|
148
|
+
self,
|
|
149
|
+
prompt: Prompt[OutputT],
|
|
150
|
+
*params: SupportsDataclass,
|
|
151
|
+
parse_output: bool = True,
|
|
152
|
+
bus: EventBus,
|
|
153
|
+
) -> PromptResponse[OutputT]:
|
|
154
|
+
prompt_name = prompt.name or prompt.__class__.__name__
|
|
155
|
+
has_structured_output = (
|
|
156
|
+
getattr(prompt, "_output_type", None) is not None
|
|
157
|
+
and getattr(prompt, "_output_container", None) is not None
|
|
158
|
+
)
|
|
159
|
+
should_disable_instructions = (
|
|
160
|
+
parse_output
|
|
161
|
+
and has_structured_output
|
|
162
|
+
and getattr(prompt, "inject_output_instructions", False)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if should_disable_instructions:
|
|
166
|
+
rendered = prompt.render(
|
|
167
|
+
*params,
|
|
168
|
+
inject_output_instructions=False,
|
|
169
|
+
) # type: ignore[reportArgumentType]
|
|
170
|
+
else:
|
|
171
|
+
rendered = prompt.render(*params) # type: ignore[reportArgumentType]
|
|
172
|
+
messages: list[dict[str, Any]] = [
|
|
173
|
+
{"role": "system", "content": rendered.text},
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
should_parse_structured_output = (
|
|
177
|
+
parse_output
|
|
178
|
+
and rendered.output_type is not None
|
|
179
|
+
and rendered.container is not None
|
|
180
|
+
)
|
|
181
|
+
response_format: dict[str, Any] | None = None
|
|
182
|
+
if should_parse_structured_output:
|
|
183
|
+
response_format = _build_json_schema_response_format(rendered, prompt_name)
|
|
184
|
+
|
|
185
|
+
tools = list(rendered.tools)
|
|
186
|
+
tool_specs = [_tool_to_litellm_spec(tool) for tool in tools]
|
|
187
|
+
tool_registry = {tool.name: tool for tool in tools}
|
|
188
|
+
tool_events: list[ToolInvoked] = []
|
|
189
|
+
provider_payload: dict[str, Any] | None = None
|
|
190
|
+
next_tool_choice: ToolChoice = self._tool_choice
|
|
191
|
+
|
|
192
|
+
while True:
|
|
193
|
+
request_payload: dict[str, Any] = {
|
|
194
|
+
"model": self._model,
|
|
195
|
+
"messages": messages,
|
|
196
|
+
}
|
|
197
|
+
if tool_specs:
|
|
198
|
+
request_payload["tools"] = tool_specs
|
|
199
|
+
if next_tool_choice is not None:
|
|
200
|
+
request_payload["tool_choice"] = next_tool_choice
|
|
201
|
+
if response_format is not None:
|
|
202
|
+
request_payload["response_format"] = response_format
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
response = self._completion(**request_payload)
|
|
206
|
+
except Exception as error: # pragma: no cover - network/SDK failure
|
|
207
|
+
raise PromptEvaluationError(
|
|
208
|
+
"LiteLLM request failed.",
|
|
209
|
+
prompt_name=prompt_name,
|
|
210
|
+
phase="request",
|
|
211
|
+
) from error
|
|
212
|
+
|
|
213
|
+
provider_payload = _extract_payload(response)
|
|
214
|
+
choice = _first_choice(response, prompt_name=prompt_name)
|
|
215
|
+
message = choice.message
|
|
216
|
+
tool_calls = list(message.tool_calls or [])
|
|
217
|
+
|
|
218
|
+
if not tool_calls:
|
|
219
|
+
final_text = _message_text_content(message.content)
|
|
220
|
+
output: OutputT | None = None
|
|
221
|
+
text_value: str | None = final_text or None
|
|
222
|
+
|
|
223
|
+
if should_parse_structured_output:
|
|
224
|
+
parsed_payload = _extract_parsed_content(message)
|
|
225
|
+
if parsed_payload is not None:
|
|
226
|
+
try:
|
|
227
|
+
output = cast(
|
|
228
|
+
OutputT,
|
|
229
|
+
_parse_schema_constrained_payload(
|
|
230
|
+
parsed_payload, rendered
|
|
231
|
+
),
|
|
232
|
+
)
|
|
233
|
+
except TypeError as error:
|
|
234
|
+
raise PromptEvaluationError(
|
|
235
|
+
str(error),
|
|
236
|
+
prompt_name=prompt_name,
|
|
237
|
+
phase="response",
|
|
238
|
+
provider_payload=provider_payload,
|
|
239
|
+
) from error
|
|
240
|
+
elif final_text:
|
|
241
|
+
try:
|
|
242
|
+
output = parse_structured_output(final_text, rendered)
|
|
243
|
+
except OutputParseError as error:
|
|
244
|
+
raise PromptEvaluationError(
|
|
245
|
+
error.message,
|
|
246
|
+
prompt_name=prompt_name,
|
|
247
|
+
phase="response",
|
|
248
|
+
provider_payload=provider_payload,
|
|
249
|
+
) from error
|
|
250
|
+
else:
|
|
251
|
+
raise PromptEvaluationError(
|
|
252
|
+
"Provider response did not include structured output.",
|
|
253
|
+
prompt_name=prompt_name,
|
|
254
|
+
phase="response",
|
|
255
|
+
provider_payload=provider_payload,
|
|
256
|
+
)
|
|
257
|
+
text_value = None
|
|
258
|
+
|
|
259
|
+
response = PromptResponse(
|
|
260
|
+
prompt_name=prompt_name,
|
|
261
|
+
text=text_value,
|
|
262
|
+
output=output,
|
|
263
|
+
tool_results=tuple(tool_events),
|
|
264
|
+
provider_payload=provider_payload,
|
|
265
|
+
)
|
|
266
|
+
bus.publish(
|
|
267
|
+
PromptExecuted(
|
|
268
|
+
prompt_name=prompt_name,
|
|
269
|
+
adapter="litellm",
|
|
270
|
+
result=cast(PromptResponse[object], response),
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
return response
|
|
274
|
+
|
|
275
|
+
assistant_tool_calls = [_serialize_tool_call(call) for call in tool_calls]
|
|
276
|
+
messages.append(
|
|
277
|
+
{
|
|
278
|
+
"role": "assistant",
|
|
279
|
+
"content": message.content or "",
|
|
280
|
+
"tool_calls": assistant_tool_calls,
|
|
281
|
+
}
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
for tool_call in tool_calls:
|
|
285
|
+
function = tool_call.function
|
|
286
|
+
tool_name = function.name
|
|
287
|
+
tool = tool_registry.get(tool_name)
|
|
288
|
+
if tool is None:
|
|
289
|
+
raise PromptEvaluationError(
|
|
290
|
+
f"Unknown tool '{tool_name}' requested by provider.",
|
|
291
|
+
prompt_name=prompt_name,
|
|
292
|
+
phase="tool",
|
|
293
|
+
provider_payload=provider_payload,
|
|
294
|
+
)
|
|
295
|
+
if tool.handler is None:
|
|
296
|
+
raise PromptEvaluationError(
|
|
297
|
+
f"Tool '{tool_name}' does not have a registered handler.",
|
|
298
|
+
prompt_name=prompt_name,
|
|
299
|
+
phase="tool",
|
|
300
|
+
provider_payload=provider_payload,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
arguments_mapping = _parse_tool_arguments(
|
|
304
|
+
function.arguments,
|
|
305
|
+
prompt_name=prompt_name,
|
|
306
|
+
provider_payload=provider_payload,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
try:
|
|
310
|
+
tool_params = parse(
|
|
311
|
+
tool.params_type,
|
|
312
|
+
arguments_mapping,
|
|
313
|
+
extra="forbid",
|
|
314
|
+
)
|
|
315
|
+
except (TypeError, ValueError) as error:
|
|
316
|
+
raise PromptEvaluationError(
|
|
317
|
+
f"Failed to parse params for tool '{tool_name}'.",
|
|
318
|
+
prompt_name=prompt_name,
|
|
319
|
+
phase="tool",
|
|
320
|
+
provider_payload=provider_payload,
|
|
321
|
+
) from error
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
tool_result = tool.handler(tool_params)
|
|
325
|
+
except Exception as error: # pragma: no cover - handler bug
|
|
326
|
+
raise PromptEvaluationError(
|
|
327
|
+
f"Tool '{tool_name}' raised an exception.",
|
|
328
|
+
prompt_name=prompt_name,
|
|
329
|
+
phase="tool",
|
|
330
|
+
provider_payload=provider_payload,
|
|
331
|
+
) from error
|
|
332
|
+
|
|
333
|
+
invocation = ToolInvoked(
|
|
334
|
+
prompt_name=prompt_name,
|
|
335
|
+
adapter="litellm",
|
|
336
|
+
name=tool_name,
|
|
337
|
+
params=tool_params,
|
|
338
|
+
result=cast(ToolResult[object], tool_result),
|
|
339
|
+
call_id=getattr(tool_call, "id", None),
|
|
340
|
+
)
|
|
341
|
+
tool_events.append(invocation)
|
|
342
|
+
bus.publish(invocation)
|
|
343
|
+
|
|
344
|
+
messages.append(
|
|
345
|
+
{
|
|
346
|
+
"role": "tool",
|
|
347
|
+
"tool_call_id": getattr(tool_call, "id", None),
|
|
348
|
+
"content": tool_result.message,
|
|
349
|
+
}
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if isinstance(next_tool_choice, Mapping):
|
|
353
|
+
tool_choice_mapping = cast(Mapping[str, object], next_tool_choice)
|
|
354
|
+
if tool_choice_mapping.get("type") == "function":
|
|
355
|
+
next_tool_choice = "auto"
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _build_json_schema_response_format(
|
|
359
|
+
rendered: RenderedPrompt[Any], prompt_name: str
|
|
360
|
+
) -> dict[str, Any] | None:
|
|
361
|
+
output_type = rendered.output_type
|
|
362
|
+
container = rendered.container
|
|
363
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
364
|
+
|
|
365
|
+
if output_type is None or container is None:
|
|
366
|
+
return None
|
|
367
|
+
|
|
368
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
369
|
+
base_schema = schema(output_type, extra=extra_mode)
|
|
370
|
+
base_schema.pop("title", None)
|
|
371
|
+
|
|
372
|
+
if container == "array":
|
|
373
|
+
schema_payload = cast(
|
|
374
|
+
dict[str, Any],
|
|
375
|
+
{
|
|
376
|
+
"type": "object",
|
|
377
|
+
"properties": {
|
|
378
|
+
ARRAY_WRAPPER_KEY: {
|
|
379
|
+
"type": "array",
|
|
380
|
+
"items": base_schema,
|
|
381
|
+
}
|
|
382
|
+
},
|
|
383
|
+
"required": [ARRAY_WRAPPER_KEY],
|
|
384
|
+
},
|
|
385
|
+
)
|
|
386
|
+
if not allow_extra_keys:
|
|
387
|
+
schema_payload["additionalProperties"] = False
|
|
388
|
+
else:
|
|
389
|
+
schema_payload = base_schema
|
|
390
|
+
|
|
391
|
+
schema_name = _schema_name(prompt_name)
|
|
392
|
+
return {
|
|
393
|
+
"type": "json_schema",
|
|
394
|
+
"json_schema": {
|
|
395
|
+
"name": schema_name,
|
|
396
|
+
"schema": schema_payload,
|
|
397
|
+
},
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _schema_name(prompt_name: str) -> str:
|
|
402
|
+
sanitized = re.sub(r"[^a-zA-Z0-9_-]+", "_", prompt_name.strip())
|
|
403
|
+
cleaned = sanitized.strip("_") or "prompt"
|
|
404
|
+
return f"{cleaned}_schema"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _tool_to_litellm_spec(tool: Tool[Any, Any]) -> dict[str, Any]:
|
|
408
|
+
parameters_schema = schema(tool.params_type, extra="forbid")
|
|
409
|
+
parameters_schema.pop("title", None)
|
|
410
|
+
return {
|
|
411
|
+
"type": "function",
|
|
412
|
+
"function": {
|
|
413
|
+
"name": tool.name,
|
|
414
|
+
"description": tool.description,
|
|
415
|
+
"parameters": parameters_schema,
|
|
416
|
+
},
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _extract_payload(response: _CompletionResponse) -> dict[str, Any] | None:
|
|
421
|
+
model_dump = getattr(response, "model_dump", None)
|
|
422
|
+
if callable(model_dump):
|
|
423
|
+
try:
|
|
424
|
+
payload = model_dump()
|
|
425
|
+
except Exception: # pragma: no cover - defensive
|
|
426
|
+
return None
|
|
427
|
+
if isinstance(payload, Mapping):
|
|
428
|
+
mapping_payload = cast(Mapping[str, Any], payload)
|
|
429
|
+
return dict(mapping_payload)
|
|
430
|
+
return None
|
|
431
|
+
if isinstance(response, Mapping): # pragma: no cover - defensive
|
|
432
|
+
return dict(response)
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _first_choice(
|
|
437
|
+
response: _CompletionResponse, *, prompt_name: str
|
|
438
|
+
) -> _CompletionChoice:
|
|
439
|
+
try:
|
|
440
|
+
return response.choices[0]
|
|
441
|
+
except (AttributeError, IndexError) as error: # pragma: no cover - defensive
|
|
442
|
+
raise PromptEvaluationError(
|
|
443
|
+
"Provider response did not include any choices.",
|
|
444
|
+
prompt_name=prompt_name,
|
|
445
|
+
phase="response",
|
|
446
|
+
) from error
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _serialize_tool_call(tool_call: _ToolCall) -> dict[str, Any]:
|
|
450
|
+
function = tool_call.function
|
|
451
|
+
return {
|
|
452
|
+
"id": getattr(tool_call, "id", None),
|
|
453
|
+
"type": "function",
|
|
454
|
+
"function": {
|
|
455
|
+
"name": function.name,
|
|
456
|
+
"arguments": function.arguments or "{}",
|
|
457
|
+
},
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def _parse_tool_arguments(
|
|
462
|
+
arguments_json: str | None,
|
|
463
|
+
*,
|
|
464
|
+
prompt_name: str,
|
|
465
|
+
provider_payload: dict[str, Any] | None,
|
|
466
|
+
) -> dict[str, Any]:
|
|
467
|
+
if not arguments_json:
|
|
468
|
+
return {}
|
|
469
|
+
try:
|
|
470
|
+
parsed = json.loads(arguments_json)
|
|
471
|
+
except json.JSONDecodeError as error:
|
|
472
|
+
raise PromptEvaluationError(
|
|
473
|
+
"Failed to decode tool call arguments.",
|
|
474
|
+
prompt_name=prompt_name,
|
|
475
|
+
phase="tool",
|
|
476
|
+
provider_payload=provider_payload,
|
|
477
|
+
) from error
|
|
478
|
+
if not isinstance(parsed, Mapping):
|
|
479
|
+
raise PromptEvaluationError(
|
|
480
|
+
"Tool call arguments must be a JSON object.",
|
|
481
|
+
prompt_name=prompt_name,
|
|
482
|
+
phase="tool",
|
|
483
|
+
provider_payload=provider_payload,
|
|
484
|
+
)
|
|
485
|
+
return dict(cast(Mapping[str, Any], parsed))
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _message_text_content(content: object) -> str:
|
|
489
|
+
if isinstance(content, str) or content is None:
|
|
490
|
+
return content or ""
|
|
491
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
492
|
+
content, (str, bytes, bytearray)
|
|
493
|
+
):
|
|
494
|
+
fragments: list[str] = []
|
|
495
|
+
sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
|
|
496
|
+
for part in sequence_content:
|
|
497
|
+
fragments.append(_content_part_text(part))
|
|
498
|
+
return "".join(fragments)
|
|
499
|
+
return str(content)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _content_part_text(part: object) -> str:
|
|
503
|
+
if part is None:
|
|
504
|
+
return ""
|
|
505
|
+
if isinstance(part, Mapping):
|
|
506
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
507
|
+
part_type = mapping_part.get("type")
|
|
508
|
+
if part_type in {"output_text", "text"}:
|
|
509
|
+
text_value = mapping_part.get("text")
|
|
510
|
+
if isinstance(text_value, str):
|
|
511
|
+
return text_value
|
|
512
|
+
return ""
|
|
513
|
+
part_type = getattr(part, "type", None)
|
|
514
|
+
if part_type in {"output_text", "text"}:
|
|
515
|
+
text_value = getattr(part, "text", None)
|
|
516
|
+
if isinstance(text_value, str):
|
|
517
|
+
return text_value
|
|
518
|
+
return ""
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def _extract_parsed_content(message: _Message) -> object | None:
|
|
522
|
+
parsed = getattr(message, "parsed", None)
|
|
523
|
+
if parsed is not None:
|
|
524
|
+
return parsed
|
|
525
|
+
|
|
526
|
+
content = message.content
|
|
527
|
+
if isinstance(content, Sequence) and not isinstance(
|
|
528
|
+
content, (str, bytes, bytearray)
|
|
529
|
+
):
|
|
530
|
+
sequence_content = cast(Sequence[object], content) # pyright: ignore[reportUnnecessaryCast]
|
|
531
|
+
for part in sequence_content:
|
|
532
|
+
payload = _parsed_payload_from_part(part)
|
|
533
|
+
if payload is not None:
|
|
534
|
+
return payload
|
|
535
|
+
return None
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _parsed_payload_from_part(part: object) -> object | None:
|
|
539
|
+
if isinstance(part, Mapping):
|
|
540
|
+
mapping_part = cast(Mapping[str, object], part)
|
|
541
|
+
if mapping_part.get("type") == "output_json":
|
|
542
|
+
return mapping_part.get("json")
|
|
543
|
+
return None
|
|
544
|
+
part_type = getattr(part, "type", None)
|
|
545
|
+
if part_type == "output_json":
|
|
546
|
+
return getattr(part, "json", None)
|
|
547
|
+
return None
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _parse_schema_constrained_payload(
|
|
551
|
+
payload: object, rendered: RenderedPrompt[Any]
|
|
552
|
+
) -> object:
|
|
553
|
+
dataclass_type = rendered.output_type
|
|
554
|
+
container = rendered.container
|
|
555
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
556
|
+
|
|
557
|
+
if dataclass_type is None or container is None:
|
|
558
|
+
raise TypeError("Prompt does not declare structured output.")
|
|
559
|
+
|
|
560
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
561
|
+
|
|
562
|
+
if container == "object":
|
|
563
|
+
if not isinstance(payload, Mapping):
|
|
564
|
+
raise TypeError("Expected provider payload to be a JSON object.")
|
|
565
|
+
return parse(
|
|
566
|
+
dataclass_type, cast(Mapping[str, object], payload), extra=extra_mode
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if container == "array":
|
|
570
|
+
if isinstance(payload, Mapping):
|
|
571
|
+
if ARRAY_WRAPPER_KEY not in payload:
|
|
572
|
+
raise TypeError("Expected provider payload to be a JSON array.")
|
|
573
|
+
payload = cast(Mapping[str, object], payload)[ARRAY_WRAPPER_KEY]
|
|
574
|
+
if not isinstance(payload, Sequence) or isinstance(
|
|
575
|
+
payload, (str, bytes, bytearray)
|
|
576
|
+
):
|
|
577
|
+
raise TypeError("Expected provider payload to be a JSON array.")
|
|
578
|
+
parsed_items: list[object] = []
|
|
579
|
+
sequence_payload = cast(Sequence[object], payload) # pyright: ignore[reportUnnecessaryCast]
|
|
580
|
+
for index, item in enumerate(sequence_payload):
|
|
581
|
+
if not isinstance(item, Mapping):
|
|
582
|
+
raise TypeError(f"Array item at index {index} is not an object.")
|
|
583
|
+
parsed_item = parse(
|
|
584
|
+
dataclass_type,
|
|
585
|
+
cast(Mapping[str, object], item),
|
|
586
|
+
extra=extra_mode,
|
|
587
|
+
)
|
|
588
|
+
parsed_items.append(parsed_item)
|
|
589
|
+
return parsed_items
|
|
590
|
+
|
|
591
|
+
raise TypeError("Unknown output container declared.")
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
__all__ = ["LiteLLMAdapter", "LiteLLMCompletion", "create_litellm_completion"]
|