agnt5 0.2.2__cp39-abi3-macosx_11_0_arm64.whl → 0.2.4__cp39-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +12 -12
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +169 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_telemetry.py +28 -7
- agnt5/agent.py +153 -140
- agnt5/client.py +50 -12
- agnt5/context.py +36 -756
- agnt5/entity.py +368 -1160
- agnt5/function.py +208 -235
- agnt5/lm.py +71 -12
- agnt5/tool.py +25 -11
- agnt5/tracing.py +196 -0
- agnt5/worker.py +205 -173
- agnt5/workflow.py +444 -20
- {agnt5-0.2.2.dist-info → agnt5-0.2.4.dist-info}/METADATA +2 -1
- agnt5-0.2.4.dist-info/RECORD +22 -0
- agnt5-0.2.2.dist-info/RECORD +0 -19
- {agnt5-0.2.2.dist-info → agnt5-0.2.4.dist-info}/WHEEL +0 -0
agnt5/entity.py
CHANGED
|
@@ -1,488 +1,431 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Entity component for stateful operations with single-writer consistency.
|
|
3
|
-
|
|
4
|
-
Entities provide isolated state per unique key with automatic consistency guarantees.
|
|
5
|
-
In Phase 1, entities use in-memory state with asyncio locks for single-writer semantics.
|
|
6
3
|
"""
|
|
7
4
|
|
|
8
5
|
import asyncio
|
|
6
|
+
import contextvars
|
|
9
7
|
import functools
|
|
10
8
|
import inspect
|
|
11
|
-
import
|
|
12
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypeVar
|
|
9
|
+
from typing import Any, Dict, Optional, Tuple
|
|
13
10
|
|
|
14
|
-
from .
|
|
15
|
-
from .exceptions import
|
|
16
|
-
from .function import _extract_function_schemas, _extract_function_metadata
|
|
11
|
+
from ._schema_utils import extract_function_metadata, extract_function_schemas
|
|
12
|
+
from .exceptions import ExecutionError
|
|
17
13
|
from ._telemetry import setup_module_logger
|
|
18
14
|
|
|
19
15
|
logger = setup_module_logger(__name__)
|
|
20
16
|
|
|
21
|
-
#
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
# Global storage for in-memory entity state and locks
|
|
26
|
-
# Phase 2 will replace these with platform-backed durable storage
|
|
27
|
-
_entity_states: Dict[Tuple[str, str], Dict[str, Any]] = {} # (type, key) -> state
|
|
28
|
-
_entity_locks: Dict[Tuple[str, str], asyncio.Lock] = {} # (type, key) -> lock
|
|
17
|
+
# Context variable for worker-scoped state manager
|
|
18
|
+
# This is set by Worker before entity execution and accessed by Entity instances
|
|
19
|
+
_entity_state_manager_ctx: contextvars.ContextVar[Optional["EntityStateManager"]] = \
|
|
20
|
+
contextvars.ContextVar('_entity_state_manager', default=None)
|
|
29
21
|
|
|
30
22
|
# Global entity registry
|
|
31
23
|
_ENTITY_REGISTRY: Dict[str, "EntityType"] = {}
|
|
32
24
|
|
|
33
25
|
|
|
34
|
-
class
|
|
35
|
-
"""Registry for entity types."""
|
|
36
|
-
|
|
37
|
-
@staticmethod
|
|
38
|
-
def register(entity_type: "EntityType") -> None:
|
|
39
|
-
"""Register an entity type."""
|
|
40
|
-
if entity_type.name in _ENTITY_REGISTRY:
|
|
41
|
-
logger.warning(f"Overwriting existing entity type '{entity_type.name}'")
|
|
42
|
-
_ENTITY_REGISTRY[entity_type.name] = entity_type
|
|
43
|
-
logger.debug(f"Registered entity type '{entity_type.name}'")
|
|
44
|
-
|
|
45
|
-
@staticmethod
|
|
46
|
-
def get(name: str) -> Optional["EntityType"]:
|
|
47
|
-
"""Get entity type by name."""
|
|
48
|
-
return _ENTITY_REGISTRY.get(name)
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def all() -> Dict[str, "EntityType"]:
|
|
52
|
-
"""Get all registered entities."""
|
|
53
|
-
return _ENTITY_REGISTRY.copy()
|
|
54
|
-
|
|
55
|
-
@staticmethod
|
|
56
|
-
def clear() -> None:
|
|
57
|
-
"""Clear all registered entities."""
|
|
58
|
-
_ENTITY_REGISTRY.clear()
|
|
59
|
-
logger.debug("Cleared entity registry")
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class EntityType:
|
|
26
|
+
class EntityStateManager:
|
|
63
27
|
"""
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
28
|
+
Worker-scoped state and lock management for entities.
|
|
29
|
+
|
|
30
|
+
This class provides isolated state management per Worker instance,
|
|
31
|
+
replacing the global dict approach. Each Worker gets its own state manager,
|
|
32
|
+
which provides:
|
|
33
|
+
- State storage per entity (type, key)
|
|
34
|
+
- Single-writer locks per entity
|
|
35
|
+
- Version tracking for optimistic locking
|
|
36
|
+
- Platform state loading/saving via Rust EntityStateManager
|
|
68
37
|
"""
|
|
69
38
|
|
|
70
|
-
def __init__(self,
|
|
39
|
+
def __init__(self, rust_entity_state_manager=None):
|
|
71
40
|
"""
|
|
72
|
-
Initialize
|
|
41
|
+
Initialize empty state manager.
|
|
73
42
|
|
|
74
43
|
Args:
|
|
75
|
-
|
|
76
|
-
|
|
44
|
+
rust_entity_state_manager: Optional Rust EntityStateManager for gRPC communication.
|
|
45
|
+
TODO: Wire this up once PyO3 bindings are complete.
|
|
77
46
|
"""
|
|
78
|
-
self.
|
|
79
|
-
self.
|
|
80
|
-
self.
|
|
81
|
-
self.
|
|
82
|
-
|
|
83
|
-
logger.debug(f"Created entity type: {name}")
|
|
47
|
+
self._states: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
|
48
|
+
self._locks: Dict[Tuple[str, str], asyncio.Lock] = {}
|
|
49
|
+
self._versions: Dict[Tuple[str, str], int] = {}
|
|
50
|
+
self._rust_manager = rust_entity_state_manager # TODO: Use for load/save
|
|
51
|
+
logger.debug("Created EntityStateManager")
|
|
84
52
|
|
|
85
|
-
def
|
|
53
|
+
def get_or_create_state(self, state_key: Tuple[str, str]) -> Dict[str, Any]:
|
|
86
54
|
"""
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
Methods receive a Context as the first parameter and can access
|
|
90
|
-
entity state via ctx.get/set/delete.
|
|
55
|
+
Get or create state dict for entity instance.
|
|
91
56
|
|
|
92
57
|
Args:
|
|
93
|
-
|
|
58
|
+
state_key: Tuple of (entity_type, entity_key)
|
|
94
59
|
|
|
95
60
|
Returns:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
Example:
|
|
99
|
-
```python
|
|
100
|
-
Counter = entity("Counter")
|
|
101
|
-
|
|
102
|
-
@Counter.method
|
|
103
|
-
async def increment(ctx: Context, amount: int = 1) -> int:
|
|
104
|
-
current = ctx.get("count", 0)
|
|
105
|
-
new_count = current + amount
|
|
106
|
-
ctx.set("count", new_count)
|
|
107
|
-
return new_count
|
|
108
|
-
```
|
|
61
|
+
State dict for the entity instance
|
|
109
62
|
"""
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
params = list(sig.parameters.values())
|
|
114
|
-
|
|
115
|
-
if not params:
|
|
116
|
-
raise ConfigurationError(
|
|
117
|
-
f"Entity method {f.__name__} must have at least one parameter (ctx: Context)"
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
first_param = params[0]
|
|
121
|
-
# Check if first parameter is Context (can be class, string "Context", or empty)
|
|
122
|
-
annotation = first_param.annotation
|
|
123
|
-
is_context = (
|
|
124
|
-
annotation == Context
|
|
125
|
-
or annotation == "Context"
|
|
126
|
-
or annotation is Context
|
|
127
|
-
or annotation == inspect.Parameter.empty
|
|
128
|
-
or (hasattr(annotation, "__name__") and annotation.__name__ == "Context")
|
|
129
|
-
)
|
|
130
|
-
if not is_context:
|
|
131
|
-
raise ConfigurationError(
|
|
132
|
-
f"Entity method {f.__name__} first parameter must be 'ctx: Context', "
|
|
133
|
-
f"got '{annotation}' (type: {type(annotation).__name__})"
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# Convert sync to async if needed
|
|
137
|
-
if not asyncio.iscoroutinefunction(f):
|
|
138
|
-
original_func = f
|
|
139
|
-
|
|
140
|
-
@functools.wraps(original_func)
|
|
141
|
-
async def async_wrapper(*args, **kwargs):
|
|
142
|
-
return original_func(*args, **kwargs)
|
|
143
|
-
|
|
144
|
-
f = async_wrapper
|
|
145
|
-
|
|
146
|
-
# Extract schemas from type hints (use original func before async wrapping)
|
|
147
|
-
original_func = original_func if 'original_func' in locals() else f
|
|
148
|
-
input_schema, output_schema = _extract_function_schemas(original_func)
|
|
149
|
-
|
|
150
|
-
# Extract metadata (description, etc.)
|
|
151
|
-
method_metadata = _extract_function_metadata(original_func)
|
|
152
|
-
|
|
153
|
-
# Register method
|
|
154
|
-
method_name = f.__name__
|
|
155
|
-
if method_name in self._methods:
|
|
156
|
-
logger.warning(
|
|
157
|
-
f"Overwriting existing method '{method_name}' on entity type '{self.name}'"
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
self._methods[method_name] = f
|
|
161
|
-
self._method_schemas[method_name] = (input_schema, output_schema)
|
|
162
|
-
self._method_metadata[method_name] = method_metadata
|
|
163
|
-
logger.debug(f"Registered method '{method_name}' on entity type '{self.name}'")
|
|
63
|
+
if state_key not in self._states:
|
|
64
|
+
self._states[state_key] = {}
|
|
65
|
+
return self._states[state_key]
|
|
164
66
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
if func is None:
|
|
168
|
-
return decorator
|
|
169
|
-
return decorator(func)
|
|
170
|
-
|
|
171
|
-
def __call__(self, key: str) -> "EntityInstance":
|
|
67
|
+
def get_or_create_lock(self, state_key: Tuple[str, str]) -> asyncio.Lock:
|
|
172
68
|
"""
|
|
173
|
-
|
|
69
|
+
Get or create async lock for entity instance.
|
|
174
70
|
|
|
175
71
|
Args:
|
|
176
|
-
|
|
72
|
+
state_key: Tuple of (entity_type, entity_key)
|
|
177
73
|
|
|
178
74
|
Returns:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
Example:
|
|
182
|
-
```python
|
|
183
|
-
Counter = entity("Counter")
|
|
184
|
-
|
|
185
|
-
counter1 = Counter(key="user-123")
|
|
186
|
-
await counter1.increment(amount=5)
|
|
187
|
-
```
|
|
75
|
+
Async lock for single-writer guarantee
|
|
188
76
|
"""
|
|
189
|
-
|
|
190
|
-
|
|
77
|
+
if state_key not in self._locks:
|
|
78
|
+
self._locks[state_key] = asyncio.Lock()
|
|
79
|
+
return self._locks[state_key]
|
|
191
80
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
"""
|
|
199
|
-
|
|
200
|
-
def __init__(self, entity_type: EntityType, key: str):
|
|
81
|
+
def load_state_from_platform(
|
|
82
|
+
self,
|
|
83
|
+
state_key: Tuple[str, str],
|
|
84
|
+
platform_state_json: str,
|
|
85
|
+
version: int = 0
|
|
86
|
+
) -> None:
|
|
201
87
|
"""
|
|
202
|
-
|
|
88
|
+
Load state from platform for entity persistence.
|
|
203
89
|
|
|
204
90
|
Args:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
91
|
+
state_key: Tuple of (entity_type, entity_key)
|
|
92
|
+
platform_state_json: JSON string of state from platform
|
|
93
|
+
version: Current version from platform
|
|
94
|
+
"""
|
|
95
|
+
import json
|
|
96
|
+
try:
|
|
97
|
+
state = json.loads(platform_state_json)
|
|
98
|
+
self._states[state_key] = state
|
|
99
|
+
self._versions[state_key] = version
|
|
100
|
+
logger.debug(
|
|
101
|
+
f"Loaded platform state: {state_key[0]}/{state_key[1]} (version {version})"
|
|
102
|
+
)
|
|
103
|
+
except json.JSONDecodeError as e:
|
|
104
|
+
logger.warning(f"Failed to parse platform state: {e}")
|
|
105
|
+
self._states[state_key] = {}
|
|
106
|
+
self._versions[state_key] = 0
|
|
212
107
|
|
|
213
|
-
def
|
|
108
|
+
def get_state_for_persistence(
|
|
109
|
+
self,
|
|
110
|
+
state_key: Tuple[str, str]
|
|
111
|
+
) -> tuple[Dict[str, Any], int, int]:
|
|
214
112
|
"""
|
|
215
|
-
|
|
113
|
+
Get state and version info for platform persistence.
|
|
216
114
|
|
|
217
115
|
Args:
|
|
218
|
-
|
|
116
|
+
state_key: Tuple of (entity_type, entity_key)
|
|
219
117
|
|
|
220
118
|
Returns:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
Raises:
|
|
224
|
-
AttributeError: If method doesn't exist on this entity type
|
|
119
|
+
Tuple of (state_dict, expected_version, new_version)
|
|
225
120
|
"""
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
if method_name not in self._entity_type._methods:
|
|
231
|
-
available = ", ".join(self._entity_type._methods.keys())
|
|
232
|
-
raise AttributeError(
|
|
233
|
-
f"Entity type '{self._entity_type.name}' has no method '{method_name}'. "
|
|
234
|
-
f"Available methods: {available or 'none'}"
|
|
235
|
-
)
|
|
121
|
+
state_dict = self._states.get(state_key, {})
|
|
122
|
+
expected_version = self._versions.get(state_key, 0)
|
|
123
|
+
new_version = expected_version + 1
|
|
236
124
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
@functools.wraps(method_func)
|
|
240
|
-
async def method_wrapper(*args, **kwargs) -> Any:
|
|
241
|
-
"""
|
|
242
|
-
Execute entity method with single-writer guarantee.
|
|
243
|
-
|
|
244
|
-
This wrapper:
|
|
245
|
-
1. Acquires lock for this entity instance (single-writer)
|
|
246
|
-
2. Creates Context with entity state
|
|
247
|
-
3. Executes method
|
|
248
|
-
4. Updates state from Context
|
|
249
|
-
"""
|
|
250
|
-
# Get or create lock for this entity instance (single-writer guarantee)
|
|
251
|
-
if self._state_key not in _entity_locks:
|
|
252
|
-
_entity_locks[self._state_key] = asyncio.Lock()
|
|
253
|
-
lock = _entity_locks[self._state_key]
|
|
254
|
-
|
|
255
|
-
async with lock:
|
|
256
|
-
# Get or create state for this entity instance
|
|
257
|
-
if self._state_key not in _entity_states:
|
|
258
|
-
_entity_states[self._state_key] = {}
|
|
259
|
-
state_dict = _entity_states[self._state_key]
|
|
260
|
-
|
|
261
|
-
# Create Context with entity state
|
|
262
|
-
# Context state is a reference to the entity's state dict
|
|
263
|
-
ctx = Context(
|
|
264
|
-
run_id=f"{self._entity_type.name}:{self._key}:{method_name}",
|
|
265
|
-
component_type="entity",
|
|
266
|
-
object_id=self._key,
|
|
267
|
-
method_name=method_name
|
|
268
|
-
)
|
|
125
|
+
# Update version for next execution
|
|
126
|
+
self._versions[state_key] = new_version
|
|
269
127
|
|
|
270
|
-
|
|
271
|
-
# This allows ctx.get/set/delete to operate on entity state
|
|
272
|
-
ctx._state = state_dict
|
|
273
|
-
|
|
274
|
-
try:
|
|
275
|
-
# Execute method
|
|
276
|
-
logger.debug(
|
|
277
|
-
f"Executing {self._entity_type.name}:{self._key}.{method_name}"
|
|
278
|
-
)
|
|
279
|
-
result = await method_func(ctx, *args, **kwargs)
|
|
280
|
-
logger.debug(
|
|
281
|
-
f"Completed {self._entity_type.name}:{self._key}.{method_name}"
|
|
282
|
-
)
|
|
283
|
-
return result
|
|
284
|
-
|
|
285
|
-
except Exception as e:
|
|
286
|
-
logger.error(
|
|
287
|
-
f"Error in {self._entity_type.name}:{self._key}.{method_name}: {e}",
|
|
288
|
-
exc_info=True
|
|
289
|
-
)
|
|
290
|
-
raise ExecutionError(
|
|
291
|
-
f"Entity method {method_name} failed: {e}"
|
|
292
|
-
) from e
|
|
293
|
-
|
|
294
|
-
return method_wrapper
|
|
128
|
+
return state_dict, expected_version, new_version
|
|
295
129
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
130
|
+
def clear_all(self) -> None:
|
|
131
|
+
"""Clear all state, locks, and versions (for testing)."""
|
|
132
|
+
self._states.clear()
|
|
133
|
+
self._locks.clear()
|
|
134
|
+
self._versions.clear()
|
|
135
|
+
logger.debug("Cleared EntityStateManager")
|
|
300
136
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
return self.
|
|
137
|
+
def get_state(self, entity_type: str, key: str) -> Optional[Dict[str, Any]]:
|
|
138
|
+
"""Get state for debugging/testing."""
|
|
139
|
+
state_key = (entity_type, key)
|
|
140
|
+
return self._states.get(state_key)
|
|
305
141
|
|
|
142
|
+
def get_all_keys(self, entity_type: str) -> list[str]:
|
|
143
|
+
"""Get all keys for entity type (for debugging/testing)."""
|
|
144
|
+
return [
|
|
145
|
+
key for (etype, key) in self._states.keys()
|
|
146
|
+
if etype == entity_type
|
|
147
|
+
]
|
|
306
148
|
|
|
307
|
-
def entity(name: str) -> EntityType:
|
|
308
|
-
"""
|
|
309
|
-
Create a new entity type.
|
|
310
149
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
150
|
+
def _get_state_manager() -> EntityStateManager:
|
|
151
|
+
"""
|
|
152
|
+
Get the current entity state manager from context.
|
|
314
153
|
|
|
315
|
-
|
|
316
|
-
|
|
154
|
+
The state manager must be set by Worker before entity execution.
|
|
155
|
+
This ensures proper worker-scoped state isolation.
|
|
317
156
|
|
|
318
157
|
Returns:
|
|
319
|
-
|
|
158
|
+
EntityStateManager instance
|
|
320
159
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
counter1 = Counter(key="user-123")
|
|
342
|
-
counter2 = Counter(key="user-456")
|
|
343
|
-
|
|
344
|
-
# Invoke methods (guaranteed single-writer per key)
|
|
345
|
-
result = await counter1.increment(amount=5) # Returns 5
|
|
346
|
-
result = await counter1.increment(amount=3) # Returns 8
|
|
347
|
-
|
|
348
|
-
# Different keys execute in parallel
|
|
349
|
-
await asyncio.gather(
|
|
350
|
-
counter1.increment(amount=1), # Parallel
|
|
351
|
-
counter2.increment(amount=1) # Parallel
|
|
160
|
+
Raises:
|
|
161
|
+
RuntimeError: If called outside of Worker context (state manager not set)
|
|
162
|
+
"""
|
|
163
|
+
manager = _entity_state_manager_ctx.get()
|
|
164
|
+
if manager is None:
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
"Entity requires state manager context.\n\n"
|
|
167
|
+
"In production:\n"
|
|
168
|
+
" Entities run automatically through Worker.\n\n"
|
|
169
|
+
"In tests, use one of:\n"
|
|
170
|
+
" Option 1 - Decorator:\n"
|
|
171
|
+
" @with_entity_context\n"
|
|
172
|
+
" async def test_cart():\n"
|
|
173
|
+
" cart = ShoppingCart('key')\n"
|
|
174
|
+
" await cart.add_item(...)\n\n"
|
|
175
|
+
" Option 2 - Fixture:\n"
|
|
176
|
+
" async def test_cart(entity_context):\n"
|
|
177
|
+
" cart = ShoppingCart('key')\n"
|
|
178
|
+
" await cart.add_item(...)\n\n"
|
|
179
|
+
"See: https://docs.agnt5.dev/sdk/entities#testing"
|
|
352
180
|
)
|
|
353
|
-
|
|
181
|
+
return manager
|
|
354
182
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
183
|
+
|
|
184
|
+
# ============================================================================
|
|
185
|
+
# Testing Helpers
|
|
186
|
+
# ============================================================================
|
|
187
|
+
|
|
188
|
+
def with_entity_context(func):
|
|
189
|
+
"""
|
|
190
|
+
Decorator that sets up entity state manager for tests.
|
|
191
|
+
|
|
192
|
+
Usage:
|
|
193
|
+
@with_entity_context
|
|
194
|
+
async def test_shopping_cart():
|
|
195
|
+
cart = ShoppingCart(key="test")
|
|
196
|
+
await cart.add_item("item", 1, 10.0)
|
|
197
|
+
assert cart.state.get("items")
|
|
359
198
|
"""
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
199
|
+
@functools.wraps(func)
|
|
200
|
+
async def wrapper(*args, **kwargs):
|
|
201
|
+
manager = EntityStateManager()
|
|
202
|
+
token = _entity_state_manager_ctx.set(manager)
|
|
203
|
+
try:
|
|
204
|
+
return await func(*args, **kwargs)
|
|
205
|
+
finally:
|
|
206
|
+
_entity_state_manager_ctx.reset(token)
|
|
207
|
+
manager.clear_all()
|
|
208
|
+
return wrapper
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def create_entity_context():
|
|
212
|
+
"""
|
|
213
|
+
Create an entity context for testing (can be used as pytest fixture).
|
|
363
214
|
|
|
215
|
+
Usage in conftest.py or test file:
|
|
216
|
+
import pytest
|
|
217
|
+
from agnt5.entity import create_entity_context
|
|
364
218
|
|
|
365
|
-
|
|
219
|
+
@pytest.fixture
|
|
220
|
+
def entity_context():
|
|
221
|
+
manager, token = create_entity_context()
|
|
222
|
+
yield manager
|
|
223
|
+
# Cleanup happens automatically
|
|
366
224
|
|
|
367
|
-
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple of (EntityStateManager, context_token)
|
|
368
227
|
"""
|
|
369
|
-
|
|
228
|
+
manager = EntityStateManager()
|
|
229
|
+
token = _entity_state_manager_ctx.set(manager)
|
|
230
|
+
return manager, token
|
|
370
231
|
|
|
371
|
-
Warning: Only use for testing. This will delete all entity state.
|
|
372
|
-
"""
|
|
373
|
-
_entity_states.clear()
|
|
374
|
-
_entity_locks.clear()
|
|
375
|
-
logger.debug("Cleared all entity state and locks")
|
|
376
232
|
|
|
233
|
+
class EntityRegistry:
|
|
234
|
+
"""Registry for entity types."""
|
|
377
235
|
|
|
378
|
-
|
|
379
|
-
""
|
|
380
|
-
|
|
236
|
+
@staticmethod
|
|
237
|
+
def register(entity_type: "EntityType") -> None:
|
|
238
|
+
"""Register an entity type."""
|
|
239
|
+
if entity_type.name in _ENTITY_REGISTRY:
|
|
240
|
+
logger.warning(f"Overwriting existing entity type '{entity_type.name}'")
|
|
241
|
+
_ENTITY_REGISTRY[entity_type.name] = entity_type
|
|
242
|
+
logger.debug(f"Registered entity type '{entity_type.name}'")
|
|
381
243
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
244
|
+
@staticmethod
|
|
245
|
+
def get(name: str) -> Optional["EntityType"]:
|
|
246
|
+
"""Get entity type by name."""
|
|
247
|
+
return _ENTITY_REGISTRY.get(name)
|
|
385
248
|
|
|
386
|
-
|
|
387
|
-
|
|
249
|
+
@staticmethod
|
|
250
|
+
def all() -> Dict[str, "EntityType"]:
|
|
251
|
+
"""Get all registered entities."""
|
|
252
|
+
return _ENTITY_REGISTRY.copy()
|
|
388
253
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
254
|
+
@staticmethod
|
|
255
|
+
def clear() -> None:
|
|
256
|
+
"""Clear all registered entities."""
|
|
257
|
+
_ENTITY_REGISTRY.clear()
|
|
258
|
+
logger.debug("Cleared entity registry")
|
|
393
259
|
|
|
394
260
|
|
|
395
|
-
|
|
261
|
+
class EntityType:
|
|
396
262
|
"""
|
|
397
|
-
|
|
263
|
+
Metadata about an Entity class.
|
|
398
264
|
|
|
399
|
-
|
|
400
|
-
|
|
265
|
+
Stores entity name, method schemas, and metadata for Worker auto-discovery
|
|
266
|
+
and platform integration. Created automatically when Entity subclasses are defined.
|
|
267
|
+
"""
|
|
401
268
|
|
|
402
|
-
|
|
403
|
-
|
|
269
|
+
def __init__(self, name: str, entity_class: type):
|
|
270
|
+
"""
|
|
271
|
+
Initialize entity type metadata.
|
|
404
272
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
273
|
+
Args:
|
|
274
|
+
name: Entity type name (class name)
|
|
275
|
+
entity_class: Reference to the Entity class
|
|
276
|
+
"""
|
|
277
|
+
self.name = name
|
|
278
|
+
self.entity_class = entity_class
|
|
279
|
+
self._method_schemas: Dict[str, Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]] = {}
|
|
280
|
+
self._method_metadata: Dict[str, Dict[str, str]] = {}
|
|
281
|
+
logger.debug("Created entity type: %s", name)
|
|
411
282
|
|
|
412
283
|
|
|
413
284
|
# ============================================================================
|
|
414
|
-
#
|
|
285
|
+
# Class-Based Entity API (Cloudflare Durable Objects style)
|
|
415
286
|
# ============================================================================
|
|
416
287
|
|
|
417
|
-
class
|
|
288
|
+
class EntityState:
|
|
418
289
|
"""
|
|
419
|
-
|
|
290
|
+
Simple state interface for Entity instances.
|
|
420
291
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
292
|
+
Provides a clean API for state management:
|
|
293
|
+
self.state.get(key, default)
|
|
294
|
+
self.state.set(key, value)
|
|
295
|
+
self.state.delete(key)
|
|
296
|
+
self.state.clear()
|
|
425
297
|
|
|
426
|
-
|
|
298
|
+
State operations are synchronous and backed by an internal dict.
|
|
427
299
|
"""
|
|
428
300
|
|
|
429
|
-
def __init__(self,
|
|
430
|
-
"""
|
|
431
|
-
|
|
301
|
+
def __init__(self, state_dict: Dict[str, Any]):
|
|
302
|
+
"""
|
|
303
|
+
Initialize state wrapper with a state dict.
|
|
432
304
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
305
|
+
Args:
|
|
306
|
+
state_dict: Dictionary to use for state storage
|
|
307
|
+
"""
|
|
308
|
+
self._state = state_dict
|
|
436
309
|
|
|
437
|
-
|
|
438
|
-
"""
|
|
439
|
-
|
|
310
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
311
|
+
"""Get value from state."""
|
|
312
|
+
return self._state.get(key, default)
|
|
440
313
|
|
|
441
|
-
|
|
442
|
-
"""
|
|
443
|
-
|
|
314
|
+
def set(self, key: str, value: Any) -> None:
|
|
315
|
+
"""Set value in state."""
|
|
316
|
+
self._state[key] = value
|
|
444
317
|
|
|
445
|
-
|
|
446
|
-
"""
|
|
447
|
-
|
|
318
|
+
def delete(self, key: str) -> None:
|
|
319
|
+
"""Delete key from state."""
|
|
320
|
+
self._state.pop(key, None)
|
|
448
321
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
return self._context.run_id
|
|
322
|
+
def clear(self) -> None:
|
|
323
|
+
"""Clear all state."""
|
|
324
|
+
self._state.clear()
|
|
453
325
|
|
|
454
|
-
@property
|
|
455
|
-
def object_id(self) -> Optional[str]:
|
|
456
|
-
return self._context.object_id
|
|
457
326
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
327
|
+
def _create_entity_method_wrapper(entity_type: str, method):
|
|
328
|
+
"""
|
|
329
|
+
Create a wrapper for an entity method that provides single-writer consistency.
|
|
330
|
+
|
|
331
|
+
This wrapper:
|
|
332
|
+
1. Acquires a lock for the entity instance (single-writer guarantee)
|
|
333
|
+
2. Sets up EntityState with the state dict
|
|
334
|
+
3. Executes the method
|
|
335
|
+
4. Cleans up state reference
|
|
336
|
+
5. Handles errors appropriately
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
entity_type: Name of the entity type (class name)
|
|
340
|
+
method: The async method to wrap
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Wrapped async method with single-writer consistency
|
|
344
|
+
"""
|
|
345
|
+
@functools.wraps(method)
|
|
346
|
+
async def entity_method_wrapper(self, *args, **kwargs):
|
|
347
|
+
"""Execute entity method with single-writer guarantee."""
|
|
348
|
+
state_key = (entity_type, self._key)
|
|
349
|
+
|
|
350
|
+
# Get state manager and lock (single-writer guarantee)
|
|
351
|
+
state_manager = _get_state_manager()
|
|
352
|
+
lock = state_manager.get_or_create_lock(state_key)
|
|
353
|
+
|
|
354
|
+
async with lock:
|
|
355
|
+
# TODO: Load state from platform if not in memory
|
|
356
|
+
# if state_key not in state_manager._states and state_manager._rust_manager:
|
|
357
|
+
# result = await state_manager._rust_manager.load_state(
|
|
358
|
+
# tenant_id, entity_type, self._key
|
|
359
|
+
# )
|
|
360
|
+
# if result.found:
|
|
361
|
+
# state_manager.load_state_from_platform(
|
|
362
|
+
# state_key, result.state_json, result.version
|
|
363
|
+
# )
|
|
364
|
+
|
|
365
|
+
# Get or create state for this entity instance
|
|
366
|
+
state_dict = state_manager.get_or_create_state(state_key)
|
|
367
|
+
|
|
368
|
+
# Set up EntityState on instance for method access
|
|
369
|
+
self._state = EntityState(state_dict)
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
# Execute method
|
|
373
|
+
logger.debug("Executing %s:%s.%s", entity_type, self._key, method.__name__)
|
|
374
|
+
result = await method(self, *args, **kwargs)
|
|
375
|
+
logger.debug("Completed %s:%s.%s", entity_type, self._key, method.__name__)
|
|
376
|
+
|
|
377
|
+
# TODO: Save state to platform after successful execution
|
|
378
|
+
# if state_manager._rust_manager:
|
|
379
|
+
# state_dict, expected_version, new_version = \
|
|
380
|
+
# state_manager.get_state_for_persistence(state_key)
|
|
381
|
+
# import json
|
|
382
|
+
# state_json = json.dumps(state_dict).encode('utf-8')
|
|
383
|
+
# save_result = await state_manager._rust_manager.save_state(
|
|
384
|
+
# tenant_id, entity_type, self._key, state_json, expected_version
|
|
385
|
+
# )
|
|
386
|
+
# state_manager._versions[state_key] = save_result.new_version
|
|
387
|
+
|
|
388
|
+
return result
|
|
389
|
+
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.error(
|
|
392
|
+
"Error in %s:%s.%s: %s",
|
|
393
|
+
entity_type, self._key, method.__name__, e,
|
|
394
|
+
exc_info=True
|
|
395
|
+
)
|
|
396
|
+
raise ExecutionError(
|
|
397
|
+
f"Entity method {method.__name__} failed: {e}"
|
|
398
|
+
) from e
|
|
399
|
+
finally:
|
|
400
|
+
# Clear state reference after execution
|
|
401
|
+
self._state = None
|
|
402
|
+
|
|
403
|
+
return entity_method_wrapper
|
|
461
404
|
|
|
462
405
|
|
|
463
|
-
class
|
|
406
|
+
class Entity:
|
|
464
407
|
"""
|
|
465
|
-
Base class for
|
|
408
|
+
Base class for stateful entities with single-writer consistency.
|
|
466
409
|
|
|
467
|
-
|
|
468
|
-
- State is accessed via self.
|
|
410
|
+
Entities provide a class-based API where:
|
|
411
|
+
- State is accessed via self.state (clean, synchronous API)
|
|
469
412
|
- Methods are regular async methods on the class
|
|
470
413
|
- Each instance is bound to a unique key
|
|
471
|
-
- Single-writer consistency per key is guaranteed
|
|
414
|
+
- Single-writer consistency per key is guaranteed automatically
|
|
472
415
|
|
|
473
416
|
Example:
|
|
474
417
|
```python
|
|
475
|
-
from agnt5 import
|
|
418
|
+
from agnt5 import Entity
|
|
476
419
|
|
|
477
|
-
class ShoppingCart(
|
|
420
|
+
class ShoppingCart(Entity):
|
|
478
421
|
async def add_item(self, item_id: str, quantity: int, price: float) -> dict:
|
|
479
|
-
items =
|
|
422
|
+
items = self.state.get("items", {})
|
|
480
423
|
items[item_id] = {"quantity": quantity, "price": price}
|
|
481
|
-
|
|
424
|
+
self.state.set("items", items)
|
|
482
425
|
return {"total_items": len(items)}
|
|
483
426
|
|
|
484
427
|
async def get_total(self) -> float:
|
|
485
|
-
items =
|
|
428
|
+
items = self.state.get("items", {})
|
|
486
429
|
return sum(item["quantity"] * item["price"] for item in items.values())
|
|
487
430
|
|
|
488
431
|
# Usage
|
|
@@ -493,12 +436,12 @@ class DurableEntity:
|
|
|
493
436
|
|
|
494
437
|
Note:
|
|
495
438
|
Methods are automatically wrapped to provide single-writer consistency per key.
|
|
496
|
-
State operations
|
|
439
|
+
State operations are synchronous for simplicity.
|
|
497
440
|
"""
|
|
498
441
|
|
|
499
442
|
def __init__(self, key: str):
|
|
500
443
|
"""
|
|
501
|
-
Initialize
|
|
444
|
+
Initialize an entity instance.
|
|
502
445
|
|
|
503
446
|
Args:
|
|
504
447
|
key: Unique identifier for this entity instance
|
|
@@ -507,35 +450,47 @@ class DurableEntity:
|
|
|
507
450
|
self._entity_type = self.__class__.__name__
|
|
508
451
|
self._state_key = (self._entity_type, key)
|
|
509
452
|
|
|
510
|
-
#
|
|
511
|
-
self.
|
|
453
|
+
# State will be initialized during method execution by wrapper
|
|
454
|
+
self._state = None
|
|
512
455
|
|
|
513
|
-
logger.debug(
|
|
456
|
+
logger.debug("Created Entity instance: %s:%s", self._entity_type, key)
|
|
514
457
|
|
|
515
458
|
@property
|
|
516
|
-
def
|
|
459
|
+
def state(self) -> EntityState:
|
|
517
460
|
"""
|
|
518
|
-
Get the
|
|
461
|
+
Get the state interface for this entity.
|
|
519
462
|
|
|
520
|
-
Available
|
|
521
|
-
-
|
|
522
|
-
-
|
|
523
|
-
-
|
|
524
|
-
-
|
|
463
|
+
Available operations:
|
|
464
|
+
- self.state.get(key, default)
|
|
465
|
+
- self.state.set(key, value)
|
|
466
|
+
- self.state.delete(key)
|
|
467
|
+
- self.state.clear()
|
|
525
468
|
|
|
526
469
|
Returns:
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
470
|
+
EntityState for synchronous state operations
|
|
471
|
+
|
|
472
|
+
Raises:
|
|
473
|
+
RuntimeError: If accessed outside of an entity method
|
|
474
|
+
"""
|
|
475
|
+
if self._state is None:
|
|
476
|
+
raise RuntimeError(
|
|
477
|
+
f"Entity state can only be accessed within entity methods.\n\n"
|
|
478
|
+
f"You tried to access state on {self._entity_type}(key='{self._key}') "
|
|
479
|
+
f"outside of a method call.\n\n"
|
|
480
|
+
f"❌ Wrong:\n"
|
|
481
|
+
f" cart = ShoppingCart(key='user-123')\n"
|
|
482
|
+
f" items = cart.state.get('items') # Error!\n\n"
|
|
483
|
+
f"✅ Correct:\n"
|
|
484
|
+
f" class ShoppingCart(Entity):\n"
|
|
485
|
+
f" async def get_items(self):\n"
|
|
486
|
+
f" return self.state.get('items', {{}}) # Works!\n\n"
|
|
487
|
+
f" cart = ShoppingCart(key='user-123')\n"
|
|
488
|
+
f" items = await cart.get_items() # Call method instead"
|
|
536
489
|
)
|
|
537
|
-
|
|
538
|
-
|
|
490
|
+
|
|
491
|
+
# Type narrowing: after the raise, self._state is guaranteed to be not None
|
|
492
|
+
assert self._state is not None
|
|
493
|
+
return self._state
|
|
539
494
|
|
|
540
495
|
@property
|
|
541
496
|
def key(self) -> str:
|
|
@@ -547,791 +502,44 @@ class DurableEntity:
|
|
|
547
502
|
"""Get the entity type name."""
|
|
548
503
|
return self._entity_type
|
|
549
504
|
|
|
550
|
-
def __getattribute__(self, name: str):
|
|
551
|
-
"""
|
|
552
|
-
Intercept method calls to add single-writer consistency.
|
|
553
|
-
|
|
554
|
-
This wraps all async methods (except private/magic methods) with:
|
|
555
|
-
1. Lock acquisition (single-writer per key)
|
|
556
|
-
2. Context setup with entity state
|
|
557
|
-
3. Method execution
|
|
558
|
-
4. State persistence
|
|
559
|
-
"""
|
|
560
|
-
attr = object.__getattribute__(self, name)
|
|
561
|
-
|
|
562
|
-
# Don't wrap private methods, properties, non-callables, or specific attributes
|
|
563
|
-
if (name.startswith('_') or
|
|
564
|
-
not callable(attr) or
|
|
565
|
-
not asyncio.iscoroutinefunction(attr) or
|
|
566
|
-
name in ('ctx', 'key', 'entity_type')): # Skip properties
|
|
567
|
-
return attr
|
|
568
|
-
|
|
569
|
-
# Don't wrap if already wrapped
|
|
570
|
-
if hasattr(attr, '_entity_wrapped'):
|
|
571
|
-
return attr
|
|
572
|
-
|
|
573
|
-
@functools.wraps(attr)
|
|
574
|
-
async def entity_method_wrapper(*args, **kwargs):
|
|
575
|
-
"""
|
|
576
|
-
Execute entity method with single-writer guarantee.
|
|
577
|
-
|
|
578
|
-
This wrapper:
|
|
579
|
-
1. Acquires lock for this entity instance (single-writer)
|
|
580
|
-
2. Creates Context with entity state
|
|
581
|
-
3. Executes method
|
|
582
|
-
4. Updates state from Context
|
|
583
|
-
"""
|
|
584
|
-
state_key = object.__getattribute__(self, '_state_key')
|
|
585
|
-
entity_type = object.__getattribute__(self, '_entity_type')
|
|
586
|
-
key = object.__getattribute__(self, '_key')
|
|
587
|
-
|
|
588
|
-
# Get or create lock for this entity instance (single-writer guarantee)
|
|
589
|
-
if state_key not in _entity_locks:
|
|
590
|
-
_entity_locks[state_key] = asyncio.Lock()
|
|
591
|
-
lock = _entity_locks[state_key]
|
|
592
|
-
|
|
593
|
-
async with lock:
|
|
594
|
-
# Get or create state for this entity instance
|
|
595
|
-
if state_key not in _entity_states:
|
|
596
|
-
_entity_states[state_key] = {}
|
|
597
|
-
state_dict = _entity_states[state_key]
|
|
598
|
-
|
|
599
|
-
# Create Context with entity state
|
|
600
|
-
ctx = Context(
|
|
601
|
-
run_id=f"{entity_type}:{key}:{name}",
|
|
602
|
-
component_type="entity",
|
|
603
|
-
object_id=key,
|
|
604
|
-
method_name=name
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
# Replace Context's internal state with entity state
|
|
608
|
-
ctx._state = state_dict
|
|
609
|
-
|
|
610
|
-
# Set context on instance for method access
|
|
611
|
-
object.__setattr__(self, '_ctx', ctx)
|
|
612
|
-
|
|
613
|
-
try:
|
|
614
|
-
# Execute method
|
|
615
|
-
logger.debug(f"Executing {entity_type}:{key}.{name}")
|
|
616
|
-
result = await attr(*args, **kwargs)
|
|
617
|
-
logger.debug(f"Completed {entity_type}:{key}.{name}")
|
|
618
|
-
return result
|
|
619
|
-
|
|
620
|
-
except Exception as e:
|
|
621
|
-
logger.error(
|
|
622
|
-
f"Error in {entity_type}:{key}.{name}: {e}",
|
|
623
|
-
exc_info=True
|
|
624
|
-
)
|
|
625
|
-
raise ExecutionError(
|
|
626
|
-
f"Entity method {name} failed: {e}"
|
|
627
|
-
) from e
|
|
628
|
-
finally:
|
|
629
|
-
# Clear context after execution
|
|
630
|
-
object.__setattr__(self, '_ctx', None)
|
|
631
|
-
|
|
632
|
-
# Mark as wrapped to avoid double-wrapping
|
|
633
|
-
entity_method_wrapper._entity_wrapped = True
|
|
634
|
-
return entity_method_wrapper
|
|
635
|
-
|
|
636
|
-
|
|
637
505
|
def __init_subclass__(cls, **kwargs):
|
|
638
506
|
"""
|
|
639
|
-
Auto-register
|
|
507
|
+
Auto-register Entity subclasses and wrap methods.
|
|
640
508
|
|
|
641
|
-
This is called automatically when a class inherits from
|
|
509
|
+
This is called automatically when a class inherits from Entity.
|
|
510
|
+
It performs two tasks:
|
|
511
|
+
1. Wraps all public async methods with single-writer consistency
|
|
512
|
+
2. Registers the entity type with metadata for platform discovery
|
|
642
513
|
"""
|
|
643
514
|
super().__init_subclass__(**kwargs)
|
|
644
515
|
|
|
645
|
-
# Don't register the base
|
|
646
|
-
if cls.__name__ == '
|
|
516
|
+
# Don't register the base Entity class itself
|
|
517
|
+
if cls.__name__ == 'Entity':
|
|
647
518
|
return
|
|
648
519
|
|
|
649
520
|
# Don't register SDK's built-in base classes (these are meant to be extended by users)
|
|
650
|
-
if cls.__name__ in ('SessionEntity', 'MemoryEntity'
|
|
521
|
+
if cls.__name__ in ('SessionEntity', 'MemoryEntity'):
|
|
651
522
|
return
|
|
652
523
|
|
|
653
524
|
# Create an EntityType for this class, storing the class reference
|
|
654
525
|
entity_type = EntityType(cls.__name__, entity_class=cls)
|
|
655
526
|
|
|
656
|
-
#
|
|
527
|
+
# Wrap all public async methods and register them
|
|
657
528
|
for name, method in inspect.getmembers(cls, predicate=inspect.iscoroutinefunction):
|
|
658
529
|
if not name.startswith('_'):
|
|
659
530
|
# Extract schemas from the method
|
|
660
|
-
input_schema, output_schema =
|
|
661
|
-
method_metadata =
|
|
531
|
+
input_schema, output_schema = extract_function_schemas(method)
|
|
532
|
+
method_metadata = extract_function_metadata(method)
|
|
662
533
|
|
|
663
534
|
# Store in entity type
|
|
664
535
|
entity_type._method_schemas[name] = (input_schema, output_schema)
|
|
665
536
|
entity_type._method_metadata[name] = method_metadata
|
|
666
537
|
|
|
667
|
-
#
|
|
668
|
-
#
|
|
538
|
+
# Wrap the method with single-writer consistency
|
|
539
|
+
# This happens once at class definition time (not per-call)
|
|
540
|
+
wrapped_method = _create_entity_method_wrapper(cls.__name__, method)
|
|
541
|
+
setattr(cls, name, wrapped_method)
|
|
669
542
|
|
|
670
543
|
# Register the entity type
|
|
671
544
|
EntityRegistry.register(entity_type)
|
|
672
|
-
logger.debug(f"Auto-registered
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
class SessionEntity(DurableEntity):
|
|
676
|
-
"""
|
|
677
|
-
Session-based entity with built-in conversation history management.
|
|
678
|
-
|
|
679
|
-
Inspired by Google ADK and OpenAI Agents SDK session patterns.
|
|
680
|
-
Automatically manages message history with trimming and optional summarization.
|
|
681
|
-
|
|
682
|
-
Configuration (class variables):
|
|
683
|
-
max_turns: Maximum conversation turns to keep (default: 20)
|
|
684
|
-
auto_summarize: Enable automatic summarization of old messages (default: False)
|
|
685
|
-
history_key: State key for storing history (default: "_history")
|
|
686
|
-
summary_key: State key for storing summary (default: "_summary")
|
|
687
|
-
|
|
688
|
-
Built-in Methods:
|
|
689
|
-
- add_message(role, content, **metadata): Add message to history
|
|
690
|
-
- get_history(limit=None): Get conversation history
|
|
691
|
-
- clear_history(): Clear all history
|
|
692
|
-
- get_summary(): Get conversation summary (if auto_summarize enabled)
|
|
693
|
-
|
|
694
|
-
Example:
|
|
695
|
-
```python
|
|
696
|
-
from agnt5 import SessionEntity
|
|
697
|
-
|
|
698
|
-
class Conversation(SessionEntity):
|
|
699
|
-
max_turns: int = 20
|
|
700
|
-
auto_summarize: bool = True
|
|
701
|
-
|
|
702
|
-
async def chat(self, message: str) -> str:
|
|
703
|
-
# Add user message (automatic)
|
|
704
|
-
await self.add_message("user", message)
|
|
705
|
-
|
|
706
|
-
# Get history (auto-trimmed)
|
|
707
|
-
history = await self.get_history(limit=10)
|
|
708
|
-
|
|
709
|
-
# Generate AI response
|
|
710
|
-
response = await some_ai_call(history)
|
|
711
|
-
|
|
712
|
-
# Add AI response (automatic)
|
|
713
|
-
await self.add_message("assistant", response)
|
|
714
|
-
|
|
715
|
-
return response
|
|
716
|
-
|
|
717
|
-
# Usage
|
|
718
|
-
conv = Conversation(key="user-123")
|
|
719
|
-
response = await conv.chat("Hello!") # History managed automatically
|
|
720
|
-
```
|
|
721
|
-
"""
|
|
722
|
-
|
|
723
|
-
# Configuration (can be overridden in subclasses)
|
|
724
|
-
max_turns: int = 20
|
|
725
|
-
auto_summarize: bool = False
|
|
726
|
-
history_key: str = "_history"
|
|
727
|
-
summary_key: str = "_summary"
|
|
728
|
-
|
|
729
|
-
async def add_message(
|
|
730
|
-
self,
|
|
731
|
-
role: str,
|
|
732
|
-
content: str,
|
|
733
|
-
**metadata
|
|
734
|
-
) -> dict:
|
|
735
|
-
"""
|
|
736
|
-
Add a message to the conversation history.
|
|
737
|
-
|
|
738
|
-
Args:
|
|
739
|
-
role: Message role (e.g., "user", "assistant", "system")
|
|
740
|
-
content: Message content
|
|
741
|
-
**metadata: Additional metadata (name, timestamp, etc.)
|
|
742
|
-
|
|
743
|
-
Returns:
|
|
744
|
-
dict with message info and current history length
|
|
745
|
-
"""
|
|
746
|
-
import time
|
|
747
|
-
|
|
748
|
-
# Get current history
|
|
749
|
-
history = await self.ctx.get(self.history_key, [])
|
|
750
|
-
|
|
751
|
-
# Create message
|
|
752
|
-
message = {
|
|
753
|
-
"role": role,
|
|
754
|
-
"content": content,
|
|
755
|
-
"timestamp": metadata.get("timestamp", time.time()),
|
|
756
|
-
**metadata
|
|
757
|
-
}
|
|
758
|
-
|
|
759
|
-
# Add to history
|
|
760
|
-
history.append(message)
|
|
761
|
-
|
|
762
|
-
# Trim if needed
|
|
763
|
-
if len(history) > self.max_turns * 2: # 2 messages per turn (user + assistant)
|
|
764
|
-
if self.auto_summarize:
|
|
765
|
-
# Summarize old messages before trimming
|
|
766
|
-
await self._summarize_and_trim(history)
|
|
767
|
-
else:
|
|
768
|
-
# Just trim
|
|
769
|
-
history = history[-(self.max_turns * 2):]
|
|
770
|
-
|
|
771
|
-
# Save history
|
|
772
|
-
await self.ctx.set(self.history_key, history)
|
|
773
|
-
|
|
774
|
-
return {
|
|
775
|
-
"role": role,
|
|
776
|
-
"added": True,
|
|
777
|
-
"history_length": len(history),
|
|
778
|
-
"timestamp": message["timestamp"]
|
|
779
|
-
}
|
|
780
|
-
|
|
781
|
-
async def get_history(self, limit: Optional[int] = None) -> list:
|
|
782
|
-
"""
|
|
783
|
-
Get conversation history.
|
|
784
|
-
|
|
785
|
-
Args:
|
|
786
|
-
limit: Maximum number of messages to return (None = all)
|
|
787
|
-
|
|
788
|
-
Returns:
|
|
789
|
-
List of message dicts
|
|
790
|
-
"""
|
|
791
|
-
history = await self.ctx.get(self.history_key, [])
|
|
792
|
-
|
|
793
|
-
if limit is not None:
|
|
794
|
-
return history[-limit:]
|
|
795
|
-
|
|
796
|
-
return history
|
|
797
|
-
|
|
798
|
-
async def clear_history(self) -> dict:
|
|
799
|
-
"""
|
|
800
|
-
Clear all conversation history.
|
|
801
|
-
|
|
802
|
-
Returns:
|
|
803
|
-
dict with status and cleared count
|
|
804
|
-
"""
|
|
805
|
-
history = await self.ctx.get(self.history_key, [])
|
|
806
|
-
count = len(history)
|
|
807
|
-
|
|
808
|
-
await self.ctx.delete(self.history_key)
|
|
809
|
-
|
|
810
|
-
if self.auto_summarize:
|
|
811
|
-
await self.ctx.delete(self.summary_key)
|
|
812
|
-
|
|
813
|
-
return {
|
|
814
|
-
"cleared": True,
|
|
815
|
-
"message_count": count
|
|
816
|
-
}
|
|
817
|
-
|
|
818
|
-
async def get_summary(self) -> Optional[str]:
|
|
819
|
-
"""
|
|
820
|
-
Get conversation summary (if auto_summarize is enabled).
|
|
821
|
-
|
|
822
|
-
Returns:
|
|
823
|
-
Summary string or None if no summary exists
|
|
824
|
-
"""
|
|
825
|
-
if not self.auto_summarize:
|
|
826
|
-
return None
|
|
827
|
-
|
|
828
|
-
return await self.ctx.get(self.summary_key)
|
|
829
|
-
|
|
830
|
-
async def _summarize_and_trim(self, history: list) -> None:
|
|
831
|
-
"""
|
|
832
|
-
Summarize old messages and trim history.
|
|
833
|
-
|
|
834
|
-
This is a placeholder for future AI-powered summarization.
|
|
835
|
-
For now, it just stores a simple summary and trims.
|
|
836
|
-
|
|
837
|
-
Args:
|
|
838
|
-
history: Current message history
|
|
839
|
-
"""
|
|
840
|
-
# Messages to summarize (oldest half)
|
|
841
|
-
to_summarize = history[:len(history) // 2]
|
|
842
|
-
|
|
843
|
-
# Simple summary (in future, use AI to generate better summary)
|
|
844
|
-
summary_text = f"Conversation summary: {len(to_summarize)} messages exchanged"
|
|
845
|
-
|
|
846
|
-
# Get existing summary
|
|
847
|
-
existing_summary = await self.ctx.get(self.summary_key)
|
|
848
|
-
if existing_summary:
|
|
849
|
-
summary_text = f"{existing_summary}\n{summary_text}"
|
|
850
|
-
|
|
851
|
-
# Store summary
|
|
852
|
-
await self.ctx.set(self.summary_key, summary_text)
|
|
853
|
-
|
|
854
|
-
# Trim history (keep most recent messages)
|
|
855
|
-
trimmed_history = history[len(history) // 2:]
|
|
856
|
-
await self.ctx.set(self.history_key, trimmed_history)
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
class MemoryEntity(DurableEntity):
|
|
860
|
-
"""
|
|
861
|
-
Memory entity for cross-session knowledge storage and retrieval.
|
|
862
|
-
|
|
863
|
-
Provides semantic memory storage with search capabilities.
|
|
864
|
-
In Phase 1: Simple keyword-based search (in-memory)
|
|
865
|
-
Future: Vector embeddings with semantic search (Pinecone, Weaviate, etc.)
|
|
866
|
-
|
|
867
|
-
Configuration (class variables):
|
|
868
|
-
memory_key: State key for storing memories (default: "_memories")
|
|
869
|
-
max_memories: Maximum memories to keep (default: 100)
|
|
870
|
-
|
|
871
|
-
Built-in Methods:
|
|
872
|
-
- store(key, content, **metadata): Store a memory
|
|
873
|
-
- recall(query, limit=5): Search memories
|
|
874
|
-
- forget(key): Delete a memory
|
|
875
|
-
- list_memories(): List all stored memories
|
|
876
|
-
|
|
877
|
-
Example:
|
|
878
|
-
```python
|
|
879
|
-
from agnt5 import MemoryEntity
|
|
880
|
-
|
|
881
|
-
class AgentMemory(MemoryEntity):
|
|
882
|
-
max_memories: int = 50
|
|
883
|
-
|
|
884
|
-
async def remember_fact(self, fact: str, category: str) -> dict:
|
|
885
|
-
# Store with metadata
|
|
886
|
-
return await self.store(
|
|
887
|
-
key=f"fact_{len(await self.list_memories())}",
|
|
888
|
-
content=fact,
|
|
889
|
-
category=category
|
|
890
|
-
)
|
|
891
|
-
|
|
892
|
-
async def find_facts(self, query: str) -> list:
|
|
893
|
-
# Search memories
|
|
894
|
-
results = await self.recall(query, limit=5)
|
|
895
|
-
return [r["content"] for r in results]
|
|
896
|
-
|
|
897
|
-
# Usage
|
|
898
|
-
memory = AgentMemory(key="agent-123")
|
|
899
|
-
await memory.remember_fact("Paris is the capital of France", category="geography")
|
|
900
|
-
results = await memory.find_facts("capital France")
|
|
901
|
-
```
|
|
902
|
-
"""
|
|
903
|
-
|
|
904
|
-
# Configuration
|
|
905
|
-
memory_key: str = "_memories"
|
|
906
|
-
max_memories: int = 100
|
|
907
|
-
|
|
908
|
-
async def store(
|
|
909
|
-
self,
|
|
910
|
-
key: str,
|
|
911
|
-
content: str,
|
|
912
|
-
**metadata
|
|
913
|
-
) -> dict:
|
|
914
|
-
"""
|
|
915
|
-
Store a memory with optional metadata.
|
|
916
|
-
|
|
917
|
-
Args:
|
|
918
|
-
key: Unique identifier for this memory
|
|
919
|
-
content: The memory content to store
|
|
920
|
-
**metadata: Additional metadata (tags, category, timestamp, etc.)
|
|
921
|
-
|
|
922
|
-
Returns:
|
|
923
|
-
dict with storage confirmation
|
|
924
|
-
"""
|
|
925
|
-
import time
|
|
926
|
-
|
|
927
|
-
# Get current memories
|
|
928
|
-
memories = await self.ctx.get(self.memory_key, {})
|
|
929
|
-
|
|
930
|
-
# Create memory entry
|
|
931
|
-
memory = {
|
|
932
|
-
"content": content,
|
|
933
|
-
"timestamp": metadata.get("timestamp", time.time()),
|
|
934
|
-
**metadata
|
|
935
|
-
}
|
|
936
|
-
|
|
937
|
-
# Store memory
|
|
938
|
-
memories[key] = memory
|
|
939
|
-
|
|
940
|
-
# Trim if needed
|
|
941
|
-
if len(memories) > self.max_memories:
|
|
942
|
-
# Remove oldest memories
|
|
943
|
-
sorted_keys = sorted(
|
|
944
|
-
memories.keys(),
|
|
945
|
-
key=lambda k: memories[k].get("timestamp", 0)
|
|
946
|
-
)
|
|
947
|
-
for old_key in sorted_keys[:len(memories) - self.max_memories]:
|
|
948
|
-
del memories[old_key]
|
|
949
|
-
|
|
950
|
-
# Save memories
|
|
951
|
-
await self.ctx.set(self.memory_key, memories)
|
|
952
|
-
|
|
953
|
-
return {
|
|
954
|
-
"stored": True,
|
|
955
|
-
"key": key,
|
|
956
|
-
"total_memories": len(memories)
|
|
957
|
-
}
|
|
958
|
-
|
|
959
|
-
async def recall(
|
|
960
|
-
self,
|
|
961
|
-
query: str,
|
|
962
|
-
limit: int = 5
|
|
963
|
-
) -> list:
|
|
964
|
-
"""
|
|
965
|
-
Search memories using keyword matching.
|
|
966
|
-
|
|
967
|
-
Phase 1: Simple keyword search
|
|
968
|
-
Future: Semantic search with embeddings
|
|
969
|
-
|
|
970
|
-
Args:
|
|
971
|
-
query: Search query
|
|
972
|
-
limit: Maximum results to return
|
|
973
|
-
|
|
974
|
-
Returns:
|
|
975
|
-
List of matching memories (sorted by relevance)
|
|
976
|
-
"""
|
|
977
|
-
memories = await self.ctx.get(self.memory_key, {})
|
|
978
|
-
|
|
979
|
-
if not memories:
|
|
980
|
-
return []
|
|
981
|
-
|
|
982
|
-
# Simple keyword matching (future: use embeddings)
|
|
983
|
-
query_lower = query.lower()
|
|
984
|
-
matches = []
|
|
985
|
-
|
|
986
|
-
for key, memory in memories.items():
|
|
987
|
-
content = memory.get("content", "").lower()
|
|
988
|
-
|
|
989
|
-
# Calculate simple relevance score (number of matching words)
|
|
990
|
-
query_words = set(query_lower.split())
|
|
991
|
-
content_words = set(content.split())
|
|
992
|
-
matching_words = query_words & content_words
|
|
993
|
-
score = len(matching_words)
|
|
994
|
-
|
|
995
|
-
if score > 0 or query_lower in content:
|
|
996
|
-
matches.append({
|
|
997
|
-
"key": key,
|
|
998
|
-
"content": memory["content"],
|
|
999
|
-
"score": score if score > 0 else 0.5, # Substring match gets 0.5
|
|
1000
|
-
"timestamp": memory.get("timestamp"),
|
|
1001
|
-
**{k: v for k, v in memory.items() if k not in ("content", "timestamp")}
|
|
1002
|
-
})
|
|
1003
|
-
|
|
1004
|
-
# Sort by score (descending)
|
|
1005
|
-
matches.sort(key=lambda x: x["score"], reverse=True)
|
|
1006
|
-
|
|
1007
|
-
return matches[:limit]
|
|
1008
|
-
|
|
1009
|
-
async def forget(self, key: str) -> dict:
|
|
1010
|
-
"""
|
|
1011
|
-
Delete a memory.
|
|
1012
|
-
|
|
1013
|
-
Args:
|
|
1014
|
-
key: Memory key to delete
|
|
1015
|
-
|
|
1016
|
-
Returns:
|
|
1017
|
-
dict with deletion status
|
|
1018
|
-
"""
|
|
1019
|
-
memories = await self.ctx.get(self.memory_key, {})
|
|
1020
|
-
|
|
1021
|
-
if key in memories:
|
|
1022
|
-
del memories[key]
|
|
1023
|
-
await self.ctx.set(self.memory_key, memories)
|
|
1024
|
-
return {"deleted": True, "key": key}
|
|
1025
|
-
|
|
1026
|
-
return {"deleted": False, "key": key, "reason": "not_found"}
|
|
1027
|
-
|
|
1028
|
-
async def list_memories(self) -> list:
|
|
1029
|
-
"""
|
|
1030
|
-
List all stored memories.
|
|
1031
|
-
|
|
1032
|
-
Returns:
|
|
1033
|
-
List of all memories with keys
|
|
1034
|
-
"""
|
|
1035
|
-
memories = await self.ctx.get(self.memory_key, {})
|
|
1036
|
-
|
|
1037
|
-
return [
|
|
1038
|
-
{"key": k, **v}
|
|
1039
|
-
for k, v in memories.items()
|
|
1040
|
-
]
|
|
1041
|
-
|
|
1042
|
-
async def clear_all_memories(self) -> dict:
|
|
1043
|
-
"""
|
|
1044
|
-
Clear all memories.
|
|
1045
|
-
|
|
1046
|
-
Returns:
|
|
1047
|
-
dict with status and count
|
|
1048
|
-
"""
|
|
1049
|
-
memories = await self.ctx.get(self.memory_key, {})
|
|
1050
|
-
count = len(memories)
|
|
1051
|
-
|
|
1052
|
-
await self.ctx.delete(self.memory_key)
|
|
1053
|
-
|
|
1054
|
-
return {
|
|
1055
|
-
"cleared": True,
|
|
1056
|
-
"memory_count": count
|
|
1057
|
-
}
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
class WorkflowEntity(DurableEntity):
|
|
1061
|
-
"""
|
|
1062
|
-
Workflow entity for durable multi-step processes.
|
|
1063
|
-
|
|
1064
|
-
Provides orchestration for complex workflows with step tracking,
|
|
1065
|
-
compensation logic, and automatic state persistence.
|
|
1066
|
-
|
|
1067
|
-
Similar to Temporal/Azure Durable Functions patterns.
|
|
1068
|
-
|
|
1069
|
-
Configuration (class variables):
|
|
1070
|
-
workflow_key: State key for workflow state (default: "_workflow")
|
|
1071
|
-
max_retries: Maximum retries per step (default: 3)
|
|
1072
|
-
|
|
1073
|
-
Built-in Methods:
|
|
1074
|
-
- get_status(): Get workflow execution status
|
|
1075
|
-
- mark_step_complete(step_name, result): Mark step as complete
|
|
1076
|
-
- mark_step_failed(step_name, error): Mark step as failed
|
|
1077
|
-
- rollback(to_step): Rollback to specific step
|
|
1078
|
-
- can_retry(step_name): Check if step can be retried
|
|
1079
|
-
|
|
1080
|
-
Example:
|
|
1081
|
-
```python
|
|
1082
|
-
from agnt5 import WorkflowEntity
|
|
1083
|
-
|
|
1084
|
-
class OrderWorkflow(WorkflowEntity):
|
|
1085
|
-
async def process_order(self, order_id: str, items: list) -> dict:
|
|
1086
|
-
# Step 1: Validate order
|
|
1087
|
-
status = await self.get_status()
|
|
1088
|
-
if "validate" not in status["completed_steps"]:
|
|
1089
|
-
try:
|
|
1090
|
-
validation = await self._validate_order(order_id, items)
|
|
1091
|
-
await self.mark_step_complete("validate", validation)
|
|
1092
|
-
except Exception as e:
|
|
1093
|
-
await self.mark_step_failed("validate", str(e))
|
|
1094
|
-
raise
|
|
1095
|
-
|
|
1096
|
-
# Step 2: Charge payment
|
|
1097
|
-
if "payment" not in status["completed_steps"]:
|
|
1098
|
-
try:
|
|
1099
|
-
charge = await self._charge_payment(order_id)
|
|
1100
|
-
await self.mark_step_complete("payment", charge)
|
|
1101
|
-
except Exception as e:
|
|
1102
|
-
await self.mark_step_failed("payment", str(e))
|
|
1103
|
-
# Rollback validation
|
|
1104
|
-
await self.rollback("validate")
|
|
1105
|
-
raise
|
|
1106
|
-
|
|
1107
|
-
# Step 3: Ship order
|
|
1108
|
-
if "shipping" not in status["completed_steps"]:
|
|
1109
|
-
try:
|
|
1110
|
-
shipment = await self._ship_order(order_id)
|
|
1111
|
-
await self.mark_step_complete("shipping", shipment)
|
|
1112
|
-
except Exception as e:
|
|
1113
|
-
await self.mark_step_failed("shipping", str(e))
|
|
1114
|
-
raise
|
|
1115
|
-
|
|
1116
|
-
return await self.get_status()
|
|
1117
|
-
|
|
1118
|
-
async def _validate_order(self, order_id: str, items: list) -> dict:
|
|
1119
|
-
# Validation logic
|
|
1120
|
-
return {"valid": True, "order_id": order_id}
|
|
1121
|
-
|
|
1122
|
-
async def _charge_payment(self, order_id: str) -> dict:
|
|
1123
|
-
# Payment logic
|
|
1124
|
-
return {"charged": True, "amount": 100}
|
|
1125
|
-
|
|
1126
|
-
async def _ship_order(self, order_id: str) -> dict:
|
|
1127
|
-
# Shipping logic
|
|
1128
|
-
return {"shipped": True, "tracking": "TRACK123"}
|
|
1129
|
-
|
|
1130
|
-
# Usage
|
|
1131
|
-
workflow = OrderWorkflow(key="order-456")
|
|
1132
|
-
result = await workflow.process_order("order-456", [{"sku": "ABC"}])
|
|
1133
|
-
status = await workflow.get_status()
|
|
1134
|
-
```
|
|
1135
|
-
"""
|
|
1136
|
-
|
|
1137
|
-
# Configuration
|
|
1138
|
-
workflow_key: str = "_workflow"
|
|
1139
|
-
max_retries: int = 3
|
|
1140
|
-
|
|
1141
|
-
async def get_status(self) -> dict:
|
|
1142
|
-
"""
|
|
1143
|
-
Get current workflow execution status.
|
|
1144
|
-
|
|
1145
|
-
Returns:
|
|
1146
|
-
dict: {
|
|
1147
|
-
"current_step": str or None,
|
|
1148
|
-
"completed_steps": list of step names,
|
|
1149
|
-
"failed_steps": dict mapping step names to error info,
|
|
1150
|
-
"started_at": timestamp or None,
|
|
1151
|
-
"completed_at": timestamp or None
|
|
1152
|
-
}
|
|
1153
|
-
"""
|
|
1154
|
-
workflow_state = await self.ctx.get(
|
|
1155
|
-
self.workflow_key,
|
|
1156
|
-
{
|
|
1157
|
-
"current_step": None,
|
|
1158
|
-
"completed_steps": [],
|
|
1159
|
-
"failed_steps": {},
|
|
1160
|
-
"started_at": None,
|
|
1161
|
-
"completed_at": None,
|
|
1162
|
-
},
|
|
1163
|
-
)
|
|
1164
|
-
return workflow_state
|
|
1165
|
-
|
|
1166
|
-
async def mark_step_complete(self, step_name: str, result: Any = None) -> dict:
|
|
1167
|
-
"""
|
|
1168
|
-
Mark a workflow step as successfully completed.
|
|
1169
|
-
|
|
1170
|
-
Args:
|
|
1171
|
-
step_name: Name of the step
|
|
1172
|
-
result: Optional result data from the step
|
|
1173
|
-
|
|
1174
|
-
Returns:
|
|
1175
|
-
dict: Updated workflow status
|
|
1176
|
-
"""
|
|
1177
|
-
workflow_state = await self.get_status()
|
|
1178
|
-
|
|
1179
|
-
if workflow_state["started_at"] is None:
|
|
1180
|
-
workflow_state["started_at"] = time.time()
|
|
1181
|
-
|
|
1182
|
-
# Add to completed steps if not already there
|
|
1183
|
-
if step_name not in workflow_state["completed_steps"]:
|
|
1184
|
-
workflow_state["completed_steps"].append(step_name)
|
|
1185
|
-
|
|
1186
|
-
# Remove from failed steps if it was there
|
|
1187
|
-
if step_name in workflow_state["failed_steps"]:
|
|
1188
|
-
del workflow_state["failed_steps"][step_name]
|
|
1189
|
-
|
|
1190
|
-
# Store step result
|
|
1191
|
-
if result is not None:
|
|
1192
|
-
step_results_key = f"{self.workflow_key}_results"
|
|
1193
|
-
step_results = await self.ctx.get(step_results_key, {})
|
|
1194
|
-
step_results[step_name] = {
|
|
1195
|
-
"result": result,
|
|
1196
|
-
"completed_at": time.time(),
|
|
1197
|
-
}
|
|
1198
|
-
await self.ctx.set(step_results_key, step_results)
|
|
1199
|
-
|
|
1200
|
-
workflow_state["current_step"] = step_name
|
|
1201
|
-
|
|
1202
|
-
await self.ctx.set(self.workflow_key, workflow_state)
|
|
1203
|
-
return workflow_state
|
|
1204
|
-
|
|
1205
|
-
async def mark_step_failed(
|
|
1206
|
-
self, step_name: str, error: str, retry_count: int = 0
|
|
1207
|
-
) -> dict:
|
|
1208
|
-
"""
|
|
1209
|
-
Mark a workflow step as failed.
|
|
1210
|
-
|
|
1211
|
-
Args:
|
|
1212
|
-
step_name: Name of the step
|
|
1213
|
-
error: Error message or description
|
|
1214
|
-
retry_count: Number of retries attempted
|
|
1215
|
-
|
|
1216
|
-
Returns:
|
|
1217
|
-
dict: Updated workflow status
|
|
1218
|
-
"""
|
|
1219
|
-
workflow_state = await self.get_status()
|
|
1220
|
-
|
|
1221
|
-
if workflow_state["started_at"] is None:
|
|
1222
|
-
workflow_state["started_at"] = time.time()
|
|
1223
|
-
|
|
1224
|
-
# Record failure
|
|
1225
|
-
workflow_state["failed_steps"][step_name] = {
|
|
1226
|
-
"error": error,
|
|
1227
|
-
"failed_at": time.time(),
|
|
1228
|
-
"retry_count": retry_count,
|
|
1229
|
-
}
|
|
1230
|
-
|
|
1231
|
-
workflow_state["current_step"] = step_name
|
|
1232
|
-
|
|
1233
|
-
await self.ctx.set(self.workflow_key, workflow_state)
|
|
1234
|
-
return workflow_state
|
|
1235
|
-
|
|
1236
|
-
async def rollback(self, to_step: str) -> dict:
|
|
1237
|
-
"""
|
|
1238
|
-
Rollback workflow to a specific step (for compensation logic).
|
|
1239
|
-
|
|
1240
|
-
Args:
|
|
1241
|
-
to_step: Step name to rollback to
|
|
1242
|
-
|
|
1243
|
-
Returns:
|
|
1244
|
-
dict: Updated workflow status
|
|
1245
|
-
"""
|
|
1246
|
-
workflow_state = await self.get_status()
|
|
1247
|
-
|
|
1248
|
-
# Find the index of the target step
|
|
1249
|
-
if to_step in workflow_state["completed_steps"]:
|
|
1250
|
-
target_index = workflow_state["completed_steps"].index(to_step)
|
|
1251
|
-
|
|
1252
|
-
# Remove all steps after the target
|
|
1253
|
-
workflow_state["completed_steps"] = workflow_state["completed_steps"][
|
|
1254
|
-
: target_index + 1
|
|
1255
|
-
]
|
|
1256
|
-
|
|
1257
|
-
# Clear failed steps that are after the target
|
|
1258
|
-
workflow_state["failed_steps"] = {}
|
|
1259
|
-
|
|
1260
|
-
workflow_state["current_step"] = to_step
|
|
1261
|
-
|
|
1262
|
-
await self.ctx.set(self.workflow_key, workflow_state)
|
|
1263
|
-
|
|
1264
|
-
return workflow_state
|
|
1265
|
-
|
|
1266
|
-
async def can_retry(self, step_name: str) -> bool:
|
|
1267
|
-
"""
|
|
1268
|
-
Check if a failed step can be retried based on max_retries.
|
|
1269
|
-
|
|
1270
|
-
Args:
|
|
1271
|
-
step_name: Name of the step
|
|
1272
|
-
|
|
1273
|
-
Returns:
|
|
1274
|
-
bool: True if step can be retried
|
|
1275
|
-
"""
|
|
1276
|
-
workflow_state = await self.get_status()
|
|
1277
|
-
|
|
1278
|
-
if step_name in workflow_state["failed_steps"]:
|
|
1279
|
-
retry_count = workflow_state["failed_steps"][step_name].get(
|
|
1280
|
-
"retry_count", 0
|
|
1281
|
-
)
|
|
1282
|
-
return retry_count < self.max_retries
|
|
1283
|
-
|
|
1284
|
-
return True
|
|
1285
|
-
|
|
1286
|
-
async def get_step_result(self, step_name: str) -> Any:
|
|
1287
|
-
"""
|
|
1288
|
-
Get the result of a completed step.
|
|
1289
|
-
|
|
1290
|
-
Args:
|
|
1291
|
-
step_name: Name of the step
|
|
1292
|
-
|
|
1293
|
-
Returns:
|
|
1294
|
-
Any: Step result or None if not found
|
|
1295
|
-
"""
|
|
1296
|
-
step_results_key = f"{self.workflow_key}_results"
|
|
1297
|
-
step_results = await self.ctx.get(step_results_key, {})
|
|
1298
|
-
|
|
1299
|
-
if step_name in step_results:
|
|
1300
|
-
return step_results[step_name].get("result")
|
|
1301
|
-
|
|
1302
|
-
return None
|
|
1303
|
-
|
|
1304
|
-
async def complete_workflow(self) -> dict:
|
|
1305
|
-
"""
|
|
1306
|
-
Mark the entire workflow as completed.
|
|
1307
|
-
|
|
1308
|
-
Returns:
|
|
1309
|
-
dict: Final workflow status
|
|
1310
|
-
"""
|
|
1311
|
-
workflow_state = await self.get_status()
|
|
1312
|
-
workflow_state["completed_at"] = time.time()
|
|
1313
|
-
workflow_state["current_step"] = None
|
|
1314
|
-
await self.ctx.set(self.workflow_key, workflow_state)
|
|
1315
|
-
return workflow_state
|
|
1316
|
-
|
|
1317
|
-
async def reset_workflow(self) -> dict:
|
|
1318
|
-
"""
|
|
1319
|
-
Reset workflow state (use with caution).
|
|
1320
|
-
|
|
1321
|
-
Returns:
|
|
1322
|
-
dict: New empty workflow state
|
|
1323
|
-
"""
|
|
1324
|
-
new_state = {
|
|
1325
|
-
"current_step": None,
|
|
1326
|
-
"completed_steps": [],
|
|
1327
|
-
"failed_steps": {},
|
|
1328
|
-
"started_at": None,
|
|
1329
|
-
"completed_at": None,
|
|
1330
|
-
}
|
|
1331
|
-
await self.ctx.set(self.workflow_key, new_state)
|
|
1332
|
-
|
|
1333
|
-
# Clear step results
|
|
1334
|
-
step_results_key = f"{self.workflow_key}_results"
|
|
1335
|
-
await self.ctx.set(step_results_key, {})
|
|
1336
|
-
|
|
1337
|
-
return new_state
|
|
545
|
+
logger.debug(f"Auto-registered Entity subclass: {cls.__name__}")
|