weakincentives 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +26 -2
- weakincentives/adapters/__init__.py +6 -5
- weakincentives/adapters/core.py +7 -17
- weakincentives/adapters/litellm.py +594 -0
- weakincentives/adapters/openai.py +286 -57
- weakincentives/events.py +103 -0
- weakincentives/examples/__init__.py +67 -0
- weakincentives/examples/code_review_prompt.py +118 -0
- weakincentives/examples/code_review_session.py +171 -0
- weakincentives/examples/code_review_tools.py +376 -0
- weakincentives/{prompts → prompt}/__init__.py +6 -8
- weakincentives/{prompts → prompt}/_types.py +1 -1
- weakincentives/{prompts/text.py → prompt/markdown.py} +19 -9
- weakincentives/{prompts → prompt}/prompt.py +216 -66
- weakincentives/{prompts → prompt}/response_format.py +9 -6
- weakincentives/{prompts → prompt}/section.py +25 -4
- weakincentives/{prompts/structured.py → prompt/structured_output.py} +16 -5
- weakincentives/{prompts → prompt}/tool.py +6 -6
- weakincentives/prompt/versioning.py +144 -0
- weakincentives/serde/__init__.py +0 -14
- weakincentives/serde/dataclass_serde.py +3 -17
- weakincentives/session/__init__.py +31 -0
- weakincentives/session/reducers.py +60 -0
- weakincentives/session/selectors.py +45 -0
- weakincentives/session/session.py +168 -0
- weakincentives/tools/__init__.py +69 -0
- weakincentives/tools/errors.py +22 -0
- weakincentives/tools/planning.py +538 -0
- weakincentives/tools/vfs.py +590 -0
- weakincentives-0.3.0.dist-info/METADATA +231 -0
- weakincentives-0.3.0.dist-info/RECORD +35 -0
- weakincentives-0.2.0.dist-info/METADATA +0 -173
- weakincentives-0.2.0.dist-info/RECORD +0 -20
- /weakincentives/{prompts → prompt}/errors.py +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,20 +16,20 @@ import inspect
|
|
|
16
16
|
import re
|
|
17
17
|
from collections.abc import Callable
|
|
18
18
|
from dataclasses import dataclass, field, is_dataclass
|
|
19
|
-
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
|
19
|
+
from typing import Annotated, Any, Final, cast, get_args, get_origin, get_type_hints
|
|
20
20
|
|
|
21
21
|
from ._types import SupportsDataclass
|
|
22
22
|
from .errors import PromptValidationError
|
|
23
23
|
|
|
24
|
-
_NAME_PATTERN = re.compile(r"^[a-z0-9_]{1,64}$")
|
|
24
|
+
_NAME_PATTERN: Final[re.Pattern[str]] = re.compile(r"^[a-z0-9_-]{1,64}$")
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@dataclass(slots=True)
|
|
28
|
-
class ToolResult[
|
|
28
|
+
class ToolResult[ResultValueT]:
|
|
29
29
|
"""Structured response emitted by a tool handler."""
|
|
30
30
|
|
|
31
31
|
message: str
|
|
32
|
-
|
|
32
|
+
value: ResultValueT
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@dataclass(slots=True)
|
|
@@ -91,13 +91,13 @@ class Tool[ParamsT: SupportsDataclass, ResultT: SupportsDataclass]:
|
|
|
91
91
|
name_clean = raw_name
|
|
92
92
|
if not name_clean:
|
|
93
93
|
raise PromptValidationError(
|
|
94
|
-
"Tool name must
|
|
94
|
+
"Tool name must match the OpenAI function name constraints (1-64 lowercase ASCII letters, digits, underscores, or hyphens).",
|
|
95
95
|
dataclass_type=params_type,
|
|
96
96
|
placeholder=stripped_name,
|
|
97
97
|
)
|
|
98
98
|
if len(name_clean) > 64 or not _NAME_PATTERN.fullmatch(name_clean):
|
|
99
99
|
raise PromptValidationError(
|
|
100
|
-
"Tool name must
|
|
100
|
+
"Tool name must match the OpenAI function name constraints (pattern: ^[a-z0-9_-]{1,64}$).",
|
|
101
101
|
dataclass_type=params_type,
|
|
102
102
|
placeholder=name_clean,
|
|
103
103
|
)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from hashlib import sha256
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
19
|
+
|
|
20
|
+
from ..serde.dataclass_serde import schema
|
|
21
|
+
from .tool import Tool
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _tool_override_mapping_factory() -> dict[str, ToolOverride]:
|
|
25
|
+
return {}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _param_description_mapping_factory() -> dict[str, str]:
|
|
29
|
+
return {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from .prompt import Prompt
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(slots=True)
|
|
37
|
+
class SectionDescriptor:
|
|
38
|
+
"""Hash metadata for a single section within a prompt."""
|
|
39
|
+
|
|
40
|
+
path: tuple[str, ...]
|
|
41
|
+
content_hash: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(slots=True)
|
|
45
|
+
class ToolDescriptor:
|
|
46
|
+
"""Stable metadata describing a tool exposed by a prompt."""
|
|
47
|
+
|
|
48
|
+
path: tuple[str, ...]
|
|
49
|
+
name: str
|
|
50
|
+
contract_hash: str
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(slots=True)
|
|
54
|
+
class PromptDescriptor:
|
|
55
|
+
"""Stable metadata describing a prompt and its hash-aware sections."""
|
|
56
|
+
|
|
57
|
+
ns: str
|
|
58
|
+
key: str
|
|
59
|
+
sections: list[SectionDescriptor]
|
|
60
|
+
tools: list[ToolDescriptor]
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def from_prompt(cls, prompt: Prompt[Any]) -> PromptDescriptor:
|
|
64
|
+
sections: list[SectionDescriptor] = []
|
|
65
|
+
tools: list[ToolDescriptor] = []
|
|
66
|
+
for node in prompt.sections:
|
|
67
|
+
template = node.section.original_body_template()
|
|
68
|
+
if template is not None:
|
|
69
|
+
content_hash = sha256(template.encode("utf-8")).hexdigest()
|
|
70
|
+
sections.append(SectionDescriptor(node.path, content_hash))
|
|
71
|
+
for tool in node.section.tools():
|
|
72
|
+
tools.append(
|
|
73
|
+
ToolDescriptor(
|
|
74
|
+
path=node.path,
|
|
75
|
+
name=tool.name,
|
|
76
|
+
contract_hash=_tool_contract_hash(tool),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
return cls(prompt.ns, prompt.key, sections, tools)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(slots=True)
|
|
83
|
+
class ToolOverride:
|
|
84
|
+
"""Description overrides validated against a tool contract hash."""
|
|
85
|
+
|
|
86
|
+
name: str
|
|
87
|
+
expected_contract_hash: str
|
|
88
|
+
description: str | None = None
|
|
89
|
+
param_descriptions: dict[str, str] = field(
|
|
90
|
+
default_factory=_param_description_mapping_factory
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass(slots=True)
|
|
95
|
+
class PromptOverride:
|
|
96
|
+
"""Runtime replacements for prompt sections validated by a version store."""
|
|
97
|
+
|
|
98
|
+
ns: str
|
|
99
|
+
prompt_key: str
|
|
100
|
+
tag: str
|
|
101
|
+
overrides: dict[tuple[str, ...], str]
|
|
102
|
+
tool_overrides: dict[str, ToolOverride] = field(
|
|
103
|
+
default_factory=_tool_override_mapping_factory
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class PromptVersionStore(Protocol):
|
|
108
|
+
"""Lookup interface for resolving prompt overrides at render time."""
|
|
109
|
+
|
|
110
|
+
def resolve(
|
|
111
|
+
self,
|
|
112
|
+
descriptor: PromptDescriptor,
|
|
113
|
+
tag: str = "latest",
|
|
114
|
+
) -> PromptOverride | None: ...
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
__all__ = [
|
|
118
|
+
"PromptDescriptor",
|
|
119
|
+
"PromptOverride",
|
|
120
|
+
"PromptVersionStore",
|
|
121
|
+
"SectionDescriptor",
|
|
122
|
+
"ToolDescriptor",
|
|
123
|
+
"ToolOverride",
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _tool_contract_hash(tool: Tool[Any, Any]) -> str:
|
|
128
|
+
description_hash = hash_text(tool.description)
|
|
129
|
+
params_schema_hash = hash_json(schema(tool.params_type, extra="forbid"))
|
|
130
|
+
result_schema_hash = hash_json(schema(tool.result_type, extra="ignore"))
|
|
131
|
+
return hash_text(
|
|
132
|
+
"::".join((description_hash, params_schema_hash, result_schema_hash))
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def hash_text(value: str) -> str:
|
|
137
|
+
return sha256(value.encode("utf-8")).hexdigest()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def hash_json(value: object) -> str:
|
|
141
|
+
canonical = json.dumps(
|
|
142
|
+
value, sort_keys=True, separators=(",", ":"), ensure_ascii=True
|
|
143
|
+
)
|
|
144
|
+
return hash_text(canonical)
|
weakincentives/serde/__init__.py
CHANGED
|
@@ -10,20 +10,6 @@
|
|
|
10
10
|
# See the License for the specific language governing permissions and
|
|
11
11
|
# limitations under the License.
|
|
12
12
|
|
|
13
|
-
# Copyright 2025 weak incentives
|
|
14
|
-
#
|
|
15
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
16
|
-
# you may not use this file except in compliance with the License.
|
|
17
|
-
# You may obtain a copy of the License at
|
|
18
|
-
#
|
|
19
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
20
|
-
#
|
|
21
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
22
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
23
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
24
|
-
# See the License for the specific language governing permissions and
|
|
25
|
-
# limitations under the License.
|
|
26
|
-
|
|
27
13
|
"""Stdlib dataclass serde utilities."""
|
|
28
14
|
|
|
29
15
|
from .dataclass_serde import clone, dump, parse, schema
|
|
@@ -10,20 +10,6 @@
|
|
|
10
10
|
# See the License for the specific language governing permissions and
|
|
11
11
|
# limitations under the License.
|
|
12
12
|
|
|
13
|
-
# Copyright 2025 weak incentives
|
|
14
|
-
#
|
|
15
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
16
|
-
# you may not use this file except in compliance with the License.
|
|
17
|
-
# You may obtain a copy of the License at
|
|
18
|
-
#
|
|
19
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
20
|
-
#
|
|
21
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
22
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
23
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
24
|
-
# See the License for the specific language governing permissions and
|
|
25
|
-
# limitations under the License.
|
|
26
|
-
|
|
27
13
|
from __future__ import annotations
|
|
28
14
|
|
|
29
15
|
import dataclasses
|
|
@@ -36,10 +22,10 @@ from enum import Enum
|
|
|
36
22
|
from pathlib import Path
|
|
37
23
|
from re import Pattern
|
|
38
24
|
from typing import Any as _AnyType
|
|
39
|
-
from typing import Literal, Union, cast, get_args, get_origin, get_type_hints
|
|
25
|
+
from typing import Final, Literal, Union, cast, get_args, get_origin, get_type_hints
|
|
40
26
|
from uuid import UUID
|
|
41
27
|
|
|
42
|
-
MISSING_SENTINEL: object = object()
|
|
28
|
+
MISSING_SENTINEL: Final[object] = object()
|
|
43
29
|
|
|
44
30
|
|
|
45
31
|
class _ExtrasDescriptor:
|
|
@@ -63,7 +49,7 @@ class _ExtrasDescriptor:
|
|
|
63
49
|
self._store[key] = dict(value)
|
|
64
50
|
|
|
65
51
|
|
|
66
|
-
_SLOTTED_EXTRAS: dict[type[object], _ExtrasDescriptor] = {}
|
|
52
|
+
_SLOTTED_EXTRAS: Final[dict[type[object], _ExtrasDescriptor]] = {}
|
|
67
53
|
|
|
68
54
|
|
|
69
55
|
def _ordered_values(values: Iterable[object]) -> list[object]:
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Session state container for agent runs."""
|
|
14
|
+
|
|
15
|
+
from .reducers import append, replace_latest, upsert_by
|
|
16
|
+
from .selectors import select_all, select_latest, select_where
|
|
17
|
+
from .session import DataEvent, PromptData, Session, ToolData, TypedReducer
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"Session",
|
|
21
|
+
"DataEvent",
|
|
22
|
+
"ToolData",
|
|
23
|
+
"PromptData",
|
|
24
|
+
"TypedReducer",
|
|
25
|
+
"append",
|
|
26
|
+
"upsert_by",
|
|
27
|
+
"replace_latest",
|
|
28
|
+
"select_all",
|
|
29
|
+
"select_latest",
|
|
30
|
+
"select_where",
|
|
31
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Built-in reducers for Session state slices."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import cast
|
|
19
|
+
|
|
20
|
+
from .session import DataEvent, TypedReducer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def append[T](slice_values: tuple[T, ...], event: DataEvent) -> tuple[T, ...]:
|
|
24
|
+
"""Append the event value if it is not already present."""
|
|
25
|
+
|
|
26
|
+
value = cast(T, event.value)
|
|
27
|
+
if value in slice_values:
|
|
28
|
+
return slice_values
|
|
29
|
+
return slice_values + (value,)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def upsert_by[T, K](key_fn: Callable[[T], K]) -> TypedReducer[T]:
|
|
33
|
+
"""Return a reducer that upserts items sharing the same derived key."""
|
|
34
|
+
|
|
35
|
+
def reducer(slice_values: tuple[T, ...], event: DataEvent) -> tuple[T, ...]:
|
|
36
|
+
value = cast(T, event.value)
|
|
37
|
+
key = key_fn(value)
|
|
38
|
+
updated: list[T] = []
|
|
39
|
+
replaced = False
|
|
40
|
+
for existing in slice_values:
|
|
41
|
+
if key_fn(existing) == key:
|
|
42
|
+
if not replaced:
|
|
43
|
+
updated.append(value)
|
|
44
|
+
replaced = True
|
|
45
|
+
continue
|
|
46
|
+
updated.append(existing)
|
|
47
|
+
if not replaced:
|
|
48
|
+
updated.append(value)
|
|
49
|
+
return tuple(updated)
|
|
50
|
+
|
|
51
|
+
return reducer
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def replace_latest[T](slice_values: tuple[T, ...], event: DataEvent) -> tuple[T, ...]:
|
|
55
|
+
"""Keep only the most recent event value."""
|
|
56
|
+
|
|
57
|
+
return (cast(T, event.value),)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
__all__ = ["append", "upsert_by", "replace_latest"]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Helpers for querying Session slices."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
|
|
19
|
+
from .session import Session
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def select_all[T](session: Session, slice_type: type[T]) -> tuple[T, ...]:
|
|
23
|
+
"""Return the entire slice for the provided type."""
|
|
24
|
+
|
|
25
|
+
return session.select_all(slice_type)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def select_latest[T](session: Session, slice_type: type[T]) -> T | None:
|
|
29
|
+
"""Return the most recent item in the slice, if any."""
|
|
30
|
+
|
|
31
|
+
values = session.select_all(slice_type)
|
|
32
|
+
if not values:
|
|
33
|
+
return None
|
|
34
|
+
return values[-1]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def select_where[T](
|
|
38
|
+
session: Session, slice_type: type[T], predicate: Callable[[T], bool]
|
|
39
|
+
) -> tuple[T, ...]:
|
|
40
|
+
"""Return items that satisfy the predicate."""
|
|
41
|
+
|
|
42
|
+
return tuple(value for value in session.select_all(slice_type) if predicate(value))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = ["select_all", "select_latest", "select_where"]
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Session state container synchronized with the event bus."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
19
|
+
from dataclasses import dataclass, is_dataclass
|
|
20
|
+
from typing import Any, cast
|
|
21
|
+
|
|
22
|
+
from ..events import EventBus, NullEventBus, PromptExecuted, ToolInvoked
|
|
23
|
+
from ..prompt._types import SupportsDataclass
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True, frozen=True)
|
|
29
|
+
class ToolData[T: SupportsDataclass]:
|
|
30
|
+
"""Wrapper containing tool payloads and their originating event."""
|
|
31
|
+
|
|
32
|
+
value: T
|
|
33
|
+
source: ToolInvoked
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(slots=True, frozen=True)
|
|
37
|
+
class PromptData[T: SupportsDataclass]:
|
|
38
|
+
"""Wrapper containing prompt outputs and their originating event."""
|
|
39
|
+
|
|
40
|
+
value: T
|
|
41
|
+
source: PromptExecuted
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
type DataEvent = ToolData[SupportsDataclass] | PromptData[SupportsDataclass]
|
|
45
|
+
|
|
46
|
+
type TypedReducer[S] = Callable[[tuple[S, ...], DataEvent], tuple[S, ...]]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass(slots=True)
|
|
50
|
+
class _ReducerRegistration:
|
|
51
|
+
reducer: TypedReducer[Any]
|
|
52
|
+
slice_type: type[Any]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Session:
|
|
56
|
+
"""Collect dataclass payloads from prompt executions and tool invocations."""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
*,
|
|
61
|
+
bus: EventBus | None = None,
|
|
62
|
+
session_id: str | None = None,
|
|
63
|
+
created_at: str | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
self.session_id = session_id
|
|
66
|
+
self.created_at = created_at
|
|
67
|
+
self._bus = bus or NullEventBus()
|
|
68
|
+
self._reducers: dict[type[SupportsDataclass], list[_ReducerRegistration]] = {}
|
|
69
|
+
self._state: dict[type[Any], tuple[Any, ...]] = {}
|
|
70
|
+
|
|
71
|
+
if bus is not None:
|
|
72
|
+
bus.subscribe(ToolInvoked, self._on_tool_invoked)
|
|
73
|
+
bus.subscribe(PromptExecuted, self._on_prompt_executed)
|
|
74
|
+
|
|
75
|
+
def register_reducer[S](
|
|
76
|
+
self,
|
|
77
|
+
data_type: type[SupportsDataclass],
|
|
78
|
+
reducer: TypedReducer[S],
|
|
79
|
+
*,
|
|
80
|
+
slice_type: type[S] | None = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Register a reducer for the provided data type."""
|
|
83
|
+
|
|
84
|
+
target_slice_type: type[Any] = data_type if slice_type is None else slice_type
|
|
85
|
+
registration = _ReducerRegistration(
|
|
86
|
+
reducer=cast(TypedReducer[Any], reducer),
|
|
87
|
+
slice_type=target_slice_type,
|
|
88
|
+
)
|
|
89
|
+
bucket = self._reducers.setdefault(data_type, [])
|
|
90
|
+
bucket.append(registration)
|
|
91
|
+
self._state.setdefault(target_slice_type, ())
|
|
92
|
+
|
|
93
|
+
def select_all[S](self, slice_type: type[S]) -> tuple[S, ...]:
|
|
94
|
+
"""Return the tuple slice maintained for the provided type."""
|
|
95
|
+
|
|
96
|
+
return cast(tuple[S, ...], self._state.get(slice_type, ()))
|
|
97
|
+
|
|
98
|
+
def _on_tool_invoked(self, event: object) -> None:
|
|
99
|
+
if isinstance(event, ToolInvoked):
|
|
100
|
+
self._handle_tool_invoked(event)
|
|
101
|
+
|
|
102
|
+
def _on_prompt_executed(self, event: object) -> None:
|
|
103
|
+
if isinstance(event, PromptExecuted):
|
|
104
|
+
self._handle_prompt_executed(event)
|
|
105
|
+
|
|
106
|
+
def _handle_tool_invoked(self, event: ToolInvoked) -> None:
|
|
107
|
+
payload = event.result.value
|
|
108
|
+
if not _is_dataclass_instance(payload):
|
|
109
|
+
return
|
|
110
|
+
dataclass_payload = cast(SupportsDataclass, payload)
|
|
111
|
+
data = ToolData(value=dataclass_payload, source=event)
|
|
112
|
+
self._dispatch_data_event(type(dataclass_payload), data)
|
|
113
|
+
|
|
114
|
+
def _handle_prompt_executed(self, event: PromptExecuted) -> None:
|
|
115
|
+
output = event.result.output
|
|
116
|
+
if _is_dataclass_instance(output):
|
|
117
|
+
dataclass_output = cast(SupportsDataclass, output)
|
|
118
|
+
data = PromptData(value=dataclass_output, source=event)
|
|
119
|
+
self._dispatch_data_event(type(dataclass_output), data)
|
|
120
|
+
return
|
|
121
|
+
if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
|
|
122
|
+
for item in cast(Iterable[object], output):
|
|
123
|
+
if _is_dataclass_instance(item):
|
|
124
|
+
dataclass_item = cast(SupportsDataclass, item)
|
|
125
|
+
data = PromptData(value=dataclass_item, source=event)
|
|
126
|
+
self._dispatch_data_event(type(dataclass_item), data)
|
|
127
|
+
|
|
128
|
+
def _dispatch_data_event(
|
|
129
|
+
self, data_type: type[SupportsDataclass], event: DataEvent
|
|
130
|
+
) -> None:
|
|
131
|
+
registrations = self._reducers.get(data_type)
|
|
132
|
+
if not registrations:
|
|
133
|
+
from .reducers import append
|
|
134
|
+
|
|
135
|
+
registrations = [
|
|
136
|
+
_ReducerRegistration(
|
|
137
|
+
reducer=cast(TypedReducer[Any], append),
|
|
138
|
+
slice_type=data_type,
|
|
139
|
+
)
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
for registration in registrations:
|
|
143
|
+
slice_type = registration.slice_type
|
|
144
|
+
previous = self._state.get(slice_type, ())
|
|
145
|
+
try:
|
|
146
|
+
result = registration.reducer(previous, event)
|
|
147
|
+
except Exception: # noqa: BLE001
|
|
148
|
+
logger.exception(
|
|
149
|
+
"Reducer %r failed for data type %s",
|
|
150
|
+
registration.reducer,
|
|
151
|
+
data_type,
|
|
152
|
+
)
|
|
153
|
+
continue
|
|
154
|
+
normalized = tuple(result)
|
|
155
|
+
self._state[slice_type] = normalized
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _is_dataclass_instance(value: object) -> bool:
|
|
159
|
+
return is_dataclass(value) and not isinstance(value, type)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
__all__ = [
|
|
163
|
+
"Session",
|
|
164
|
+
"ToolData",
|
|
165
|
+
"PromptData",
|
|
166
|
+
"DataEvent",
|
|
167
|
+
"TypedReducer",
|
|
168
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Public surface for built-in tool suites."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from .errors import ToolValidationError
|
|
18
|
+
from .planning import (
|
|
19
|
+
AddStep,
|
|
20
|
+
ClearPlan,
|
|
21
|
+
MarkStep,
|
|
22
|
+
NewPlanStep,
|
|
23
|
+
Plan,
|
|
24
|
+
PlanningToolsSection,
|
|
25
|
+
PlanStatus,
|
|
26
|
+
PlanStep,
|
|
27
|
+
ReadPlan,
|
|
28
|
+
SetupPlan,
|
|
29
|
+
StepStatus,
|
|
30
|
+
UpdateStep,
|
|
31
|
+
)
|
|
32
|
+
from .vfs import (
|
|
33
|
+
DeleteEntry,
|
|
34
|
+
HostMount,
|
|
35
|
+
ListDirectory,
|
|
36
|
+
ListDirectoryResult,
|
|
37
|
+
ReadFile,
|
|
38
|
+
VfsFile,
|
|
39
|
+
VfsPath,
|
|
40
|
+
VfsToolsSection,
|
|
41
|
+
VirtualFileSystem,
|
|
42
|
+
WriteFile,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"ToolValidationError",
|
|
47
|
+
"Plan",
|
|
48
|
+
"PlanStep",
|
|
49
|
+
"PlanStatus",
|
|
50
|
+
"StepStatus",
|
|
51
|
+
"NewPlanStep",
|
|
52
|
+
"SetupPlan",
|
|
53
|
+
"AddStep",
|
|
54
|
+
"UpdateStep",
|
|
55
|
+
"MarkStep",
|
|
56
|
+
"ClearPlan",
|
|
57
|
+
"ReadPlan",
|
|
58
|
+
"PlanningToolsSection",
|
|
59
|
+
"VirtualFileSystem",
|
|
60
|
+
"VfsFile",
|
|
61
|
+
"VfsPath",
|
|
62
|
+
"HostMount",
|
|
63
|
+
"ListDirectory",
|
|
64
|
+
"ListDirectoryResult",
|
|
65
|
+
"ReadFile",
|
|
66
|
+
"WriteFile",
|
|
67
|
+
"DeleteEntry",
|
|
68
|
+
"VfsToolsSection",
|
|
69
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared error types for built-in tools."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ToolValidationError(ValueError):
|
|
19
|
+
"""Raised when tool parameters fail validation checks."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = ["ToolValidationError"]
|