handoff-guard 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- handoff/__init__.py +19 -0
- handoff/core.py +95 -0
- handoff/guard.py +469 -0
- handoff/langgraph.py +76 -0
- handoff/retry.py +145 -0
- handoff/testing.py +39 -0
- handoff/utils.py +49 -0
- handoff_guard-0.1.0.dist-info/METADATA +233 -0
- handoff_guard-0.1.0.dist-info/RECORD +11 -0
- handoff_guard-0.1.0.dist-info/WHEEL +4 -0
- handoff_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
handoff/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from handoff.core import HandoffViolation, ViolationContext
|
|
2
|
+
from handoff.guard import guard, GuardConfig
|
|
3
|
+
from handoff.retry import retry, RetryState, Diagnostic, AttemptRecord
|
|
4
|
+
from handoff.utils import ParseError, parse_json
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"guard",
|
|
8
|
+
"GuardConfig",
|
|
9
|
+
"HandoffViolation",
|
|
10
|
+
"ViolationContext",
|
|
11
|
+
"retry",
|
|
12
|
+
"RetryState",
|
|
13
|
+
"Diagnostic",
|
|
14
|
+
"AttemptRecord",
|
|
15
|
+
"ParseError",
|
|
16
|
+
"parse_json",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
__version__ = "0.2.0"
|
handoff/core.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Core types for handoff validation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, TYPE_CHECKING
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from handoff.retry import AttemptRecord
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ViolationContext:
|
|
15
|
+
"""Rich context about where and why validation failed."""
|
|
16
|
+
|
|
17
|
+
node_name: str
|
|
18
|
+
contract_type: str # "input" | "output" | "invariant"
|
|
19
|
+
field_path: str # e.g., "response.refund_id"
|
|
20
|
+
expected: str # Human-readable expectation
|
|
21
|
+
received: Any # What we actually got
|
|
22
|
+
received_type: str # Type of what we got
|
|
23
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
24
|
+
upstream_node: str | None = None
|
|
25
|
+
suggestion: str | None = None
|
|
26
|
+
|
|
27
|
+
def __str__(self) -> str:
|
|
28
|
+
lines = [
|
|
29
|
+
f"HandoffViolation in '{self.node_name}':",
|
|
30
|
+
f" Contract: {self.contract_type}",
|
|
31
|
+
f" Field: {self.field_path}",
|
|
32
|
+
f" Expected: {self.expected}",
|
|
33
|
+
f" Received: {repr(self.received)[:100]} ({self.received_type})",
|
|
34
|
+
]
|
|
35
|
+
if self.upstream_node:
|
|
36
|
+
lines.append(f" Upstream: {self.upstream_node}")
|
|
37
|
+
if self.suggestion:
|
|
38
|
+
lines.append(f" Suggestion: {self.suggestion}")
|
|
39
|
+
return "\n".join(lines)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HandoffViolation(Exception):
|
|
43
|
+
"""Raised when validation fails at an agent boundary."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
context: ViolationContext,
|
|
48
|
+
history: list[AttemptRecord] | None = None,
|
|
49
|
+
):
|
|
50
|
+
self.context = context
|
|
51
|
+
self.history: list[AttemptRecord] = history or []
|
|
52
|
+
super().__init__(str(context))
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def node_name(self) -> str:
|
|
56
|
+
return self.context.node_name
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def field_path(self) -> str:
|
|
60
|
+
return self.context.field_path
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def total_attempts(self) -> int:
|
|
64
|
+
return len(self.history) if self.history else 1
|
|
65
|
+
|
|
66
|
+
def to_dict(self) -> dict:
|
|
67
|
+
"""Serialize for logging/telemetry."""
|
|
68
|
+
d = {
|
|
69
|
+
"node_name": self.context.node_name,
|
|
70
|
+
"contract_type": self.context.contract_type,
|
|
71
|
+
"field_path": self.context.field_path,
|
|
72
|
+
"expected": self.context.expected,
|
|
73
|
+
"received": repr(self.context.received)[:200],
|
|
74
|
+
"received_type": self.context.received_type,
|
|
75
|
+
"timestamp": self.context.timestamp.isoformat(),
|
|
76
|
+
"upstream_node": self.context.upstream_node,
|
|
77
|
+
"suggestion": self.context.suggestion,
|
|
78
|
+
}
|
|
79
|
+
if self.history:
|
|
80
|
+
d["total_attempts"] = self.total_attempts
|
|
81
|
+
d["history"] = [
|
|
82
|
+
{
|
|
83
|
+
"attempt": rec.attempt,
|
|
84
|
+
"timestamp": rec.timestamp.isoformat(),
|
|
85
|
+
"duration_ms": rec.duration_ms,
|
|
86
|
+
"diagnostic": {
|
|
87
|
+
"cause": rec.diagnostic.cause,
|
|
88
|
+
"message": rec.diagnostic.message,
|
|
89
|
+
}
|
|
90
|
+
if rec.diagnostic
|
|
91
|
+
else None,
|
|
92
|
+
}
|
|
93
|
+
for rec in self.history
|
|
94
|
+
]
|
|
95
|
+
return d
|
handoff/guard.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
"""The @guard decorator for validating agent boundaries."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from typing import Any, Callable, TypeVar, Type, Literal
|
|
8
|
+
import inspect
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ValidationError
|
|
11
|
+
|
|
12
|
+
from handoff.core import HandoffViolation, ViolationContext
|
|
13
|
+
from handoff.retry import (
|
|
14
|
+
_retry_context,
|
|
15
|
+
RetryState,
|
|
16
|
+
Diagnostic,
|
|
17
|
+
AttemptRecord,
|
|
18
|
+
)
|
|
19
|
+
from handoff.utils import ParseError
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
OnFailAction = Literal["raise", "return_none", "return_input"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class GuardConfig:
|
|
28
|
+
"""Configuration for guard behavior."""
|
|
29
|
+
|
|
30
|
+
input_schema: Type[BaseModel] | None = None
|
|
31
|
+
output_schema: Type[BaseModel] | None = None
|
|
32
|
+
node_name: str | None = None # Auto-detected from function name if not provided
|
|
33
|
+
on_fail: OnFailAction | Callable[[HandoffViolation], Any] = "raise"
|
|
34
|
+
validate_input: bool = True
|
|
35
|
+
validate_output: bool = True
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _generate_suggestion(
|
|
39
|
+
err_type: str, field_path: str, contract_type: str
|
|
40
|
+
) -> str | None:
|
|
41
|
+
"""Generate a helpful suggestion based on Pydantic error type."""
|
|
42
|
+
if err_type == "missing":
|
|
43
|
+
return f"Add '{field_path}' to the {contract_type} data"
|
|
44
|
+
elif err_type == "string_type":
|
|
45
|
+
return f"Convert '{field_path}' to string"
|
|
46
|
+
elif err_type == "int_type":
|
|
47
|
+
return f"Convert '{field_path}' to integer"
|
|
48
|
+
elif err_type == "string_too_short":
|
|
49
|
+
return f"Increase the length of '{field_path}'"
|
|
50
|
+
elif err_type == "string_too_long":
|
|
51
|
+
return f"Reduce the length of '{field_path}'"
|
|
52
|
+
elif err_type == "too_short":
|
|
53
|
+
return f"Add more items to '{field_path}'"
|
|
54
|
+
elif err_type == "too_long":
|
|
55
|
+
return f"Reduce the number of items in '{field_path}'"
|
|
56
|
+
elif err_type == "greater_than_equal":
|
|
57
|
+
return f"Increase the value of '{field_path}'"
|
|
58
|
+
elif err_type == "less_than_equal":
|
|
59
|
+
return f"Decrease the value of '{field_path}'"
|
|
60
|
+
elif err_type == "string_pattern_mismatch":
|
|
61
|
+
return f"'{field_path}' does not match the required pattern"
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _extract_violations(
|
|
66
|
+
error: ValidationError,
|
|
67
|
+
node_name: str,
|
|
68
|
+
contract_type: str,
|
|
69
|
+
raw_data: Any,
|
|
70
|
+
) -> list[ViolationContext]:
|
|
71
|
+
"""Convert Pydantic ValidationError to rich ViolationContext objects."""
|
|
72
|
+
|
|
73
|
+
violations = []
|
|
74
|
+
for err in error.errors():
|
|
75
|
+
field_path = ".".join(str(loc) for loc in err["loc"])
|
|
76
|
+
|
|
77
|
+
# Try to get the actual value at the path
|
|
78
|
+
received = raw_data
|
|
79
|
+
for loc in err["loc"]:
|
|
80
|
+
if isinstance(received, dict):
|
|
81
|
+
received = received.get(loc, "<missing>")
|
|
82
|
+
elif hasattr(received, str(loc)):
|
|
83
|
+
received = getattr(received, str(loc), "<missing>")
|
|
84
|
+
else:
|
|
85
|
+
received = "<missing>"
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
suggestion = _generate_suggestion(err["type"], field_path, contract_type)
|
|
89
|
+
|
|
90
|
+
violations.append(
|
|
91
|
+
ViolationContext(
|
|
92
|
+
node_name=node_name,
|
|
93
|
+
contract_type=contract_type,
|
|
94
|
+
field_path=field_path or "<root>",
|
|
95
|
+
expected=err["msg"],
|
|
96
|
+
received=received,
|
|
97
|
+
received_type=type(received).__name__,
|
|
98
|
+
suggestion=suggestion,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return violations
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _validate_data(
|
|
106
|
+
data: Any,
|
|
107
|
+
schema: Type[BaseModel],
|
|
108
|
+
node_name: str,
|
|
109
|
+
contract_type: str,
|
|
110
|
+
) -> tuple[bool, BaseModel | None, list[ViolationContext]]:
|
|
111
|
+
"""Validate data against schema, return (success, validated_model, violations)."""
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
# Handle both dict and BaseModel inputs
|
|
115
|
+
if isinstance(data, BaseModel):
|
|
116
|
+
validated = schema.model_validate(data.model_dump())
|
|
117
|
+
elif isinstance(data, dict):
|
|
118
|
+
validated = schema.model_validate(data)
|
|
119
|
+
else:
|
|
120
|
+
# Try to convert to dict
|
|
121
|
+
validated = schema.model_validate(
|
|
122
|
+
data.__dict__ if hasattr(data, "__dict__") else {"value": data}
|
|
123
|
+
)
|
|
124
|
+
return True, validated, []
|
|
125
|
+
|
|
126
|
+
except ValidationError as e:
|
|
127
|
+
violations = _extract_violations(e, node_name, contract_type, data)
|
|
128
|
+
return False, None, violations
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _build_validation_diagnostic(
|
|
132
|
+
violations: list[ViolationContext],
|
|
133
|
+
result: Any,
|
|
134
|
+
) -> Diagnostic:
|
|
135
|
+
"""Build a Diagnostic from output validation violations."""
|
|
136
|
+
errors = [
|
|
137
|
+
f"{v.field_path}: expected {v.expected}, got {repr(v.received)[:100]}"
|
|
138
|
+
for v in violations
|
|
139
|
+
]
|
|
140
|
+
return Diagnostic(
|
|
141
|
+
cause="validation",
|
|
142
|
+
message="Output validation failed",
|
|
143
|
+
errors=errors,
|
|
144
|
+
raw_output=repr(result)[:500] if result is not None else None,
|
|
145
|
+
field_path=violations[0].field_path if violations else None,
|
|
146
|
+
suggestion=violations[0].suggestion if violations else None,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _build_parse_diagnostic(exc: Exception, raw: Any = None) -> Diagnostic:
|
|
151
|
+
"""Build a Diagnostic from a parse error."""
|
|
152
|
+
raw_output = None
|
|
153
|
+
if isinstance(exc, ParseError):
|
|
154
|
+
raw_output = exc.raw_output
|
|
155
|
+
elif raw is not None:
|
|
156
|
+
raw_output = repr(raw)[:500]
|
|
157
|
+
return Diagnostic(
|
|
158
|
+
cause="parse",
|
|
159
|
+
message=str(exc),
|
|
160
|
+
raw_output=raw_output,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def guard(
|
|
165
|
+
input: Type[BaseModel] | None = None,
|
|
166
|
+
output: Type[BaseModel] | None = None,
|
|
167
|
+
*,
|
|
168
|
+
node_name: str | None = None,
|
|
169
|
+
max_attempts: int = 1,
|
|
170
|
+
retry_on: tuple[str, ...] = ("validation", "parse"),
|
|
171
|
+
on_fail: OnFailAction | Callable[[HandoffViolation], Any] = "raise",
|
|
172
|
+
input_param: str | None = "state",
|
|
173
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
174
|
+
"""
|
|
175
|
+
Decorator to validate input/output at agent boundaries.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
input: Pydantic model to validate input against
|
|
179
|
+
output: Pydantic model to validate output against
|
|
180
|
+
node_name: Override the node name (defaults to function name)
|
|
181
|
+
max_attempts: Maximum number of attempts (1 = no retry, default)
|
|
182
|
+
retry_on: Tuple of error types to retry on ("validation", "parse")
|
|
183
|
+
on_fail: What to do on validation failure:
|
|
184
|
+
- "raise": Raise HandoffViolation (default)
|
|
185
|
+
- "return_none": Return None
|
|
186
|
+
- "return_input": Return input unchanged
|
|
187
|
+
- callable: Call with HandoffViolation, return its result
|
|
188
|
+
input_param: Name of the input argument to validate (default: "state")
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
@guard(input=RequestSchema, output=ResponseSchema)
|
|
192
|
+
def my_agent_node(state: dict) -> dict:
|
|
193
|
+
...
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
197
|
+
_node_name = node_name or func.__name__
|
|
198
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
199
|
+
|
|
200
|
+
# Check if function accepts a 'retry' parameter
|
|
201
|
+
sig = inspect.signature(func)
|
|
202
|
+
_accepts_retry = "retry" in sig.parameters
|
|
203
|
+
|
|
204
|
+
def _bind_input(args: tuple, kwargs: dict) -> Any:
|
|
205
|
+
"""Extract the input argument using signature binding."""
|
|
206
|
+
if input_param is None:
|
|
207
|
+
if args:
|
|
208
|
+
return args[0]
|
|
209
|
+
return kwargs.get("state")
|
|
210
|
+
try:
|
|
211
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
212
|
+
except TypeError:
|
|
213
|
+
return kwargs.get(input_param)
|
|
214
|
+
if input_param in bound.arguments:
|
|
215
|
+
return bound.arguments[input_param]
|
|
216
|
+
return kwargs.get(input_param)
|
|
217
|
+
|
|
218
|
+
def _handle_violation(violation: HandoffViolation, input_data: Any) -> Any:
|
|
219
|
+
if on_fail == "raise":
|
|
220
|
+
raise violation
|
|
221
|
+
elif on_fail == "return_none":
|
|
222
|
+
return None
|
|
223
|
+
elif on_fail == "return_input":
|
|
224
|
+
return input_data
|
|
225
|
+
elif callable(on_fail):
|
|
226
|
+
return on_fail(violation)
|
|
227
|
+
else:
|
|
228
|
+
raise violation
|
|
229
|
+
|
|
230
|
+
def _validate_input(
|
|
231
|
+
args: tuple, kwargs: dict
|
|
232
|
+
) -> tuple[Any, list[ViolationContext]]:
|
|
233
|
+
if not input:
|
|
234
|
+
return _bind_input(args, kwargs), []
|
|
235
|
+
|
|
236
|
+
input_data = _bind_input(args, kwargs)
|
|
237
|
+
|
|
238
|
+
success, validated, violations = _validate_data(
|
|
239
|
+
input_data, input, _node_name, "input"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if not success:
|
|
243
|
+
return input_data, violations
|
|
244
|
+
|
|
245
|
+
return input_data, []
|
|
246
|
+
|
|
247
|
+
def _validate_output_data(result: Any) -> list[ViolationContext]:
|
|
248
|
+
if not output:
|
|
249
|
+
return []
|
|
250
|
+
|
|
251
|
+
success, validated, violations = _validate_data(
|
|
252
|
+
result, output, _node_name, "output"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return violations
|
|
256
|
+
|
|
257
|
+
def _is_retryable_parse_error(exc: Exception) -> bool:
|
|
258
|
+
"""Check if exception is a parse-related error eligible for retry."""
|
|
259
|
+
return isinstance(exc, (ParseError, json.JSONDecodeError))
|
|
260
|
+
|
|
261
|
+
if is_async:
|
|
262
|
+
|
|
263
|
+
@wraps(func)
|
|
264
|
+
async def async_wrapper(*args, **kwargs) -> T:
|
|
265
|
+
# Validate input (outside retry loop — input doesn't change)
|
|
266
|
+
input_data, input_violations = _validate_input(args, kwargs)
|
|
267
|
+
if input_violations:
|
|
268
|
+
violation = HandoffViolation(input_violations[0])
|
|
269
|
+
return _handle_violation(violation, input_data)
|
|
270
|
+
|
|
271
|
+
history: list[AttemptRecord] = []
|
|
272
|
+
last_diagnostic: Diagnostic | None = None
|
|
273
|
+
|
|
274
|
+
for attempt_num in range(1, max_attempts + 1):
|
|
275
|
+
state = RetryState(
|
|
276
|
+
attempt=attempt_num,
|
|
277
|
+
max_attempts=max_attempts,
|
|
278
|
+
last_error=last_diagnostic,
|
|
279
|
+
history=list(history),
|
|
280
|
+
)
|
|
281
|
+
token = _retry_context.set(state)
|
|
282
|
+
start_time = time.monotonic()
|
|
283
|
+
try:
|
|
284
|
+
# Inject retry kwarg if function accepts it
|
|
285
|
+
call_kwargs = dict(kwargs)
|
|
286
|
+
if _accepts_retry and "retry" not in call_kwargs:
|
|
287
|
+
call_kwargs["retry"] = state
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
result = await func(*args, **call_kwargs)
|
|
291
|
+
except Exception as exc:
|
|
292
|
+
if (
|
|
293
|
+
_is_retryable_parse_error(exc)
|
|
294
|
+
and "parse" in retry_on
|
|
295
|
+
and max_attempts > 1
|
|
296
|
+
):
|
|
297
|
+
elapsed = (time.monotonic() - start_time) * 1000
|
|
298
|
+
diag = _build_parse_diagnostic(
|
|
299
|
+
exc, getattr(exc, "raw_output", None)
|
|
300
|
+
)
|
|
301
|
+
last_diagnostic = diag
|
|
302
|
+
history.append(
|
|
303
|
+
AttemptRecord(
|
|
304
|
+
attempt=attempt_num,
|
|
305
|
+
diagnostic=diag,
|
|
306
|
+
duration_ms=elapsed,
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
if attempt_num < max_attempts:
|
|
310
|
+
continue
|
|
311
|
+
# Final attempt — build violation
|
|
312
|
+
violation_ctx = ViolationContext(
|
|
313
|
+
node_name=_node_name,
|
|
314
|
+
contract_type="output",
|
|
315
|
+
field_path="<parse>",
|
|
316
|
+
expected="Valid parseable output",
|
|
317
|
+
received=str(exc),
|
|
318
|
+
received_type=type(exc).__name__,
|
|
319
|
+
suggestion="Return valid JSON or structured data",
|
|
320
|
+
)
|
|
321
|
+
violation = HandoffViolation(
|
|
322
|
+
violation_ctx, history=history
|
|
323
|
+
)
|
|
324
|
+
return _handle_violation(violation, input_data)
|
|
325
|
+
else:
|
|
326
|
+
raise
|
|
327
|
+
|
|
328
|
+
# Validate output
|
|
329
|
+
output_violations = _validate_output_data(result)
|
|
330
|
+
elapsed = (time.monotonic() - start_time) * 1000
|
|
331
|
+
if output_violations:
|
|
332
|
+
diag = _build_validation_diagnostic(
|
|
333
|
+
output_violations, result
|
|
334
|
+
)
|
|
335
|
+
last_diagnostic = diag
|
|
336
|
+
history.append(
|
|
337
|
+
AttemptRecord(
|
|
338
|
+
attempt=attempt_num,
|
|
339
|
+
diagnostic=diag,
|
|
340
|
+
duration_ms=elapsed,
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
if "validation" in retry_on and attempt_num < max_attempts:
|
|
344
|
+
continue
|
|
345
|
+
# Final attempt or validation not retried
|
|
346
|
+
violation = HandoffViolation(
|
|
347
|
+
output_violations[0], history=history
|
|
348
|
+
)
|
|
349
|
+
return _handle_violation(violation, input_data)
|
|
350
|
+
|
|
351
|
+
# Success
|
|
352
|
+
history.append(
|
|
353
|
+
AttemptRecord(
|
|
354
|
+
attempt=attempt_num,
|
|
355
|
+
duration_ms=elapsed,
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
return result
|
|
359
|
+
|
|
360
|
+
finally:
|
|
361
|
+
_retry_context.reset(token)
|
|
362
|
+
|
|
363
|
+
return async_wrapper
|
|
364
|
+
|
|
365
|
+
else:
|
|
366
|
+
|
|
367
|
+
@wraps(func)
|
|
368
|
+
def sync_wrapper(*args, **kwargs) -> T:
|
|
369
|
+
# Validate input (outside retry loop — input doesn't change)
|
|
370
|
+
input_data, input_violations = _validate_input(args, kwargs)
|
|
371
|
+
if input_violations:
|
|
372
|
+
violation = HandoffViolation(input_violations[0])
|
|
373
|
+
return _handle_violation(violation, input_data)
|
|
374
|
+
|
|
375
|
+
history: list[AttemptRecord] = []
|
|
376
|
+
last_diagnostic: Diagnostic | None = None
|
|
377
|
+
|
|
378
|
+
for attempt_num in range(1, max_attempts + 1):
|
|
379
|
+
state = RetryState(
|
|
380
|
+
attempt=attempt_num,
|
|
381
|
+
max_attempts=max_attempts,
|
|
382
|
+
last_error=last_diagnostic,
|
|
383
|
+
history=list(history),
|
|
384
|
+
)
|
|
385
|
+
token = _retry_context.set(state)
|
|
386
|
+
start_time = time.monotonic()
|
|
387
|
+
try:
|
|
388
|
+
# Inject retry kwarg if function accepts it
|
|
389
|
+
call_kwargs = dict(kwargs)
|
|
390
|
+
if _accepts_retry and "retry" not in call_kwargs:
|
|
391
|
+
call_kwargs["retry"] = state
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
result = func(*args, **call_kwargs)
|
|
395
|
+
except Exception as exc:
|
|
396
|
+
if (
|
|
397
|
+
_is_retryable_parse_error(exc)
|
|
398
|
+
and "parse" in retry_on
|
|
399
|
+
and max_attempts > 1
|
|
400
|
+
):
|
|
401
|
+
elapsed = (time.monotonic() - start_time) * 1000
|
|
402
|
+
diag = _build_parse_diagnostic(
|
|
403
|
+
exc, getattr(exc, "raw_output", None)
|
|
404
|
+
)
|
|
405
|
+
last_diagnostic = diag
|
|
406
|
+
history.append(
|
|
407
|
+
AttemptRecord(
|
|
408
|
+
attempt=attempt_num,
|
|
409
|
+
diagnostic=diag,
|
|
410
|
+
duration_ms=elapsed,
|
|
411
|
+
)
|
|
412
|
+
)
|
|
413
|
+
if attempt_num < max_attempts:
|
|
414
|
+
continue
|
|
415
|
+
# Final attempt — build violation
|
|
416
|
+
violation_ctx = ViolationContext(
|
|
417
|
+
node_name=_node_name,
|
|
418
|
+
contract_type="output",
|
|
419
|
+
field_path="<parse>",
|
|
420
|
+
expected="Valid parseable output",
|
|
421
|
+
received=str(exc),
|
|
422
|
+
received_type=type(exc).__name__,
|
|
423
|
+
suggestion="Return valid JSON or structured data",
|
|
424
|
+
)
|
|
425
|
+
violation = HandoffViolation(
|
|
426
|
+
violation_ctx, history=history
|
|
427
|
+
)
|
|
428
|
+
return _handle_violation(violation, input_data)
|
|
429
|
+
else:
|
|
430
|
+
raise
|
|
431
|
+
|
|
432
|
+
# Validate output
|
|
433
|
+
output_violations = _validate_output_data(result)
|
|
434
|
+
elapsed = (time.monotonic() - start_time) * 1000
|
|
435
|
+
if output_violations:
|
|
436
|
+
diag = _build_validation_diagnostic(
|
|
437
|
+
output_violations, result
|
|
438
|
+
)
|
|
439
|
+
last_diagnostic = diag
|
|
440
|
+
history.append(
|
|
441
|
+
AttemptRecord(
|
|
442
|
+
attempt=attempt_num,
|
|
443
|
+
diagnostic=diag,
|
|
444
|
+
duration_ms=elapsed,
|
|
445
|
+
)
|
|
446
|
+
)
|
|
447
|
+
if "validation" in retry_on and attempt_num < max_attempts:
|
|
448
|
+
continue
|
|
449
|
+
# Final attempt or validation not retried
|
|
450
|
+
violation = HandoffViolation(
|
|
451
|
+
output_violations[0], history=history
|
|
452
|
+
)
|
|
453
|
+
return _handle_violation(violation, input_data)
|
|
454
|
+
|
|
455
|
+
# Success
|
|
456
|
+
history.append(
|
|
457
|
+
AttemptRecord(
|
|
458
|
+
attempt=attempt_num,
|
|
459
|
+
duration_ms=elapsed,
|
|
460
|
+
)
|
|
461
|
+
)
|
|
462
|
+
return result
|
|
463
|
+
|
|
464
|
+
finally:
|
|
465
|
+
_retry_context.reset(token)
|
|
466
|
+
|
|
467
|
+
return sync_wrapper
|
|
468
|
+
|
|
469
|
+
return decorator
|
handoff/langgraph.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""LangGraph-specific utilities for handoff validation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Type, Callable, TypeVar
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from handoff.guard import guard
|
|
7
|
+
from handoff.core import HandoffViolation
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def guarded_node(
|
|
14
|
+
input: Type[BaseModel] | None = None,
|
|
15
|
+
output: Type[BaseModel] | None = None,
|
|
16
|
+
*,
|
|
17
|
+
max_attempts: int = 1,
|
|
18
|
+
retry_on: tuple[str, ...] = ("validation", "parse"),
|
|
19
|
+
on_fail: str | Callable[[HandoffViolation], Any] = "raise",
|
|
20
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
21
|
+
"""
|
|
22
|
+
LangGraph-specific decorator for node validation.
|
|
23
|
+
|
|
24
|
+
Wraps @guard with LangGraph-friendly defaults.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
from handoff.langgraph import guarded_node
|
|
28
|
+
|
|
29
|
+
class AgentInput(BaseModel):
|
|
30
|
+
messages: list
|
|
31
|
+
context: dict
|
|
32
|
+
|
|
33
|
+
class AgentOutput(BaseModel):
|
|
34
|
+
messages: list
|
|
35
|
+
next_agent: str
|
|
36
|
+
|
|
37
|
+
@guarded_node(input=AgentInput, output=AgentOutput)
|
|
38
|
+
def my_agent(state: dict) -> dict:
|
|
39
|
+
# Your agent logic
|
|
40
|
+
return {"messages": [...], "next_agent": "reviewer"}
|
|
41
|
+
"""
|
|
42
|
+
return guard(
|
|
43
|
+
input=input,
|
|
44
|
+
output=output,
|
|
45
|
+
max_attempts=max_attempts,
|
|
46
|
+
retry_on=retry_on,
|
|
47
|
+
on_fail=on_fail,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def validate_state(
|
|
52
|
+
state: dict | BaseModel,
|
|
53
|
+
schema: Type[BaseModel],
|
|
54
|
+
node_name: str = "unknown",
|
|
55
|
+
) -> BaseModel:
|
|
56
|
+
"""
|
|
57
|
+
Explicitly validate state against a schema.
|
|
58
|
+
|
|
59
|
+
Use this for manual validation points, e.g., before checkpointing.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
HandoffViolation: If validation fails
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
validated = validate_state(state, MyStateSchema, node_name="pre_checkpoint")
|
|
66
|
+
"""
|
|
67
|
+
from handoff.guard import _validate_data
|
|
68
|
+
|
|
69
|
+
success, validated, violations = _validate_data(
|
|
70
|
+
state, schema, node_name, "state"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
if not success:
|
|
74
|
+
raise HandoffViolation(violations[0])
|
|
75
|
+
|
|
76
|
+
return validated
|
handoff/retry.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Retry-with-feedback data structures and async-safe proxy."""
|
|
2
|
+
|
|
3
|
+
from contextvars import ContextVar
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
RetryCause = Literal["validation", "parse"]
|
|
9
|
+
|
|
10
|
+
_RAW_OUTPUT_MAX = 500
|
|
11
|
+
_FEEDBACK_MAX_DEFAULT = 2000
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Diagnostic:
|
|
16
|
+
"""Describes why an attempt failed."""
|
|
17
|
+
|
|
18
|
+
cause: RetryCause
|
|
19
|
+
message: str
|
|
20
|
+
errors: list[str] = field(default_factory=list)
|
|
21
|
+
raw_output: str | None = None
|
|
22
|
+
field_path: str | None = None
|
|
23
|
+
suggestion: str | None = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
if self.raw_output and len(self.raw_output) > _RAW_OUTPUT_MAX:
|
|
27
|
+
self.raw_output = self.raw_output[:_RAW_OUTPUT_MAX]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class AttemptRecord:
|
|
32
|
+
"""Record of a single attempt."""
|
|
33
|
+
|
|
34
|
+
attempt: int
|
|
35
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
36
|
+
diagnostic: Diagnostic | None = None
|
|
37
|
+
duration_ms: float | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class RetryState:
|
|
42
|
+
"""Current retry state passed to functions."""
|
|
43
|
+
|
|
44
|
+
attempt: int = 1
|
|
45
|
+
max_attempts: int = 1
|
|
46
|
+
last_error: Diagnostic | None = None
|
|
47
|
+
history: list[AttemptRecord] = field(default_factory=list)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def remaining(self) -> int:
|
|
51
|
+
return max(0, self.max_attempts - self.attempt)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def is_retry(self) -> bool:
|
|
55
|
+
return self.attempt > 1
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def is_final_attempt(self) -> bool:
|
|
59
|
+
return self.attempt >= self.max_attempts
|
|
60
|
+
|
|
61
|
+
def feedback(self, max_chars: int = _FEEDBACK_MAX_DEFAULT) -> str | None:
|
|
62
|
+
if self.last_error is None:
|
|
63
|
+
return None
|
|
64
|
+
text = _format_diagnostic(self.last_error)
|
|
65
|
+
if len(text) > max_chars:
|
|
66
|
+
text = text[:max_chars]
|
|
67
|
+
return text
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _format_diagnostic(diag: Diagnostic) -> str:
|
|
71
|
+
"""Render a Diagnostic as LLM-friendly text."""
|
|
72
|
+
lines = [
|
|
73
|
+
f"[Retry] Previous attempt failed ({diag.cause}):",
|
|
74
|
+
f" Message: {diag.message}",
|
|
75
|
+
]
|
|
76
|
+
if diag.errors:
|
|
77
|
+
lines.append(" Errors:")
|
|
78
|
+
for err in diag.errors:
|
|
79
|
+
lines.append(f" - {err}")
|
|
80
|
+
if diag.field_path:
|
|
81
|
+
lines.append(f" Field: {diag.field_path}")
|
|
82
|
+
if diag.suggestion:
|
|
83
|
+
lines.append(f" Suggestion: {diag.suggestion}")
|
|
84
|
+
if diag.raw_output:
|
|
85
|
+
lines.append(f" Raw output: {diag.raw_output}")
|
|
86
|
+
return "\n".join(lines)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
_retry_context: ContextVar[RetryState | None] = ContextVar(
|
|
90
|
+
"_retry_context", default=None
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _RetryProxy:
|
|
95
|
+
"""Reads from _retry_context, exposes RetryState interface with safe defaults."""
|
|
96
|
+
|
|
97
|
+
def _get(self) -> RetryState | None:
|
|
98
|
+
return _retry_context.get(None)
|
|
99
|
+
|
|
100
|
+
def get(self) -> RetryState | None:
|
|
101
|
+
return self._get()
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def attempt(self) -> int:
|
|
105
|
+
state = self._get()
|
|
106
|
+
return state.attempt if state else 1
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def max_attempts(self) -> int:
|
|
110
|
+
state = self._get()
|
|
111
|
+
return state.max_attempts if state else 1
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def remaining(self) -> int:
|
|
115
|
+
state = self._get()
|
|
116
|
+
return state.remaining if state else 0
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def is_retry(self) -> bool:
|
|
120
|
+
state = self._get()
|
|
121
|
+
return state.is_retry if state else False
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_final_attempt(self) -> bool:
|
|
125
|
+
state = self._get()
|
|
126
|
+
return state.is_final_attempt if state else True
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def last_error(self) -> Diagnostic | None:
|
|
130
|
+
state = self._get()
|
|
131
|
+
return state.last_error if state else None
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def history(self) -> list[AttemptRecord]:
|
|
135
|
+
state = self._get()
|
|
136
|
+
return state.history if state else []
|
|
137
|
+
|
|
138
|
+
def feedback(self, max_chars: int = _FEEDBACK_MAX_DEFAULT) -> str | None:
|
|
139
|
+
state = self._get()
|
|
140
|
+
if state is None:
|
|
141
|
+
return None
|
|
142
|
+
return state.feedback(max_chars)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
retry = _RetryProxy()
|
handoff/testing.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Test utilities for handoff retry."""
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
|
|
5
|
+
from handoff.retry import (
|
|
6
|
+
_retry_context,
|
|
7
|
+
RetryState,
|
|
8
|
+
Diagnostic,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@contextmanager
|
|
13
|
+
def mock_retry(
|
|
14
|
+
attempt: int = 2,
|
|
15
|
+
max_attempts: int = 3,
|
|
16
|
+
last_error: Diagnostic | None = None,
|
|
17
|
+
feedback_text: str | None = None,
|
|
18
|
+
):
|
|
19
|
+
"""Context manager that sets retry state for testing.
|
|
20
|
+
|
|
21
|
+
If feedback_text is provided without last_error, creates a simple
|
|
22
|
+
validation Diagnostic with that text as the message.
|
|
23
|
+
"""
|
|
24
|
+
if feedback_text and last_error is None:
|
|
25
|
+
last_error = Diagnostic(
|
|
26
|
+
cause="validation",
|
|
27
|
+
message=feedback_text,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
state = RetryState(
|
|
31
|
+
attempt=attempt,
|
|
32
|
+
max_attempts=max_attempts,
|
|
33
|
+
last_error=last_error,
|
|
34
|
+
)
|
|
35
|
+
token = _retry_context.set(state)
|
|
36
|
+
try:
|
|
37
|
+
yield state
|
|
38
|
+
finally:
|
|
39
|
+
_retry_context.reset(token)
|
handoff/utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Utility functions for handoff."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ParseError(Exception):
|
|
7
|
+
"""Raised when output cannot be parsed as JSON."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, message: str, raw_output: str | None = None):
|
|
10
|
+
self.raw_output = raw_output
|
|
11
|
+
super().__init__(message)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def parse_json(text: str) -> dict:
|
|
15
|
+
"""Parse JSON from text, handling common LLM output quirks.
|
|
16
|
+
|
|
17
|
+
Strips UTF-8 BOM and markdown code fences before parsing.
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
ParseError: If text cannot be parsed as JSON.
|
|
21
|
+
"""
|
|
22
|
+
if not isinstance(text, str):
|
|
23
|
+
raise ParseError(
|
|
24
|
+
f"Expected string, got {type(text).__name__}",
|
|
25
|
+
raw_output=repr(text)[:500],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Strip UTF-8 BOM
|
|
29
|
+
cleaned = text.lstrip("\ufeff")
|
|
30
|
+
|
|
31
|
+
# Strip markdown code fences
|
|
32
|
+
stripped = cleaned.strip()
|
|
33
|
+
if stripped.startswith("```"):
|
|
34
|
+
lines = stripped.split("\n", 1)
|
|
35
|
+
if len(lines) > 1:
|
|
36
|
+
body = lines[1]
|
|
37
|
+
else:
|
|
38
|
+
body = ""
|
|
39
|
+
if body.rstrip().endswith("```"):
|
|
40
|
+
body = body.rstrip()[: -len("```")]
|
|
41
|
+
stripped = body.strip()
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
return json.loads(stripped)
|
|
45
|
+
except json.JSONDecodeError as e:
|
|
46
|
+
raise ParseError(
|
|
47
|
+
f"Failed to parse JSON: {e}",
|
|
48
|
+
raw_output=text[:500],
|
|
49
|
+
) from e
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: handoff-guard
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Lightweight validation at agent boundaries. Know what broke and where.
|
|
5
|
+
Project-URL: Homepage, https://github.com/acartag7/handoff-guard
|
|
6
|
+
Project-URL: Repository, https://github.com/acartag7/handoff-guard
|
|
7
|
+
Author-email: Arnold <cartagena.arnold@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: agents,handoff,langgraph,multi-agent,validation
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Requires-Dist: pydantic>=2.0.0
|
|
19
|
+
Provides-Extra: dev
|
|
20
|
+
Requires-Dist: pre-commit>=3.0.0; extra == 'dev'
|
|
21
|
+
Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
|
|
22
|
+
Requires-Dist: pytest>=8.0.0; extra == 'dev'
|
|
23
|
+
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
24
|
+
Provides-Extra: langgraph
|
|
25
|
+
Requires-Dist: langgraph>=0.2.0; extra == 'langgraph'
|
|
26
|
+
Provides-Extra: llm
|
|
27
|
+
Requires-Dist: httpx>=0.25.0; extra == 'llm'
|
|
28
|
+
Requires-Dist: langgraph>=0.2.0; extra == 'llm'
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
|
|
31
|
+
# handoff-guard
|
|
32
|
+
|
|
33
|
+
> Validation for LLM agents that retries with feedback.
|
|
34
|
+
|
|
35
|
+
[](https://badge.fury.io/py/handoff-guard)
|
|
36
|
+
[](https://opensource.org/licenses/MIT)
|
|
37
|
+
|
|
38
|
+
## The Problem
|
|
39
|
+
|
|
40
|
+
When an LLM agent returns bad output, you get a generic error and no recovery path:
|
|
41
|
+
|
|
42
|
+
```
|
|
43
|
+
ValidationError: 1 validation error for State
|
|
44
|
+
field required (type=value_error.missing)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Which node? Which field? What was passed? Can the agent fix it?
|
|
48
|
+
|
|
49
|
+
## The Solution
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
from handoff import guard, retry, parse_json
|
|
53
|
+
from pydantic import BaseModel, Field
|
|
54
|
+
|
|
55
|
+
class WriterOutput(BaseModel):
|
|
56
|
+
draft: str = Field(min_length=100)
|
|
57
|
+
word_count: int = Field(ge=50)
|
|
58
|
+
tone: str
|
|
59
|
+
title: str
|
|
60
|
+
|
|
61
|
+
@guard(output=WriterOutput, node_name="writer", max_attempts=3)
|
|
62
|
+
def writer_agent(state: dict) -> dict:
|
|
63
|
+
prompt = "Write a JSON response with: draft, word_count, tone, title."
|
|
64
|
+
|
|
65
|
+
if retry.is_retry:
|
|
66
|
+
prompt += f"\n\nYour previous attempt failed:\n{retry.feedback()}"
|
|
67
|
+
|
|
68
|
+
response = call_llm(prompt)
|
|
69
|
+
return parse_json(response)
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
When validation fails, the agent retries with feedback about what went wrong. After all attempts are exhausted:
|
|
73
|
+
|
|
74
|
+
```
|
|
75
|
+
HandoffViolation in 'writer' (attempt 3/3):
|
|
76
|
+
Contract: output
|
|
77
|
+
Field: draft
|
|
78
|
+
Expected: String should have at least 100 characters
|
|
79
|
+
Suggestion: Increase the length of 'draft'
|
|
80
|
+
History: 3 failed attempts
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## Quick Start
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
pip install handoff-guard
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
# See retry-with-feedback in action (no API key needed)
|
|
91
|
+
python -m examples.llm_demo.run_demo
|
|
92
|
+
|
|
93
|
+
# Run with real LLM calls
|
|
94
|
+
export OPENROUTER_API_KEY=your_key
|
|
95
|
+
python -m examples.llm_demo.run_demo --pipeline --api
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Features
|
|
99
|
+
|
|
100
|
+
- **Retry with feedback** — Failed outputs are fed back to the agent as context
|
|
101
|
+
- **Know which node failed** — No more guessing from stack traces
|
|
102
|
+
- **Know which field failed** — Exact path to the problem
|
|
103
|
+
- **Get fix suggestions** — Actionable error messages
|
|
104
|
+
- **`parse_json`** — Strips code fences, handles BOM, raises `ParseError` on failure
|
|
105
|
+
- **Framework agnostic** — Works with LangGraph, CrewAI, or plain Python
|
|
106
|
+
- **Lightweight** — Just Pydantic, no Docker, no telemetry servers
|
|
107
|
+
|
|
108
|
+
## API
|
|
109
|
+
|
|
110
|
+
### `@guard` decorator
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
@guard(
|
|
114
|
+
input=InputSchema, # Pydantic model for input validation
|
|
115
|
+
output=OutputSchema, # Pydantic model for output validation
|
|
116
|
+
node_name="my_node", # Identifies the node in errors (default: function name)
|
|
117
|
+
max_attempts=3, # Retry up to 3 times (default: 1, no retry)
|
|
118
|
+
retry_on=("validation", "parse"), # What errors trigger retry (default)
|
|
119
|
+
on_fail="raise", # "raise" | "return_none" | "return_input" | callable
|
|
120
|
+
)
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
### `retry` proxy
|
|
124
|
+
|
|
125
|
+
Access retry state inside any guarded function:
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
from handoff import retry
|
|
129
|
+
|
|
130
|
+
retry.is_retry # True if attempt > 1
|
|
131
|
+
retry.attempt # Current attempt number
|
|
132
|
+
retry.max_attempts # Total allowed attempts
|
|
133
|
+
retry.remaining # Attempts left
|
|
134
|
+
retry.is_final_attempt
|
|
135
|
+
retry.feedback() # Formatted string describing last error, or None
|
|
136
|
+
retry.last_error # Diagnostic object, or None
|
|
137
|
+
retry.history # List of AttemptRecord objects
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
### `parse_json`
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
from handoff import parse_json
|
|
144
|
+
|
|
145
|
+
data = parse_json('```json\n{"key": "value"}\n```')
|
|
146
|
+
# Returns: {"key": "value"}
|
|
147
|
+
# Raises ParseError on failure (retryable by @guard)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
### `HandoffViolation`
|
|
151
|
+
|
|
152
|
+
Raised when all retry attempts are exhausted:
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
from handoff import HandoffViolation
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
result = my_agent(state)
|
|
159
|
+
except HandoffViolation as e:
|
|
160
|
+
print(e.node_name) # "writer"
|
|
161
|
+
print(e.total_attempts) # 3
|
|
162
|
+
print(e.history) # List of AttemptRecord with diagnostics
|
|
163
|
+
print(e.to_dict()) # Serializable for logging
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
### Handle Failures
|
|
167
|
+
|
|
168
|
+
```python
|
|
169
|
+
@guard(output=Schema, on_fail="raise") # Raise exception (default)
|
|
170
|
+
@guard(output=Schema, on_fail="return_none") # Return None on failure
|
|
171
|
+
@guard(output=Schema, on_fail="return_input") # Return input unchanged
|
|
172
|
+
@guard(output=Schema, on_fail=my_handler) # Custom handler
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
## Examples
|
|
176
|
+
|
|
177
|
+
| Demo | What it shows |
|
|
178
|
+
|------|---------------|
|
|
179
|
+
| [`examples/llm_demo`](examples/llm_demo/) | Retry-with-feedback: writer fails, gets feedback, self-corrects |
|
|
180
|
+
| [`examples/rag_demo`](examples/rag_demo/) | Multi-stage pipeline validation + hallucinated citation detection |
|
|
181
|
+
|
|
182
|
+
Both demos support `--api` for real LLM calls and run with mock data by default.
|
|
183
|
+
|
|
184
|
+
## With LangGraph
|
|
185
|
+
|
|
186
|
+
```python
|
|
187
|
+
from handoff.langgraph import guarded_node
|
|
188
|
+
from pydantic import BaseModel, Field
|
|
189
|
+
|
|
190
|
+
class RouterOutput(BaseModel):
|
|
191
|
+
next_agent: str = Field(pattern="^(writer|reviewer|done)$")
|
|
192
|
+
messages: list
|
|
193
|
+
|
|
194
|
+
@guarded_node(output=RouterOutput)
|
|
195
|
+
def router(state: dict) -> dict:
|
|
196
|
+
return {
|
|
197
|
+
"next_agent": "writer",
|
|
198
|
+
"messages": state["messages"]
|
|
199
|
+
}
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
## Why not just use Pydantic directly?
|
|
203
|
+
|
|
204
|
+
You should! Handoff uses Pydantic under the hood.
|
|
205
|
+
|
|
206
|
+
The difference:
|
|
207
|
+
|
|
208
|
+
| Pydantic alone | Handoff |
|
|
209
|
+
|----------------|---------|
|
|
210
|
+
| `ValidationError: 1 validation error` | `HandoffViolation in 'router_node'` |
|
|
211
|
+
| Generic stack trace | Exact node + field + suggestion |
|
|
212
|
+
| You wire up validation manually | One decorator |
|
|
213
|
+
| No retry | Automatic retry with feedback |
|
|
214
|
+
| Errors are for developers | Errors are actionable for agents |
|
|
215
|
+
|
|
216
|
+
## Roadmap
|
|
217
|
+
|
|
218
|
+
- [ ] Invariant contracts (input/output relationships)
|
|
219
|
+
- [ ] CrewAI adapter
|
|
220
|
+
- [x] Retry with feedback loop
|
|
221
|
+
- [ ] VS Code extension for violation inspection
|
|
222
|
+
|
|
223
|
+
## Contributing
|
|
224
|
+
|
|
225
|
+
Contributions welcome! Please open an issue first to discuss what you'd like to change.
|
|
226
|
+
|
|
227
|
+
## License
|
|
228
|
+
|
|
229
|
+
MIT
|
|
230
|
+
|
|
231
|
+
---
|
|
232
|
+
|
|
233
|
+
Built for developers who are tired of debugging agent handoffs.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
handoff/__init__.py,sha256=RDO2wqA1zROrtM2GE_B_1pLA4OOClCzzjMARaorB-Oo,449
|
|
2
|
+
handoff/core.py,sha256=zVGKyLO7TxwoJS_J5n-c47OdQGUmrQJjIvRW97tInBs,3173
|
|
3
|
+
handoff/guard.py,sha256=Mf3Nj7b8MC_WHcvMhzaVWFTSTGjTPHu64Jz2AeTnAvI,18752
|
|
4
|
+
handoff/langgraph.py,sha256=MLm4G2uZWiE30lFKU99DEA_LjlSHzTvb5Ci7urFlyF8,1931
|
|
5
|
+
handoff/retry.py,sha256=QVuPAygS3hEEXKS9h073uW4N0erURXiyIYFllsaR1YY,3891
|
|
6
|
+
handoff/testing.py,sha256=WlsEKyOX8_M8cfUc3xbcoMFU4Pyn-j1ffD8xPigAQ_Q,915
|
|
7
|
+
handoff/utils.py,sha256=c0nN5AAXgx3mbVKa8jHFMqkor0qtchpC0Jc_e_yjzTE,1302
|
|
8
|
+
handoff_guard-0.1.0.dist-info/METADATA,sha256=g8B526tyQGMVMRcA3isJGJwTu2PkSw9PKMFjdmF45xo,6969
|
|
9
|
+
handoff_guard-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
+
handoff_guard-0.1.0.dist-info/licenses/LICENSE,sha256=HroDuVS_iz3RlUAJhww1Ma4olb4onXHg1MXaIZCb06s,1063
|
|
11
|
+
handoff_guard-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Arnold
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|