weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared typing primitives for event integrations."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from typing import TYPE_CHECKING, Protocol, cast, override
|
|
21
|
+
from uuid import UUID, uuid4
|
|
22
|
+
|
|
23
|
+
from ...prompt._types import SupportsDataclass
|
|
24
|
+
from ...prompt.tool_result import ToolResult
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from ...adapters._names import AdapterName
|
|
28
|
+
|
|
29
|
+
EventHandler = Callable[[object], None]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class EventBus(Protocol):
|
|
33
|
+
"""Minimal synchronous publish/subscribe abstraction."""
|
|
34
|
+
|
|
35
|
+
def subscribe(self, event_type: type[object], handler: EventHandler) -> None:
|
|
36
|
+
"""Register a handler for the given event type."""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
def publish(self, event: object) -> PublishResult:
|
|
40
|
+
"""Publish an event instance to subscribers."""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(slots=True, frozen=True)
|
|
45
|
+
class HandlerFailure:
|
|
46
|
+
"""Container describing a handler error captured during publish."""
|
|
47
|
+
|
|
48
|
+
handler: EventHandler
|
|
49
|
+
error: BaseException
|
|
50
|
+
|
|
51
|
+
@override
|
|
52
|
+
def __str__(self) -> str:
|
|
53
|
+
return f"{self.handler!r} -> {self.error!r}"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(slots=True, frozen=True)
|
|
57
|
+
class PublishResult:
|
|
58
|
+
"""Summary of an event publish invocation."""
|
|
59
|
+
|
|
60
|
+
event: object
|
|
61
|
+
handlers_invoked: tuple[EventHandler, ...]
|
|
62
|
+
errors: tuple[HandlerFailure, ...]
|
|
63
|
+
handled_count: int = field(init=False)
|
|
64
|
+
|
|
65
|
+
def __post_init__(self) -> None:
|
|
66
|
+
object.__setattr__(self, "handled_count", len(self.handlers_invoked))
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def ok(self) -> bool:
|
|
70
|
+
"""Return ``True`` when no handler failures were recorded."""
|
|
71
|
+
|
|
72
|
+
return not self.errors
|
|
73
|
+
|
|
74
|
+
def raise_if_errors(self) -> None:
|
|
75
|
+
"""Raise an ``ExceptionGroup`` if any handlers failed."""
|
|
76
|
+
|
|
77
|
+
if not self.errors:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
failures = ", ".join(str(failure) for failure in self.errors)
|
|
81
|
+
message = f"Errors while publishing {type(self.event).__name__}: {failures}"
|
|
82
|
+
raise ExceptionGroup(
|
|
83
|
+
message,
|
|
84
|
+
tuple(cast(Exception, failure.error) for failure in self.errors),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(slots=True, frozen=True)
|
|
89
|
+
class ToolInvoked:
|
|
90
|
+
"""Event emitted after an adapter executes a tool handler."""
|
|
91
|
+
|
|
92
|
+
prompt_name: str
|
|
93
|
+
adapter: AdapterName
|
|
94
|
+
name: str
|
|
95
|
+
params: SupportsDataclass
|
|
96
|
+
result: ToolResult[object]
|
|
97
|
+
session_id: UUID | None
|
|
98
|
+
created_at: datetime
|
|
99
|
+
value: SupportsDataclass | None = None
|
|
100
|
+
call_id: str | None = None
|
|
101
|
+
event_id: UUID = field(default_factory=uuid4)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
__all__ = [
|
|
105
|
+
"EventBus",
|
|
106
|
+
"EventHandler",
|
|
107
|
+
"HandlerFailure",
|
|
108
|
+
"PublishResult",
|
|
109
|
+
"ToolInvoked",
|
|
110
|
+
]
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Structured logging helpers for :mod:`weakincentives`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import logging.config
|
|
20
|
+
import os
|
|
21
|
+
from collections.abc import Mapping, MutableMapping
|
|
22
|
+
from datetime import UTC, datetime
|
|
23
|
+
from typing import Protocol, cast, override
|
|
24
|
+
|
|
25
|
+
from ..types import JSONValue
|
|
26
|
+
|
|
27
|
+
type StructuredLogPayload = Mapping[str, JSONValue]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SupportsLogMessage(Protocol):
|
|
31
|
+
"""Protocol describing values that logging can stringify."""
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def __str__(self) -> str: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"JSONValue",
|
|
39
|
+
"StructuredLogPayload",
|
|
40
|
+
"StructuredLogger",
|
|
41
|
+
"SupportsLogMessage",
|
|
42
|
+
"configure_logging",
|
|
43
|
+
"get_logger",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
_LOG_LEVEL_ENV = "WEAKINCENTIVES_LOG_LEVEL"
|
|
47
|
+
_LOG_FORMAT_ENV = "WEAKINCENTIVES_LOG_FORMAT"
|
|
48
|
+
_LEVEL_NAMES = {
|
|
49
|
+
"CRITICAL": logging.CRITICAL,
|
|
50
|
+
"ERROR": logging.ERROR,
|
|
51
|
+
"WARNING": logging.WARNING,
|
|
52
|
+
"INFO": logging.INFO,
|
|
53
|
+
"DEBUG": logging.DEBUG,
|
|
54
|
+
"NOTSET": logging.NOTSET,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class StructuredLogger(logging.LoggerAdapter[logging.Logger]):
|
|
59
|
+
"""Logger adapter enforcing a minimal structured event schema."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
logger: logging.Logger,
|
|
64
|
+
*,
|
|
65
|
+
context: StructuredLogPayload | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
base_context: dict[str, JSONValue] = (
|
|
68
|
+
dict(context) if context is not None else {}
|
|
69
|
+
)
|
|
70
|
+
super().__init__(logger, base_context)
|
|
71
|
+
self._context: dict[str, JSONValue] = base_context
|
|
72
|
+
|
|
73
|
+
def bind(self, **context: JSONValue) -> StructuredLogger:
|
|
74
|
+
"""Return a new adapter with ``context`` merged into the baseline payload."""
|
|
75
|
+
|
|
76
|
+
merged: dict[str, JSONValue] = {**dict(self._context), **context}
|
|
77
|
+
return type(self)(self.logger, context=merged)
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def process(
|
|
81
|
+
self, msg: SupportsLogMessage, kwargs: MutableMapping[str, object]
|
|
82
|
+
) -> tuple[SupportsLogMessage, MutableMapping[str, object]]:
|
|
83
|
+
extra_value = kwargs.setdefault("extra", {})
|
|
84
|
+
if extra_value is None:
|
|
85
|
+
extra_mapping: MutableMapping[str, JSONValue] = {}
|
|
86
|
+
kwargs["extra"] = extra_mapping
|
|
87
|
+
elif isinstance(extra_value, MutableMapping):
|
|
88
|
+
extra_mapping = cast(MutableMapping[str, JSONValue], extra_value)
|
|
89
|
+
else: # pragma: no cover - defensive guard
|
|
90
|
+
raise TypeError(
|
|
91
|
+
"Structured logs require a mutable mapping for extra context."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
context_payload: dict[str, JSONValue] = dict(self._context)
|
|
95
|
+
|
|
96
|
+
inline_context = kwargs.pop("context", None)
|
|
97
|
+
if inline_context is not None:
|
|
98
|
+
if isinstance(inline_context, Mapping):
|
|
99
|
+
context_payload.update(cast(StructuredLogPayload, inline_context))
|
|
100
|
+
else: # pragma: no cover - defensive guard
|
|
101
|
+
raise TypeError("context must be a mapping when provided.")
|
|
102
|
+
|
|
103
|
+
for key in tuple(extra_mapping.keys()):
|
|
104
|
+
if key == "event":
|
|
105
|
+
continue
|
|
106
|
+
context_payload[key] = extra_mapping.pop(key)
|
|
107
|
+
|
|
108
|
+
event_obj = kwargs.pop("event", None)
|
|
109
|
+
if event_obj is None:
|
|
110
|
+
event_obj = extra_mapping.pop("event", None)
|
|
111
|
+
if not isinstance(event_obj, str):
|
|
112
|
+
raise TypeError("Structured logs require an 'event' field.")
|
|
113
|
+
|
|
114
|
+
extra_mapping.clear()
|
|
115
|
+
extra_mapping.update(
|
|
116
|
+
{
|
|
117
|
+
"event": event_obj,
|
|
118
|
+
"context": context_payload,
|
|
119
|
+
}
|
|
120
|
+
)
|
|
121
|
+
return msg, kwargs
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_logger(
|
|
125
|
+
name: str,
|
|
126
|
+
*,
|
|
127
|
+
logger_override: logging.Logger
|
|
128
|
+
| logging.LoggerAdapter[logging.Logger]
|
|
129
|
+
| None = None,
|
|
130
|
+
context: StructuredLogPayload | None = None,
|
|
131
|
+
) -> StructuredLogger:
|
|
132
|
+
"""Return a :class:`StructuredLogger` scoped to ``name``.
|
|
133
|
+
|
|
134
|
+
When ``logger_override`` is provided, the returned adapter reuses the supplied
|
|
135
|
+
logger and merges its contextual ``extra`` payload when available.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
base_context: dict[str, JSONValue] = dict(context or {})
|
|
139
|
+
base_logger: logging.Logger
|
|
140
|
+
|
|
141
|
+
if isinstance(logger_override, StructuredLogger):
|
|
142
|
+
base_logger = logger_override.logger
|
|
143
|
+
base_context = {
|
|
144
|
+
**dict(cast(StructuredLogPayload, logger_override.extra)),
|
|
145
|
+
**base_context,
|
|
146
|
+
}
|
|
147
|
+
elif isinstance(logger_override, logging.Logger):
|
|
148
|
+
base_logger = logger_override
|
|
149
|
+
elif isinstance(logger_override, logging.LoggerAdapter):
|
|
150
|
+
base_logger = _unwrap_logger(cast(_SupportsNestedLogger, logger_override))
|
|
151
|
+
adapter_extra = getattr(logger_override, "extra", None)
|
|
152
|
+
if isinstance(adapter_extra, Mapping):
|
|
153
|
+
base_context = {
|
|
154
|
+
**dict(cast(StructuredLogPayload, adapter_extra)),
|
|
155
|
+
**base_context,
|
|
156
|
+
}
|
|
157
|
+
else:
|
|
158
|
+
base_logger = logging.getLogger(name)
|
|
159
|
+
|
|
160
|
+
return StructuredLogger(base_logger, context=base_context)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def configure_logging(
|
|
164
|
+
*,
|
|
165
|
+
level: int | str | None = None,
|
|
166
|
+
json_mode: bool | None = None,
|
|
167
|
+
env: Mapping[str, str] | None = None,
|
|
168
|
+
force: bool = False,
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Configure the root logger with sensible defaults.
|
|
171
|
+
|
|
172
|
+
``level`` and ``json_mode`` can be supplied directly or via the
|
|
173
|
+
``WEAKINCENTIVES_LOG_LEVEL`` and ``WEAKINCENTIVES_LOG_FORMAT`` environment
|
|
174
|
+
variables respectively (``json`` enables structured output, ``text`` keeps the
|
|
175
|
+
plain formatter).
|
|
176
|
+
|
|
177
|
+
The function avoids installing duplicate handlers when the host application has
|
|
178
|
+
already configured logging unless ``force=True`` is supplied.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
env = env or os.environ
|
|
182
|
+
|
|
183
|
+
if level is not None:
|
|
184
|
+
resolved_level = _coerce_level(level)
|
|
185
|
+
else:
|
|
186
|
+
resolved_level = _coerce_level(env.get(_LOG_LEVEL_ENV))
|
|
187
|
+
|
|
188
|
+
if json_mode is None:
|
|
189
|
+
format_value = env.get(_LOG_FORMAT_ENV)
|
|
190
|
+
if format_value is not None:
|
|
191
|
+
json_mode = format_value.lower() == "json"
|
|
192
|
+
else:
|
|
193
|
+
json_mode = False
|
|
194
|
+
|
|
195
|
+
root_logger = logging.getLogger()
|
|
196
|
+
|
|
197
|
+
if root_logger.handlers and not force:
|
|
198
|
+
root_logger.setLevel(resolved_level)
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
formatter_key = "json" if json_mode else "text"
|
|
202
|
+
logging.config.dictConfig(
|
|
203
|
+
{
|
|
204
|
+
"version": 1,
|
|
205
|
+
"disable_existing_loggers": False,
|
|
206
|
+
"formatters": {
|
|
207
|
+
"text": {
|
|
208
|
+
"format": "%(asctime)s %(levelname)s %(name)s %(event)s %(message)s %(context)s",
|
|
209
|
+
"datefmt": "%Y-%m-%d %H:%M:%S",
|
|
210
|
+
},
|
|
211
|
+
"json": {
|
|
212
|
+
"()": "weakincentives.runtime.logging._JsonFormatter",
|
|
213
|
+
},
|
|
214
|
+
},
|
|
215
|
+
"handlers": {
|
|
216
|
+
"stderr": {
|
|
217
|
+
"class": "logging.StreamHandler",
|
|
218
|
+
"stream": "ext://sys.stderr",
|
|
219
|
+
"formatter": formatter_key,
|
|
220
|
+
}
|
|
221
|
+
},
|
|
222
|
+
"root": {
|
|
223
|
+
"handlers": ["stderr"],
|
|
224
|
+
"level": resolved_level,
|
|
225
|
+
},
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class _SupportsNestedLogger(Protocol):
|
|
231
|
+
logger: logging.Logger | logging.LoggerAdapter[logging.Logger]
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class _JsonFormatter(logging.Formatter):
|
|
235
|
+
"""Formatter that renders structured records as compact JSON."""
|
|
236
|
+
|
|
237
|
+
@override
|
|
238
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
239
|
+
payload: dict[str, JSONValue] = {
|
|
240
|
+
"timestamp": datetime.fromtimestamp(record.created, tz=UTC).isoformat(),
|
|
241
|
+
"level": record.levelname,
|
|
242
|
+
"logger": record.name,
|
|
243
|
+
"message": record.getMessage(),
|
|
244
|
+
}
|
|
245
|
+
event = getattr(record, "event", None)
|
|
246
|
+
if event is not None:
|
|
247
|
+
payload["event"] = event
|
|
248
|
+
context = getattr(record, "context", None)
|
|
249
|
+
if context:
|
|
250
|
+
payload["context"] = context
|
|
251
|
+
if record.exc_info:
|
|
252
|
+
payload["exc_info"] = self.formatException(record.exc_info)
|
|
253
|
+
return json.dumps(payload, default=_json_default, separators=(",", ":"))
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _json_default(value: object) -> JSONValue:
|
|
257
|
+
"""Fallback serializer returning ``repr`` for unsupported values."""
|
|
258
|
+
|
|
259
|
+
return repr(value)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
_JSON_FORMATTER_CLASS = _JsonFormatter
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _unwrap_logger(adapter: _SupportsNestedLogger) -> logging.Logger:
|
|
266
|
+
"""Return the underlying :class:`logging.Logger` from an adapter."""
|
|
267
|
+
|
|
268
|
+
logger_value = cast(object, adapter.logger)
|
|
269
|
+
if isinstance(logger_value, logging.LoggerAdapter):
|
|
270
|
+
return _unwrap_logger(cast(_SupportsNestedLogger, logger_value))
|
|
271
|
+
if isinstance(logger_value, logging.Logger):
|
|
272
|
+
return logger_value
|
|
273
|
+
raise TypeError("LoggerAdapter.logger must be a logging.Logger instance.")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _coerce_level(level: int | str | None) -> int:
|
|
277
|
+
if isinstance(level, int):
|
|
278
|
+
return level
|
|
279
|
+
if isinstance(level, str):
|
|
280
|
+
try:
|
|
281
|
+
return _LEVEL_NAMES[level.upper()]
|
|
282
|
+
except KeyError: # pragma: no cover - defensive guard
|
|
283
|
+
raise TypeError(f"Unknown log level: {level!r}") from None
|
|
284
|
+
return logging.INFO
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Session state container for agent runs."""
|
|
14
|
+
|
|
15
|
+
from ._types import ReducerContextProtocol, ReducerEvent, TypedReducer
|
|
16
|
+
from .protocols import SessionProtocol, SnapshotProtocol
|
|
17
|
+
from .reducer_context import ReducerContext, build_reducer_context
|
|
18
|
+
from .reducers import append, replace_latest, upsert_by
|
|
19
|
+
from .selectors import select_all, select_latest, select_where
|
|
20
|
+
from .session import DataEvent, Session
|
|
21
|
+
from .snapshots import (
|
|
22
|
+
Snapshot,
|
|
23
|
+
SnapshotRestoreError,
|
|
24
|
+
SnapshotSerializationError,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"DataEvent",
|
|
29
|
+
"ReducerContext",
|
|
30
|
+
"ReducerContextProtocol",
|
|
31
|
+
"ReducerEvent",
|
|
32
|
+
"Session",
|
|
33
|
+
"SessionProtocol",
|
|
34
|
+
"Snapshot",
|
|
35
|
+
"SnapshotProtocol",
|
|
36
|
+
"SnapshotRestoreError",
|
|
37
|
+
"SnapshotSerializationError",
|
|
38
|
+
"TypedReducer",
|
|
39
|
+
"append",
|
|
40
|
+
"build_reducer_context",
|
|
41
|
+
"replace_latest",
|
|
42
|
+
"select_all",
|
|
43
|
+
"select_latest",
|
|
44
|
+
"select_where",
|
|
45
|
+
"upsert_by",
|
|
46
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared slice typing helpers for :mod:`weakincentives.runtime.session`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from ...prompt._types import SupportsDataclass
|
|
18
|
+
|
|
19
|
+
type SessionSliceType = type[SupportsDataclass]
|
|
20
|
+
|
|
21
|
+
type SessionSlice = tuple[SupportsDataclass, ...]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["SessionSlice", "SessionSliceType"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Shared typing helpers for session reducers."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import TYPE_CHECKING, Protocol, TypeVar
|
|
18
|
+
|
|
19
|
+
from ...prompt._types import SupportsDataclass
|
|
20
|
+
from ..events import EventBus
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .protocols import SessionProtocol
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ReducerEvent(Protocol):
|
|
27
|
+
"""Structural type satisfied by session data events."""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def value(self) -> SupportsDataclass | None: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
S = TypeVar("S", bound=SupportsDataclass)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReducerContextProtocol(Protocol):
|
|
37
|
+
"""Protocol implemented by reducer context objects."""
|
|
38
|
+
|
|
39
|
+
session: SessionProtocol
|
|
40
|
+
event_bus: EventBus
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TypedReducer(Protocol[S]):
|
|
44
|
+
"""Protocol for reducer callables maintained by :class:`Session`."""
|
|
45
|
+
|
|
46
|
+
def __call__(
|
|
47
|
+
self,
|
|
48
|
+
slice_values: tuple[S, ...],
|
|
49
|
+
event: ReducerEvent,
|
|
50
|
+
*,
|
|
51
|
+
context: ReducerContextProtocol,
|
|
52
|
+
) -> tuple[S, ...]: ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = ["ReducerContextProtocol", "ReducerEvent", "TypedReducer"]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Dataclass helper utilities for session modules."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import is_dataclass
|
|
18
|
+
from typing import TypeGuard
|
|
19
|
+
|
|
20
|
+
from ...prompt._types import SupportsDataclass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def is_dataclass_instance(value: object) -> TypeGuard[SupportsDataclass]:
|
|
24
|
+
"""Return ``True`` when ``value`` is a dataclass instance."""
|
|
25
|
+
|
|
26
|
+
return is_dataclass(value) and not isinstance(value, type)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ["is_dataclass_instance"]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Protocols describing Session behavior exposed to other modules."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Protocol
|
|
18
|
+
|
|
19
|
+
from .snapshots import Snapshot
|
|
20
|
+
|
|
21
|
+
type SnapshotProtocol = Snapshot
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SessionProtocol(Protocol):
|
|
25
|
+
"""Structural protocol implemented by session state containers."""
|
|
26
|
+
|
|
27
|
+
def snapshot(self) -> SnapshotProtocol: ...
|
|
28
|
+
|
|
29
|
+
def rollback(self, snapshot: SnapshotProtocol) -> None: ...
|
|
30
|
+
|
|
31
|
+
def reset(self) -> None: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
__all__ = ["SessionProtocol", "SnapshotProtocol"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Reducer context threaded through session reducer invocations."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
|
|
19
|
+
from ..events import EventBus
|
|
20
|
+
from ._types import ReducerContextProtocol
|
|
21
|
+
from .protocols import SessionProtocol
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(slots=True, frozen=True)
|
|
25
|
+
class ReducerContext(ReducerContextProtocol):
|
|
26
|
+
"""Immutable bundle of runtime services shared with reducers."""
|
|
27
|
+
|
|
28
|
+
session: SessionProtocol
|
|
29
|
+
event_bus: EventBus
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_reducer_context(
|
|
33
|
+
*, session: SessionProtocol, event_bus: EventBus
|
|
34
|
+
) -> ReducerContext:
|
|
35
|
+
"""Return a :class:`ReducerContext` for the provided session and event bus."""
|
|
36
|
+
|
|
37
|
+
return ReducerContext(session=session, event_bus=event_bus)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
__all__ = ["ReducerContext", "build_reducer_context"]
|