weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Built-in reducers for Session state slices."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import cast
|
|
19
|
+
|
|
20
|
+
from ...dbc import pure
|
|
21
|
+
from ...prompt._types import SupportsDataclass
|
|
22
|
+
from ._types import ReducerContextProtocol, ReducerEvent, TypedReducer
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pure
|
|
26
|
+
def append[T: SupportsDataclass](
|
|
27
|
+
slice_values: tuple[T, ...],
|
|
28
|
+
event: ReducerEvent,
|
|
29
|
+
*,
|
|
30
|
+
context: ReducerContextProtocol,
|
|
31
|
+
) -> tuple[T, ...]:
|
|
32
|
+
"""Append the event value if it is not already present."""
|
|
33
|
+
|
|
34
|
+
del context
|
|
35
|
+
value = cast(T, event.value)
|
|
36
|
+
if value in slice_values:
|
|
37
|
+
return slice_values
|
|
38
|
+
return (*slice_values, value)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def upsert_by[T: SupportsDataclass, K](key_fn: Callable[[T], K]) -> TypedReducer[T]:
|
|
42
|
+
"""Return a reducer that upserts items sharing the same derived key."""
|
|
43
|
+
|
|
44
|
+
def reducer(
|
|
45
|
+
slice_values: tuple[T, ...],
|
|
46
|
+
event: ReducerEvent,
|
|
47
|
+
*,
|
|
48
|
+
context: ReducerContextProtocol,
|
|
49
|
+
) -> tuple[T, ...]:
|
|
50
|
+
del context
|
|
51
|
+
value = cast(T, event.value)
|
|
52
|
+
key = key_fn(value)
|
|
53
|
+
updated: list[T] = []
|
|
54
|
+
replaced = False
|
|
55
|
+
for existing in slice_values:
|
|
56
|
+
if key_fn(existing) == key:
|
|
57
|
+
if not replaced:
|
|
58
|
+
updated.append(value)
|
|
59
|
+
replaced = True
|
|
60
|
+
continue
|
|
61
|
+
updated.append(existing)
|
|
62
|
+
if not replaced:
|
|
63
|
+
updated.append(value)
|
|
64
|
+
return tuple(updated)
|
|
65
|
+
|
|
66
|
+
return pure(reducer)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pure
|
|
70
|
+
def replace_latest[T: SupportsDataclass](
|
|
71
|
+
slice_values: tuple[T, ...],
|
|
72
|
+
event: ReducerEvent,
|
|
73
|
+
*,
|
|
74
|
+
context: ReducerContextProtocol,
|
|
75
|
+
) -> tuple[T, ...]:
|
|
76
|
+
"""Keep only the most recent event value."""
|
|
77
|
+
|
|
78
|
+
del context
|
|
79
|
+
return (cast(T, event.value),)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
__all__ = ["append", "replace_latest", "upsert_by"]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Helpers for querying Session slices."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
|
|
19
|
+
from ...dbc import pure
|
|
20
|
+
from ...prompt._types import SupportsDataclass
|
|
21
|
+
from .session import Session
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pure
|
|
25
|
+
def select_all[T: SupportsDataclass](
|
|
26
|
+
session: Session, slice_type: type[T]
|
|
27
|
+
) -> tuple[T, ...]:
|
|
28
|
+
"""Return the entire slice for the provided type."""
|
|
29
|
+
|
|
30
|
+
return session.select_all(slice_type)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pure
|
|
34
|
+
def select_latest[T: SupportsDataclass](
|
|
35
|
+
session: Session, slice_type: type[T]
|
|
36
|
+
) -> T | None:
|
|
37
|
+
"""Return the most recent item in the slice, if any."""
|
|
38
|
+
|
|
39
|
+
values = session.select_all(slice_type)
|
|
40
|
+
if not values:
|
|
41
|
+
return None
|
|
42
|
+
return values[-1]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pure
|
|
46
|
+
def select_where[T: SupportsDataclass](
|
|
47
|
+
session: Session,
|
|
48
|
+
slice_type: type[T],
|
|
49
|
+
predicate: Callable[[T], bool],
|
|
50
|
+
) -> tuple[T, ...]:
|
|
51
|
+
"""Return items that satisfy the predicate."""
|
|
52
|
+
|
|
53
|
+
return tuple(value for value in session.select_all(slice_type) if predicate(value))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
__all__ = ["select_all", "select_latest", "select_where"]
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Session state container synchronized with the event bus."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Iterable
|
|
18
|
+
from dataclasses import dataclass, replace
|
|
19
|
+
from datetime import UTC, datetime
|
|
20
|
+
from threading import RLock
|
|
21
|
+
from typing import Any, cast, override
|
|
22
|
+
from uuid import UUID, uuid4
|
|
23
|
+
|
|
24
|
+
from ...dbc import invariant
|
|
25
|
+
from ...prompt._types import SupportsDataclass
|
|
26
|
+
from ..events import EventBus, PromptExecuted, PromptRendered, ToolInvoked
|
|
27
|
+
from ..logging import StructuredLogger, get_logger
|
|
28
|
+
from ._slice_types import SessionSlice, SessionSliceType
|
|
29
|
+
from ._types import ReducerContextProtocol, ReducerEvent, TypedReducer
|
|
30
|
+
from .dataclasses import is_dataclass_instance
|
|
31
|
+
from .protocols import SessionProtocol, SnapshotProtocol
|
|
32
|
+
from .reducers import append
|
|
33
|
+
from .snapshots import (
|
|
34
|
+
Snapshot,
|
|
35
|
+
SnapshotRestoreError,
|
|
36
|
+
SnapshotSerializationError,
|
|
37
|
+
SnapshotState,
|
|
38
|
+
normalize_snapshot_state,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
logger: StructuredLogger = get_logger(__name__, context={"component": "session"})
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
type DataEvent = PromptExecuted | PromptRendered | ToolInvoked
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
_PROMPT_RENDERED_TYPE: type[SupportsDataclass] = cast(
|
|
48
|
+
type[SupportsDataclass], PromptRendered
|
|
49
|
+
)
|
|
50
|
+
_TOOL_INVOKED_TYPE: type[SupportsDataclass] = cast(type[SupportsDataclass], ToolInvoked)
|
|
51
|
+
_PROMPT_EXECUTED_TYPE: type[SupportsDataclass] = cast(
|
|
52
|
+
type[SupportsDataclass], PromptExecuted
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
EMPTY_SLICE: SessionSlice = ()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _append_event(
|
|
59
|
+
slice_values: tuple[SupportsDataclass, ...],
|
|
60
|
+
event: ReducerEvent,
|
|
61
|
+
*,
|
|
62
|
+
context: ReducerContextProtocol,
|
|
63
|
+
) -> tuple[SupportsDataclass, ...]:
|
|
64
|
+
del context
|
|
65
|
+
appended = cast(SupportsDataclass, event)
|
|
66
|
+
return (*slice_values, appended)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(slots=True)
|
|
70
|
+
class _ReducerRegistration:
|
|
71
|
+
reducer: TypedReducer[Any]
|
|
72
|
+
slice_type: SessionSliceType
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _session_id_is_well_formed(session: "Session") -> bool: # noqa: UP037
|
|
76
|
+
return len(session.session_id.bytes) == 16
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _created_at_has_tz(session: "Session") -> bool: # noqa: UP037
|
|
80
|
+
return session.created_at.tzinfo is not None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _created_at_is_utc(session: "Session") -> bool: # noqa: UP037
|
|
84
|
+
return session.created_at.tzinfo == UTC
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@invariant(
|
|
88
|
+
_session_id_is_well_formed,
|
|
89
|
+
_created_at_has_tz,
|
|
90
|
+
_created_at_is_utc,
|
|
91
|
+
)
|
|
92
|
+
class Session(SessionProtocol):
|
|
93
|
+
"""Collect dataclass payloads from prompt executions and tool invocations."""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
*,
|
|
98
|
+
bus: EventBus,
|
|
99
|
+
session_id: UUID | None = None,
|
|
100
|
+
created_at: datetime | None = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
super().__init__()
|
|
103
|
+
resolved_session_id = session_id if session_id is not None else uuid4()
|
|
104
|
+
resolved_created_at = (
|
|
105
|
+
created_at if created_at is not None else datetime.now(UTC)
|
|
106
|
+
)
|
|
107
|
+
if resolved_created_at.tzinfo is None:
|
|
108
|
+
msg = "Session created_at must be timezone-aware."
|
|
109
|
+
raise ValueError(msg)
|
|
110
|
+
|
|
111
|
+
self.session_id: UUID = resolved_session_id
|
|
112
|
+
self.created_at: datetime = resolved_created_at.astimezone(UTC)
|
|
113
|
+
self._bus: EventBus = bus
|
|
114
|
+
self._reducers: dict[SessionSliceType, list[_ReducerRegistration]] = {}
|
|
115
|
+
self._state: dict[SessionSliceType, SessionSlice] = {}
|
|
116
|
+
self._lock = RLock()
|
|
117
|
+
self._subscriptions_attached = False
|
|
118
|
+
self._attach_to_bus(bus)
|
|
119
|
+
|
|
120
|
+
def clone(
|
|
121
|
+
self,
|
|
122
|
+
*,
|
|
123
|
+
bus: EventBus,
|
|
124
|
+
session_id: UUID | None = None,
|
|
125
|
+
created_at: datetime | None = None,
|
|
126
|
+
) -> Session:
|
|
127
|
+
"""Return a new session that mirrors the current state and reducers."""
|
|
128
|
+
|
|
129
|
+
with self._lock:
|
|
130
|
+
reducer_snapshot = [
|
|
131
|
+
(data_type, tuple(registrations))
|
|
132
|
+
for data_type, registrations in self._reducers.items()
|
|
133
|
+
]
|
|
134
|
+
state_snapshot = dict(self._state)
|
|
135
|
+
|
|
136
|
+
clone = Session(
|
|
137
|
+
bus=bus,
|
|
138
|
+
session_id=session_id if session_id is not None else self.session_id,
|
|
139
|
+
created_at=created_at if created_at is not None else self.created_at,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
for data_type, registrations in reducer_snapshot:
|
|
143
|
+
for registration in registrations:
|
|
144
|
+
clone.register_reducer(
|
|
145
|
+
data_type,
|
|
146
|
+
registration.reducer,
|
|
147
|
+
slice_type=registration.slice_type,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
with clone._lock:
|
|
151
|
+
clone._state = state_snapshot
|
|
152
|
+
|
|
153
|
+
return clone
|
|
154
|
+
|
|
155
|
+
def register_reducer[S: SupportsDataclass](
|
|
156
|
+
self,
|
|
157
|
+
data_type: SessionSliceType,
|
|
158
|
+
reducer: TypedReducer[S],
|
|
159
|
+
*,
|
|
160
|
+
slice_type: type[S] | None = None,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Register a reducer for the provided data type."""
|
|
163
|
+
|
|
164
|
+
target_slice_type: SessionSliceType = (
|
|
165
|
+
data_type if slice_type is None else slice_type
|
|
166
|
+
)
|
|
167
|
+
registration = _ReducerRegistration(
|
|
168
|
+
reducer=cast(TypedReducer[Any], reducer),
|
|
169
|
+
slice_type=target_slice_type,
|
|
170
|
+
)
|
|
171
|
+
with self._lock:
|
|
172
|
+
bucket = self._reducers.setdefault(data_type, [])
|
|
173
|
+
bucket.append(registration)
|
|
174
|
+
_ = self._state.setdefault(target_slice_type, EMPTY_SLICE)
|
|
175
|
+
|
|
176
|
+
def select_all[S: SupportsDataclass](self, slice_type: type[S]) -> tuple[S, ...]:
|
|
177
|
+
"""Return the tuple slice maintained for the provided type."""
|
|
178
|
+
|
|
179
|
+
with self._lock:
|
|
180
|
+
return cast(tuple[S, ...], self._state.get(slice_type, EMPTY_SLICE))
|
|
181
|
+
|
|
182
|
+
def seed_slice[S: SupportsDataclass](
|
|
183
|
+
self, slice_type: type[S], values: Iterable[S]
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Initialize or replace the stored tuple for the provided type."""
|
|
186
|
+
|
|
187
|
+
with self._lock:
|
|
188
|
+
self._state[slice_type] = tuple(values)
|
|
189
|
+
|
|
190
|
+
@override
|
|
191
|
+
def reset(self) -> None:
|
|
192
|
+
"""Clear all stored slices while preserving reducer registrations."""
|
|
193
|
+
|
|
194
|
+
with self._lock:
|
|
195
|
+
slice_types: set[SessionSliceType] = set(self._state)
|
|
196
|
+
for registrations in self._reducers.values():
|
|
197
|
+
for registration in registrations:
|
|
198
|
+
slice_types.add(registration.slice_type)
|
|
199
|
+
|
|
200
|
+
self._state = dict.fromkeys(slice_types, EMPTY_SLICE)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def event_bus(self) -> EventBus:
|
|
204
|
+
"""Return the event bus backing this session."""
|
|
205
|
+
|
|
206
|
+
return self._bus
|
|
207
|
+
|
|
208
|
+
@override
|
|
209
|
+
def snapshot(self) -> SnapshotProtocol:
|
|
210
|
+
"""Capture an immutable snapshot of the current session state."""
|
|
211
|
+
|
|
212
|
+
with self._lock:
|
|
213
|
+
state_snapshot: dict[SessionSliceType, SessionSlice] = dict(self._state)
|
|
214
|
+
for ephemeral in (
|
|
215
|
+
_TOOL_INVOKED_TYPE,
|
|
216
|
+
_PROMPT_EXECUTED_TYPE,
|
|
217
|
+
_PROMPT_RENDERED_TYPE,
|
|
218
|
+
):
|
|
219
|
+
_ = state_snapshot.pop(ephemeral, None)
|
|
220
|
+
try:
|
|
221
|
+
normalized: SnapshotState = normalize_snapshot_state(state_snapshot)
|
|
222
|
+
except ValueError as error:
|
|
223
|
+
msg = "Unable to serialize session slices"
|
|
224
|
+
raise SnapshotSerializationError(msg) from error
|
|
225
|
+
|
|
226
|
+
created_at = datetime.now(UTC)
|
|
227
|
+
return Snapshot(created_at=created_at, slices=normalized)
|
|
228
|
+
|
|
229
|
+
@override
|
|
230
|
+
def rollback(self, snapshot: SnapshotProtocol) -> None:
|
|
231
|
+
"""Restore session slices from the provided snapshot."""
|
|
232
|
+
|
|
233
|
+
registered_slices = self._registered_slice_types()
|
|
234
|
+
missing = [
|
|
235
|
+
slice_type
|
|
236
|
+
for slice_type in snapshot.slices
|
|
237
|
+
if slice_type not in registered_slices
|
|
238
|
+
]
|
|
239
|
+
if missing:
|
|
240
|
+
missing_names = ", ".join(sorted(cls.__qualname__ for cls in missing))
|
|
241
|
+
msg = f"Slice types not registered: {missing_names}"
|
|
242
|
+
raise SnapshotRestoreError(msg)
|
|
243
|
+
|
|
244
|
+
with self._lock:
|
|
245
|
+
new_state: dict[SessionSliceType, SessionSlice] = dict(self._state)
|
|
246
|
+
for slice_type in registered_slices:
|
|
247
|
+
new_state[slice_type] = snapshot.slices.get(slice_type, EMPTY_SLICE)
|
|
248
|
+
|
|
249
|
+
self._state = new_state
|
|
250
|
+
|
|
251
|
+
def _registered_slice_types(self) -> set[SessionSliceType]:
|
|
252
|
+
with self._lock:
|
|
253
|
+
types: set[SessionSliceType] = set(self._state)
|
|
254
|
+
for registrations in self._reducers.values():
|
|
255
|
+
for registration in registrations:
|
|
256
|
+
types.add(registration.slice_type)
|
|
257
|
+
return types
|
|
258
|
+
|
|
259
|
+
def _on_tool_invoked(self, event: object) -> None:
|
|
260
|
+
tool_event = cast(ToolInvoked, event)
|
|
261
|
+
self._handle_tool_invoked(tool_event)
|
|
262
|
+
|
|
263
|
+
def _on_prompt_executed(self, event: object) -> None:
|
|
264
|
+
prompt_event = cast(PromptExecuted, event)
|
|
265
|
+
self._handle_prompt_executed(prompt_event)
|
|
266
|
+
|
|
267
|
+
def _on_prompt_rendered(self, event: object) -> None:
|
|
268
|
+
start_event = cast(PromptRendered, event)
|
|
269
|
+
self._handle_prompt_rendered(start_event)
|
|
270
|
+
|
|
271
|
+
def _handle_tool_invoked(self, event: ToolInvoked) -> None:
|
|
272
|
+
normalized_event = event
|
|
273
|
+
payload = event.value if event.value is not None else event.result.value
|
|
274
|
+
if event.value is None and is_dataclass_instance(payload):
|
|
275
|
+
normalized_event = replace(event, value=payload)
|
|
276
|
+
|
|
277
|
+
self._dispatch_data_event(
|
|
278
|
+
_TOOL_INVOKED_TYPE,
|
|
279
|
+
cast(ReducerEvent, normalized_event),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if normalized_event.value is not None:
|
|
283
|
+
self._dispatch_data_event(
|
|
284
|
+
type(normalized_event.value),
|
|
285
|
+
cast(ReducerEvent, normalized_event),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _handle_prompt_executed(self, event: PromptExecuted) -> None:
|
|
289
|
+
normalized_event = event
|
|
290
|
+
output = event.result.output
|
|
291
|
+
if event.value is None and is_dataclass_instance(output):
|
|
292
|
+
normalized_event = replace(event, value=output)
|
|
293
|
+
|
|
294
|
+
self._dispatch_data_event(
|
|
295
|
+
_PROMPT_EXECUTED_TYPE,
|
|
296
|
+
cast(ReducerEvent, normalized_event),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if normalized_event.value is not None:
|
|
300
|
+
self._dispatch_data_event(
|
|
301
|
+
type(normalized_event.value),
|
|
302
|
+
cast(ReducerEvent, normalized_event),
|
|
303
|
+
)
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
if isinstance(output, Iterable) and not isinstance(output, (str, bytes)):
|
|
307
|
+
for item in cast(Iterable[object], output):
|
|
308
|
+
if is_dataclass_instance(item):
|
|
309
|
+
enriched_event = replace(normalized_event, value=item)
|
|
310
|
+
self._dispatch_data_event(
|
|
311
|
+
type(item),
|
|
312
|
+
cast(ReducerEvent, enriched_event),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def _handle_prompt_rendered(self, event: PromptRendered) -> None:
|
|
316
|
+
self._dispatch_data_event(
|
|
317
|
+
_PROMPT_RENDERED_TYPE,
|
|
318
|
+
cast(ReducerEvent, event),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def _dispatch_data_event(
|
|
322
|
+
self, data_type: SessionSliceType, event: ReducerEvent
|
|
323
|
+
) -> None:
|
|
324
|
+
from .reducer_context import build_reducer_context
|
|
325
|
+
|
|
326
|
+
with self._lock:
|
|
327
|
+
registrations = list(self._reducers.get(data_type, ()))
|
|
328
|
+
if not registrations:
|
|
329
|
+
default_reducer: TypedReducer[Any]
|
|
330
|
+
if data_type in {_TOOL_INVOKED_TYPE, _PROMPT_EXECUTED_TYPE}:
|
|
331
|
+
default_reducer = cast(TypedReducer[Any], _append_event)
|
|
332
|
+
else:
|
|
333
|
+
default_reducer = cast(TypedReducer[Any], append)
|
|
334
|
+
registrations = [
|
|
335
|
+
_ReducerRegistration(
|
|
336
|
+
reducer=default_reducer,
|
|
337
|
+
slice_type=data_type,
|
|
338
|
+
)
|
|
339
|
+
]
|
|
340
|
+
event_bus = self._bus
|
|
341
|
+
|
|
342
|
+
context = build_reducer_context(session=self, event_bus=event_bus)
|
|
343
|
+
|
|
344
|
+
for registration in registrations:
|
|
345
|
+
slice_type = registration.slice_type
|
|
346
|
+
while True:
|
|
347
|
+
with self._lock:
|
|
348
|
+
previous = self._state.get(slice_type, EMPTY_SLICE)
|
|
349
|
+
try:
|
|
350
|
+
result = registration.reducer(previous, event, context=context)
|
|
351
|
+
except Exception: # log and continue
|
|
352
|
+
reducer_name = getattr(
|
|
353
|
+
registration.reducer, "__qualname__", repr(registration.reducer)
|
|
354
|
+
)
|
|
355
|
+
logger.exception(
|
|
356
|
+
"Reducer application failed.",
|
|
357
|
+
event="session_reducer_failed",
|
|
358
|
+
context={
|
|
359
|
+
"reducer": reducer_name,
|
|
360
|
+
"data_type": data_type.__qualname__,
|
|
361
|
+
"slice_type": slice_type.__qualname__,
|
|
362
|
+
},
|
|
363
|
+
)
|
|
364
|
+
break
|
|
365
|
+
normalized = tuple(result)
|
|
366
|
+
with self._lock:
|
|
367
|
+
current = self._state.get(slice_type, EMPTY_SLICE)
|
|
368
|
+
if current is previous or current == normalized:
|
|
369
|
+
self._state[slice_type] = normalized
|
|
370
|
+
break
|
|
371
|
+
|
|
372
|
+
def _attach_to_bus(self, bus: EventBus) -> None:
|
|
373
|
+
with self._lock:
|
|
374
|
+
if self._subscriptions_attached and self._bus is bus:
|
|
375
|
+
return
|
|
376
|
+
self._bus = bus
|
|
377
|
+
self._subscriptions_attached = True
|
|
378
|
+
bus.subscribe(ToolInvoked, self._on_tool_invoked)
|
|
379
|
+
bus.subscribe(PromptExecuted, self._on_prompt_executed)
|
|
380
|
+
bus.subscribe(PromptRendered, self._on_prompt_rendered)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
__all__ = [
|
|
384
|
+
"DataEvent",
|
|
385
|
+
"Session",
|
|
386
|
+
"TypedReducer",
|
|
387
|
+
]
|