weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Curated public surface for :mod:`weakincentives`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from .adapters import PromptResponse
|
|
18
|
+
from .dbc import (
|
|
19
|
+
dbc_active,
|
|
20
|
+
dbc_enabled,
|
|
21
|
+
ensure,
|
|
22
|
+
invariant,
|
|
23
|
+
pure,
|
|
24
|
+
require,
|
|
25
|
+
skip_invariant,
|
|
26
|
+
)
|
|
27
|
+
from .deadlines import Deadline
|
|
28
|
+
from .prompt import (
|
|
29
|
+
MarkdownSection,
|
|
30
|
+
Prompt,
|
|
31
|
+
SupportsDataclass,
|
|
32
|
+
Tool,
|
|
33
|
+
ToolContext,
|
|
34
|
+
ToolHandler,
|
|
35
|
+
ToolResult,
|
|
36
|
+
parse_structured_output,
|
|
37
|
+
)
|
|
38
|
+
from .runtime import StructuredLogger, configure_logging, get_logger
|
|
39
|
+
from .types import JSONValue
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"Deadline",
|
|
43
|
+
"JSONValue",
|
|
44
|
+
"MarkdownSection",
|
|
45
|
+
"Prompt",
|
|
46
|
+
"PromptResponse",
|
|
47
|
+
"StructuredLogger",
|
|
48
|
+
"SupportsDataclass",
|
|
49
|
+
"Tool",
|
|
50
|
+
"ToolContext",
|
|
51
|
+
"ToolHandler",
|
|
52
|
+
"ToolResult",
|
|
53
|
+
"configure_logging",
|
|
54
|
+
"dbc_active",
|
|
55
|
+
"dbc_enabled",
|
|
56
|
+
"ensure",
|
|
57
|
+
"get_logger",
|
|
58
|
+
"invariant",
|
|
59
|
+
"parse_structured_output",
|
|
60
|
+
"pure",
|
|
61
|
+
"require",
|
|
62
|
+
"skip_invariant",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def __dir__() -> list[str]:
|
|
67
|
+
return sorted({*globals().keys(), *__all__})
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Integration adapters for optional third-party providers."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from ._names import LITELLM_ADAPTER_NAME, OPENAI_ADAPTER_NAME, AdapterName
|
|
18
|
+
from .core import (
|
|
19
|
+
PromptEvaluationError,
|
|
20
|
+
PromptResponse,
|
|
21
|
+
ProviderAdapter,
|
|
22
|
+
SessionProtocol,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"LITELLM_ADAPTER_NAME",
|
|
27
|
+
"OPENAI_ADAPTER_NAME",
|
|
28
|
+
"AdapterName",
|
|
29
|
+
"PromptEvaluationError",
|
|
30
|
+
"PromptResponse",
|
|
31
|
+
"ProviderAdapter",
|
|
32
|
+
"SessionProtocol",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def __dir__() -> list[str]:
|
|
37
|
+
return sorted({*globals().keys(), *(__all__)})
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Semantic adapter name definitions shared across provider integrations."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Final, Literal
|
|
18
|
+
|
|
19
|
+
AdapterName = Literal["openai", "litellm"]
|
|
20
|
+
"""Recognized adapter identifiers for provider integrations."""
|
|
21
|
+
|
|
22
|
+
OPENAI_ADAPTER_NAME: Final[AdapterName] = "openai"
|
|
23
|
+
"""Canonical label for the OpenAI adapter."""
|
|
24
|
+
|
|
25
|
+
LITELLM_ADAPTER_NAME: Final[AdapterName] = "litellm"
|
|
26
|
+
"""Canonical label for the LiteLLM adapter."""
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"LITELLM_ADAPTER_NAME",
|
|
30
|
+
"OPENAI_ADAPTER_NAME",
|
|
31
|
+
"AdapterName",
|
|
32
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Structural typing primitives shared across provider adapters."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
from typing import Protocol
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"ProviderChoice",
|
|
22
|
+
"ProviderCompletionCallable",
|
|
23
|
+
"ProviderCompletionResponse",
|
|
24
|
+
"ProviderFunctionCall",
|
|
25
|
+
"ProviderMessage",
|
|
26
|
+
"ProviderToolCall",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ProviderFunctionCall(Protocol):
|
|
31
|
+
"""Structural Protocol describing a provider function call payload."""
|
|
32
|
+
|
|
33
|
+
name: str
|
|
34
|
+
arguments: str | None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ProviderToolCall(Protocol):
|
|
38
|
+
"""Structural Protocol describing a provider tool call payload."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def function(self) -> ProviderFunctionCall: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ProviderMessage(Protocol):
|
|
45
|
+
"""Structural Protocol describing a provider message payload."""
|
|
46
|
+
|
|
47
|
+
content: str | Sequence[object] | None
|
|
48
|
+
tool_calls: Sequence[ProviderToolCall] | None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ProviderChoice(Protocol):
|
|
52
|
+
"""Structural Protocol describing a provider choice payload."""
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def message(self) -> ProviderMessage: ...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ProviderCompletionResponse(Protocol):
|
|
59
|
+
"""Structural Protocol describing a provider completion response."""
|
|
60
|
+
|
|
61
|
+
choices: Sequence[ProviderChoice]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ProviderCompletionCallable(Protocol):
|
|
65
|
+
"""Structural Protocol describing a provider completion callable."""
|
|
66
|
+
|
|
67
|
+
def __call__(
|
|
68
|
+
self, *args: object, **kwargs: object
|
|
69
|
+
) -> ProviderCompletionResponse: ...
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared helpers for serialising tool results into provider messages."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from collections.abc import Mapping, Sequence
|
|
19
|
+
from dataclasses import is_dataclass
|
|
20
|
+
from typing import Final, cast
|
|
21
|
+
|
|
22
|
+
from ..prompt._types import SupportsToolResult
|
|
23
|
+
from ..prompt.tool import ToolResult
|
|
24
|
+
from ..serde import dump
|
|
25
|
+
|
|
26
|
+
type JsonPrimitive = str | int | float | bool | None
|
|
27
|
+
type JsonValue = JsonPrimitive | list["JsonValue"] | dict[str, "JsonValue"]
|
|
28
|
+
|
|
29
|
+
_UNSET: Final = object()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def serialize_tool_message(
|
|
33
|
+
result: ToolResult[SupportsToolResult], *, payload: object = _UNSET
|
|
34
|
+
) -> str:
|
|
35
|
+
"""Return a JSON string summarising a tool invocation for provider APIs."""
|
|
36
|
+
|
|
37
|
+
message_payload: dict[str, JsonValue | bool | str] = {
|
|
38
|
+
"message": result.message,
|
|
39
|
+
"success": result.success,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if not result.exclude_value_from_context:
|
|
43
|
+
value = result.value if payload is _UNSET else payload
|
|
44
|
+
if value is not None:
|
|
45
|
+
message_payload["payload"] = _serialize_value(value)
|
|
46
|
+
|
|
47
|
+
return json.dumps(message_payload, ensure_ascii=False)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _serialize_value(value: object) -> JsonValue:
|
|
51
|
+
"""Convert tool result payloads to JSON-compatible structures."""
|
|
52
|
+
|
|
53
|
+
if is_dataclass(value):
|
|
54
|
+
return _serialize_str_mapping(dump(value, exclude_none=True))
|
|
55
|
+
|
|
56
|
+
if isinstance(value, Mapping):
|
|
57
|
+
return _serialize_mapping(cast(Mapping[object, object], value))
|
|
58
|
+
|
|
59
|
+
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
|
60
|
+
sequence_items = cast(Sequence[object], value)
|
|
61
|
+
return [_serialize_value(item) for item in sequence_items]
|
|
62
|
+
|
|
63
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
return str(value)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _serialize_mapping(mapping: Mapping[object, object]) -> dict[str, JsonValue]:
|
|
70
|
+
serialized: dict[str, JsonValue] = {}
|
|
71
|
+
for key, item in mapping.items():
|
|
72
|
+
serialized[str(key)] = _serialize_value(item)
|
|
73
|
+
return serialized
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _serialize_str_mapping(mapping: Mapping[str, object]) -> dict[str, JsonValue]:
|
|
77
|
+
return {key: _serialize_value(item) for key, item in mapping.items()}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
__all__ = ["serialize_tool_message"]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Core adapter interfaces shared across provider integrations."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar
|
|
19
|
+
|
|
20
|
+
from ..deadlines import Deadline
|
|
21
|
+
from ..prompt._types import SupportsDataclass
|
|
22
|
+
from ..runtime.session.protocols import SessionProtocol
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..prompt.overrides import PromptOverridesStore
|
|
26
|
+
from ..prompt.prompt import Prompt
|
|
27
|
+
from ..runtime.events._types import EventBus, ToolInvoked
|
|
28
|
+
|
|
29
|
+
OutputT = TypeVar("OutputT")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ProviderAdapter(Protocol[OutputT]):
|
|
33
|
+
"""Protocol describing the synchronous adapter contract."""
|
|
34
|
+
|
|
35
|
+
def evaluate(
|
|
36
|
+
self,
|
|
37
|
+
prompt: Prompt[OutputT],
|
|
38
|
+
*params: SupportsDataclass,
|
|
39
|
+
parse_output: bool = True,
|
|
40
|
+
bus: EventBus,
|
|
41
|
+
session: SessionProtocol,
|
|
42
|
+
deadline: Deadline | None = None,
|
|
43
|
+
overrides_store: PromptOverridesStore | None = None,
|
|
44
|
+
overrides_tag: str = "latest",
|
|
45
|
+
) -> PromptResponse[OutputT]:
|
|
46
|
+
"""Evaluate the prompt and return a structured response."""
|
|
47
|
+
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(slots=True)
|
|
52
|
+
class PromptResponse[OutputT]:
|
|
53
|
+
"""Structured result emitted by an adapter evaluation."""
|
|
54
|
+
|
|
55
|
+
prompt_name: str
|
|
56
|
+
text: str | None
|
|
57
|
+
output: OutputT | None
|
|
58
|
+
tool_results: tuple[ToolInvoked, ...]
|
|
59
|
+
provider_payload: dict[str, Any] | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PromptEvaluationError(RuntimeError):
|
|
63
|
+
"""Raised when evaluation against a provider fails."""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
message: str,
|
|
68
|
+
*,
|
|
69
|
+
prompt_name: str,
|
|
70
|
+
phase: PromptEvaluationPhase,
|
|
71
|
+
provider_payload: dict[str, Any] | None = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
super().__init__(message)
|
|
74
|
+
self.message = message
|
|
75
|
+
self.prompt_name = prompt_name
|
|
76
|
+
self.phase: PromptEvaluationPhase = phase
|
|
77
|
+
self.provider_payload = provider_payload
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
PromptEvaluationPhase = Literal["request", "response", "tool"]
|
|
81
|
+
"""Phases where a prompt evaluation error can occur."""
|
|
82
|
+
|
|
83
|
+
PROMPT_EVALUATION_PHASE_REQUEST: PromptEvaluationPhase = "request"
|
|
84
|
+
"""Prompt evaluation failed while issuing the provider request."""
|
|
85
|
+
|
|
86
|
+
PROMPT_EVALUATION_PHASE_RESPONSE: PromptEvaluationPhase = "response"
|
|
87
|
+
"""Prompt evaluation failed while handling the provider response."""
|
|
88
|
+
|
|
89
|
+
PROMPT_EVALUATION_PHASE_TOOL: PromptEvaluationPhase = "tool"
|
|
90
|
+
"""Prompt evaluation failed while handling a tool invocation."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
__all__ = [
|
|
94
|
+
"PROMPT_EVALUATION_PHASE_REQUEST",
|
|
95
|
+
"PROMPT_EVALUATION_PHASE_RESPONSE",
|
|
96
|
+
"PROMPT_EVALUATION_PHASE_TOOL",
|
|
97
|
+
"PromptEvaluationError",
|
|
98
|
+
"PromptEvaluationPhase",
|
|
99
|
+
"PromptResponse",
|
|
100
|
+
"ProviderAdapter",
|
|
101
|
+
"SessionProtocol",
|
|
102
|
+
]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Optional LiteLLM adapter utilities."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping, Sequence
|
|
18
|
+
from dataclasses import replace
|
|
19
|
+
from datetime import timedelta
|
|
20
|
+
from importlib import import_module
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Final, Protocol, cast
|
|
22
|
+
|
|
23
|
+
from ..deadlines import Deadline
|
|
24
|
+
from ..prompt._types import SupportsDataclass
|
|
25
|
+
from ..prompt.prompt import Prompt
|
|
26
|
+
from ..runtime.events import EventBus
|
|
27
|
+
from ..runtime.logging import StructuredLogger, get_logger
|
|
28
|
+
from . import shared as _shared
|
|
29
|
+
from ._provider_protocols import (
|
|
30
|
+
ProviderChoice,
|
|
31
|
+
ProviderCompletionCallable,
|
|
32
|
+
ProviderCompletionResponse,
|
|
33
|
+
)
|
|
34
|
+
from ._tool_messages import serialize_tool_message
|
|
35
|
+
from .core import (
|
|
36
|
+
PROMPT_EVALUATION_PHASE_REQUEST,
|
|
37
|
+
PromptEvaluationError,
|
|
38
|
+
PromptResponse,
|
|
39
|
+
SessionProtocol,
|
|
40
|
+
)
|
|
41
|
+
from .shared import (
|
|
42
|
+
LITELLM_ADAPTER_NAME,
|
|
43
|
+
ToolChoice,
|
|
44
|
+
build_json_schema_response_format,
|
|
45
|
+
deadline_provider_payload,
|
|
46
|
+
first_choice,
|
|
47
|
+
format_publish_failures,
|
|
48
|
+
parse_tool_arguments,
|
|
49
|
+
run_conversation,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if TYPE_CHECKING:
|
|
53
|
+
from ..adapters.core import ProviderAdapter
|
|
54
|
+
from ..prompt.overrides import PromptOverridesStore
|
|
55
|
+
|
|
56
|
+
_ERROR_MESSAGE: Final[str] = (
|
|
57
|
+
"LiteLLM support requires the optional 'litellm' dependency. "
|
|
58
|
+
"Install it with `uv sync --extra litellm` or `pip install weakincentives[litellm]`."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _LiteLLMModule(Protocol):
|
|
63
|
+
def completion(
|
|
64
|
+
self, *args: object, **kwargs: object
|
|
65
|
+
) -> ProviderCompletionResponse: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _LiteLLMCompletionFactory(Protocol):
|
|
69
|
+
def __call__(self, **kwargs: object) -> ProviderCompletionCallable: ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
LiteLLMCompletion = ProviderCompletionCallable
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _load_litellm_module() -> _LiteLLMModule:
|
|
76
|
+
try:
|
|
77
|
+
module = import_module("litellm")
|
|
78
|
+
except ModuleNotFoundError as exc: # pragma: no cover - dependency guard
|
|
79
|
+
raise RuntimeError(_ERROR_MESSAGE) from exc
|
|
80
|
+
return cast(_LiteLLMModule, module)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def create_litellm_completion(**kwargs: object) -> LiteLLMCompletion:
|
|
84
|
+
"""Return a LiteLLM completion callable, guarding the optional dependency."""
|
|
85
|
+
|
|
86
|
+
module = _load_litellm_module()
|
|
87
|
+
if not kwargs:
|
|
88
|
+
return module.completion
|
|
89
|
+
|
|
90
|
+
def _wrapped_completion(
|
|
91
|
+
*args: object, **request_kwargs: object
|
|
92
|
+
) -> ProviderCompletionResponse:
|
|
93
|
+
merged: dict[str, object] = dict(kwargs)
|
|
94
|
+
merged.update(request_kwargs)
|
|
95
|
+
return module.completion(*args, **merged)
|
|
96
|
+
|
|
97
|
+
return _wrapped_completion
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
logger: StructuredLogger = get_logger(
|
|
101
|
+
__name__, context={"component": "adapter.litellm"}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class LiteLLMAdapter:
|
|
106
|
+
"""Adapter that evaluates prompts via LiteLLM's completion helper."""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
*,
|
|
111
|
+
model: str,
|
|
112
|
+
tool_choice: ToolChoice = "auto",
|
|
113
|
+
completion: LiteLLMCompletion | None = None,
|
|
114
|
+
completion_factory: _LiteLLMCompletionFactory | None = None,
|
|
115
|
+
completion_kwargs: Mapping[str, object] | None = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
super().__init__()
|
|
118
|
+
if completion is not None:
|
|
119
|
+
if completion_factory is not None:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"completion_factory cannot be provided when an explicit completion is supplied.",
|
|
122
|
+
)
|
|
123
|
+
if completion_kwargs:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"completion_kwargs cannot be provided when an explicit completion is supplied.",
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
factory = completion_factory or create_litellm_completion
|
|
129
|
+
completion = factory(**dict(completion_kwargs or {}))
|
|
130
|
+
|
|
131
|
+
self._completion = completion
|
|
132
|
+
self._model = model
|
|
133
|
+
self._tool_choice: ToolChoice = tool_choice
|
|
134
|
+
|
|
135
|
+
def evaluate[OutputT](
|
|
136
|
+
self,
|
|
137
|
+
prompt: Prompt[OutputT],
|
|
138
|
+
*params: SupportsDataclass,
|
|
139
|
+
parse_output: bool = True,
|
|
140
|
+
bus: EventBus,
|
|
141
|
+
session: SessionProtocol,
|
|
142
|
+
deadline: Deadline | None = None,
|
|
143
|
+
overrides_store: PromptOverridesStore | None = None,
|
|
144
|
+
overrides_tag: str = "latest",
|
|
145
|
+
) -> PromptResponse[OutputT]:
|
|
146
|
+
prompt_name = prompt.name or prompt.__class__.__name__
|
|
147
|
+
render_inputs: tuple[SupportsDataclass, ...] = tuple(params)
|
|
148
|
+
if deadline is not None and deadline.remaining() <= timedelta(0):
|
|
149
|
+
raise PromptEvaluationError(
|
|
150
|
+
"Deadline expired before evaluation started.",
|
|
151
|
+
prompt_name=prompt_name,
|
|
152
|
+
phase=PROMPT_EVALUATION_PHASE_REQUEST,
|
|
153
|
+
provider_payload=deadline_provider_payload(deadline),
|
|
154
|
+
)
|
|
155
|
+
has_structured_output = prompt.structured_output is not None
|
|
156
|
+
should_disable_instructions = (
|
|
157
|
+
parse_output
|
|
158
|
+
and has_structured_output
|
|
159
|
+
and getattr(prompt, "inject_output_instructions", False)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if should_disable_instructions:
|
|
163
|
+
rendered = prompt.render(
|
|
164
|
+
*params,
|
|
165
|
+
overrides_store=overrides_store,
|
|
166
|
+
tag=overrides_tag,
|
|
167
|
+
inject_output_instructions=False,
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
rendered = prompt.render(
|
|
171
|
+
*params,
|
|
172
|
+
overrides_store=overrides_store,
|
|
173
|
+
tag=overrides_tag,
|
|
174
|
+
)
|
|
175
|
+
if deadline is not None:
|
|
176
|
+
rendered = replace(rendered, deadline=deadline)
|
|
177
|
+
response_format: dict[str, Any] | None = None
|
|
178
|
+
should_parse_structured_output = (
|
|
179
|
+
parse_output
|
|
180
|
+
and rendered.output_type is not None
|
|
181
|
+
and rendered.container is not None
|
|
182
|
+
)
|
|
183
|
+
if should_parse_structured_output:
|
|
184
|
+
response_format = build_json_schema_response_format(rendered, prompt_name)
|
|
185
|
+
|
|
186
|
+
def _call_provider(
|
|
187
|
+
messages: list[dict[str, Any]],
|
|
188
|
+
tool_specs: Sequence[Mapping[str, Any]],
|
|
189
|
+
tool_choice_directive: ToolChoice | None,
|
|
190
|
+
response_format_payload: Mapping[str, Any] | None,
|
|
191
|
+
) -> object:
|
|
192
|
+
request_payload: dict[str, Any] = {
|
|
193
|
+
"model": self._model,
|
|
194
|
+
"messages": messages,
|
|
195
|
+
}
|
|
196
|
+
if tool_specs:
|
|
197
|
+
request_payload["tools"] = list(tool_specs)
|
|
198
|
+
if tool_choice_directive is not None:
|
|
199
|
+
request_payload["tool_choice"] = tool_choice_directive
|
|
200
|
+
if response_format_payload is not None:
|
|
201
|
+
request_payload["response_format"] = response_format_payload
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
return self._completion(**request_payload)
|
|
205
|
+
except Exception as error: # pragma: no cover - network/SDK failure
|
|
206
|
+
raise PromptEvaluationError(
|
|
207
|
+
"LiteLLM request failed.",
|
|
208
|
+
prompt_name=prompt_name,
|
|
209
|
+
phase=PROMPT_EVALUATION_PHASE_REQUEST,
|
|
210
|
+
) from error
|
|
211
|
+
|
|
212
|
+
def _select_choice(response: object) -> ProviderChoice:
|
|
213
|
+
return cast(
|
|
214
|
+
ProviderChoice,
|
|
215
|
+
first_choice(response, prompt_name=prompt_name),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return run_conversation(
|
|
219
|
+
adapter_name=LITELLM_ADAPTER_NAME,
|
|
220
|
+
adapter=cast("ProviderAdapter[OutputT]", self),
|
|
221
|
+
prompt=prompt,
|
|
222
|
+
prompt_name=prompt_name,
|
|
223
|
+
rendered=rendered,
|
|
224
|
+
render_inputs=render_inputs,
|
|
225
|
+
initial_messages=[{"role": "system", "content": rendered.text}],
|
|
226
|
+
parse_output=parse_output,
|
|
227
|
+
bus=bus,
|
|
228
|
+
session=session,
|
|
229
|
+
tool_choice=self._tool_choice,
|
|
230
|
+
response_format=response_format,
|
|
231
|
+
require_structured_output_text=True,
|
|
232
|
+
call_provider=_call_provider,
|
|
233
|
+
select_choice=_select_choice,
|
|
234
|
+
serialize_tool_message_fn=serialize_tool_message,
|
|
235
|
+
format_publish_failures=format_publish_failures,
|
|
236
|
+
parse_arguments=parse_tool_arguments,
|
|
237
|
+
logger_override=logger,
|
|
238
|
+
deadline=deadline,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
__all__ = [
|
|
243
|
+
"LiteLLMAdapter",
|
|
244
|
+
"LiteLLMCompletion",
|
|
245
|
+
"create_litellm_completion",
|
|
246
|
+
"extract_parsed_content",
|
|
247
|
+
"message_text_content",
|
|
248
|
+
"parse_schema_constrained_payload",
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
message_text_content = _shared.message_text_content
|
|
253
|
+
extract_parsed_content = _shared.extract_parsed_content
|
|
254
|
+
parse_schema_constrained_payload = _shared.parse_schema_constrained_payload
|