handoff-guard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
handoff/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ from handoff.core import HandoffViolation, ViolationContext
2
+ from handoff.guard import guard, GuardConfig
3
+ from handoff.retry import retry, RetryState, Diagnostic, AttemptRecord
4
+ from handoff.utils import ParseError, parse_json
5
+
6
+ __all__ = [
7
+ "guard",
8
+ "GuardConfig",
9
+ "HandoffViolation",
10
+ "ViolationContext",
11
+ "retry",
12
+ "RetryState",
13
+ "Diagnostic",
14
+ "AttemptRecord",
15
+ "ParseError",
16
+ "parse_json",
17
+ ]
18
+
19
+ __version__ = "0.2.0"
handoff/core.py ADDED
@@ -0,0 +1,95 @@
1
+ """Core types for handoff validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, TYPE_CHECKING
7
+ from datetime import datetime, timezone
8
+
9
+ if TYPE_CHECKING:
10
+ from handoff.retry import AttemptRecord
11
+
12
+
13
+ @dataclass
14
+ class ViolationContext:
15
+ """Rich context about where and why validation failed."""
16
+
17
+ node_name: str
18
+ contract_type: str # "input" | "output" | "invariant"
19
+ field_path: str # e.g., "response.refund_id"
20
+ expected: str # Human-readable expectation
21
+ received: Any # What we actually got
22
+ received_type: str # Type of what we got
23
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
24
+ upstream_node: str | None = None
25
+ suggestion: str | None = None
26
+
27
+ def __str__(self) -> str:
28
+ lines = [
29
+ f"HandoffViolation in '{self.node_name}':",
30
+ f" Contract: {self.contract_type}",
31
+ f" Field: {self.field_path}",
32
+ f" Expected: {self.expected}",
33
+ f" Received: {repr(self.received)[:100]} ({self.received_type})",
34
+ ]
35
+ if self.upstream_node:
36
+ lines.append(f" Upstream: {self.upstream_node}")
37
+ if self.suggestion:
38
+ lines.append(f" Suggestion: {self.suggestion}")
39
+ return "\n".join(lines)
40
+
41
+
42
+ class HandoffViolation(Exception):
43
+ """Raised when validation fails at an agent boundary."""
44
+
45
+ def __init__(
46
+ self,
47
+ context: ViolationContext,
48
+ history: list[AttemptRecord] | None = None,
49
+ ):
50
+ self.context = context
51
+ self.history: list[AttemptRecord] = history or []
52
+ super().__init__(str(context))
53
+
54
+ @property
55
+ def node_name(self) -> str:
56
+ return self.context.node_name
57
+
58
+ @property
59
+ def field_path(self) -> str:
60
+ return self.context.field_path
61
+
62
+ @property
63
+ def total_attempts(self) -> int:
64
+ return len(self.history) if self.history else 1
65
+
66
+ def to_dict(self) -> dict:
67
+ """Serialize for logging/telemetry."""
68
+ d = {
69
+ "node_name": self.context.node_name,
70
+ "contract_type": self.context.contract_type,
71
+ "field_path": self.context.field_path,
72
+ "expected": self.context.expected,
73
+ "received": repr(self.context.received)[:200],
74
+ "received_type": self.context.received_type,
75
+ "timestamp": self.context.timestamp.isoformat(),
76
+ "upstream_node": self.context.upstream_node,
77
+ "suggestion": self.context.suggestion,
78
+ }
79
+ if self.history:
80
+ d["total_attempts"] = self.total_attempts
81
+ d["history"] = [
82
+ {
83
+ "attempt": rec.attempt,
84
+ "timestamp": rec.timestamp.isoformat(),
85
+ "duration_ms": rec.duration_ms,
86
+ "diagnostic": {
87
+ "cause": rec.diagnostic.cause,
88
+ "message": rec.diagnostic.message,
89
+ }
90
+ if rec.diagnostic
91
+ else None,
92
+ }
93
+ for rec in self.history
94
+ ]
95
+ return d
handoff/guard.py ADDED
@@ -0,0 +1,469 @@
1
+ """The @guard decorator for validating agent boundaries."""
2
+
3
+ import json
4
+ import time
5
+ from dataclasses import dataclass
6
+ from functools import wraps
7
+ from typing import Any, Callable, TypeVar, Type, Literal
8
+ import inspect
9
+
10
+ from pydantic import BaseModel, ValidationError
11
+
12
+ from handoff.core import HandoffViolation, ViolationContext
13
+ from handoff.retry import (
14
+ _retry_context,
15
+ RetryState,
16
+ Diagnostic,
17
+ AttemptRecord,
18
+ )
19
+ from handoff.utils import ParseError
20
+
21
+
22
+ T = TypeVar("T")
23
+ OnFailAction = Literal["raise", "return_none", "return_input"]
24
+
25
+
26
+ @dataclass
27
+ class GuardConfig:
28
+ """Configuration for guard behavior."""
29
+
30
+ input_schema: Type[BaseModel] | None = None
31
+ output_schema: Type[BaseModel] | None = None
32
+ node_name: str | None = None # Auto-detected from function name if not provided
33
+ on_fail: OnFailAction | Callable[[HandoffViolation], Any] = "raise"
34
+ validate_input: bool = True
35
+ validate_output: bool = True
36
+
37
+
38
+ def _generate_suggestion(
39
+ err_type: str, field_path: str, contract_type: str
40
+ ) -> str | None:
41
+ """Generate a helpful suggestion based on Pydantic error type."""
42
+ if err_type == "missing":
43
+ return f"Add '{field_path}' to the {contract_type} data"
44
+ elif err_type == "string_type":
45
+ return f"Convert '{field_path}' to string"
46
+ elif err_type == "int_type":
47
+ return f"Convert '{field_path}' to integer"
48
+ elif err_type == "string_too_short":
49
+ return f"Increase the length of '{field_path}'"
50
+ elif err_type == "string_too_long":
51
+ return f"Reduce the length of '{field_path}'"
52
+ elif err_type == "too_short":
53
+ return f"Add more items to '{field_path}'"
54
+ elif err_type == "too_long":
55
+ return f"Reduce the number of items in '{field_path}'"
56
+ elif err_type == "greater_than_equal":
57
+ return f"Increase the value of '{field_path}'"
58
+ elif err_type == "less_than_equal":
59
+ return f"Decrease the value of '{field_path}'"
60
+ elif err_type == "string_pattern_mismatch":
61
+ return f"'{field_path}' does not match the required pattern"
62
+ return None
63
+
64
+
65
+ def _extract_violations(
66
+ error: ValidationError,
67
+ node_name: str,
68
+ contract_type: str,
69
+ raw_data: Any,
70
+ ) -> list[ViolationContext]:
71
+ """Convert Pydantic ValidationError to rich ViolationContext objects."""
72
+
73
+ violations = []
74
+ for err in error.errors():
75
+ field_path = ".".join(str(loc) for loc in err["loc"])
76
+
77
+ # Try to get the actual value at the path
78
+ received = raw_data
79
+ for loc in err["loc"]:
80
+ if isinstance(received, dict):
81
+ received = received.get(loc, "<missing>")
82
+ elif hasattr(received, str(loc)):
83
+ received = getattr(received, str(loc), "<missing>")
84
+ else:
85
+ received = "<missing>"
86
+ break
87
+
88
+ suggestion = _generate_suggestion(err["type"], field_path, contract_type)
89
+
90
+ violations.append(
91
+ ViolationContext(
92
+ node_name=node_name,
93
+ contract_type=contract_type,
94
+ field_path=field_path or "<root>",
95
+ expected=err["msg"],
96
+ received=received,
97
+ received_type=type(received).__name__,
98
+ suggestion=suggestion,
99
+ )
100
+ )
101
+
102
+ return violations
103
+
104
+
105
+ def _validate_data(
106
+ data: Any,
107
+ schema: Type[BaseModel],
108
+ node_name: str,
109
+ contract_type: str,
110
+ ) -> tuple[bool, BaseModel | None, list[ViolationContext]]:
111
+ """Validate data against schema, return (success, validated_model, violations)."""
112
+
113
+ try:
114
+ # Handle both dict and BaseModel inputs
115
+ if isinstance(data, BaseModel):
116
+ validated = schema.model_validate(data.model_dump())
117
+ elif isinstance(data, dict):
118
+ validated = schema.model_validate(data)
119
+ else:
120
+ # Try to convert to dict
121
+ validated = schema.model_validate(
122
+ data.__dict__ if hasattr(data, "__dict__") else {"value": data}
123
+ )
124
+ return True, validated, []
125
+
126
+ except ValidationError as e:
127
+ violations = _extract_violations(e, node_name, contract_type, data)
128
+ return False, None, violations
129
+
130
+
131
+ def _build_validation_diagnostic(
132
+ violations: list[ViolationContext],
133
+ result: Any,
134
+ ) -> Diagnostic:
135
+ """Build a Diagnostic from output validation violations."""
136
+ errors = [
137
+ f"{v.field_path}: expected {v.expected}, got {repr(v.received)[:100]}"
138
+ for v in violations
139
+ ]
140
+ return Diagnostic(
141
+ cause="validation",
142
+ message="Output validation failed",
143
+ errors=errors,
144
+ raw_output=repr(result)[:500] if result is not None else None,
145
+ field_path=violations[0].field_path if violations else None,
146
+ suggestion=violations[0].suggestion if violations else None,
147
+ )
148
+
149
+
150
+ def _build_parse_diagnostic(exc: Exception, raw: Any = None) -> Diagnostic:
151
+ """Build a Diagnostic from a parse error."""
152
+ raw_output = None
153
+ if isinstance(exc, ParseError):
154
+ raw_output = exc.raw_output
155
+ elif raw is not None:
156
+ raw_output = repr(raw)[:500]
157
+ return Diagnostic(
158
+ cause="parse",
159
+ message=str(exc),
160
+ raw_output=raw_output,
161
+ )
162
+
163
+
164
+ def guard(
165
+ input: Type[BaseModel] | None = None,
166
+ output: Type[BaseModel] | None = None,
167
+ *,
168
+ node_name: str | None = None,
169
+ max_attempts: int = 1,
170
+ retry_on: tuple[str, ...] = ("validation", "parse"),
171
+ on_fail: OnFailAction | Callable[[HandoffViolation], Any] = "raise",
172
+ input_param: str | None = "state",
173
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
174
+ """
175
+ Decorator to validate input/output at agent boundaries.
176
+
177
+ Args:
178
+ input: Pydantic model to validate input against
179
+ output: Pydantic model to validate output against
180
+ node_name: Override the node name (defaults to function name)
181
+ max_attempts: Maximum number of attempts (1 = no retry, default)
182
+ retry_on: Tuple of error types to retry on ("validation", "parse")
183
+ on_fail: What to do on validation failure:
184
+ - "raise": Raise HandoffViolation (default)
185
+ - "return_none": Return None
186
+ - "return_input": Return input unchanged
187
+ - callable: Call with HandoffViolation, return its result
188
+ input_param: Name of the input argument to validate (default: "state")
189
+
190
+ Example:
191
+ @guard(input=RequestSchema, output=ResponseSchema)
192
+ def my_agent_node(state: dict) -> dict:
193
+ ...
194
+ """
195
+
196
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
197
+ _node_name = node_name or func.__name__
198
+ is_async = inspect.iscoroutinefunction(func)
199
+
200
+ # Check if function accepts a 'retry' parameter
201
+ sig = inspect.signature(func)
202
+ _accepts_retry = "retry" in sig.parameters
203
+
204
+ def _bind_input(args: tuple, kwargs: dict) -> Any:
205
+ """Extract the input argument using signature binding."""
206
+ if input_param is None:
207
+ if args:
208
+ return args[0]
209
+ return kwargs.get("state")
210
+ try:
211
+ bound = sig.bind_partial(*args, **kwargs)
212
+ except TypeError:
213
+ return kwargs.get(input_param)
214
+ if input_param in bound.arguments:
215
+ return bound.arguments[input_param]
216
+ return kwargs.get(input_param)
217
+
218
+ def _handle_violation(violation: HandoffViolation, input_data: Any) -> Any:
219
+ if on_fail == "raise":
220
+ raise violation
221
+ elif on_fail == "return_none":
222
+ return None
223
+ elif on_fail == "return_input":
224
+ return input_data
225
+ elif callable(on_fail):
226
+ return on_fail(violation)
227
+ else:
228
+ raise violation
229
+
230
+ def _validate_input(
231
+ args: tuple, kwargs: dict
232
+ ) -> tuple[Any, list[ViolationContext]]:
233
+ if not input:
234
+ return _bind_input(args, kwargs), []
235
+
236
+ input_data = _bind_input(args, kwargs)
237
+
238
+ success, validated, violations = _validate_data(
239
+ input_data, input, _node_name, "input"
240
+ )
241
+
242
+ if not success:
243
+ return input_data, violations
244
+
245
+ return input_data, []
246
+
247
+ def _validate_output_data(result: Any) -> list[ViolationContext]:
248
+ if not output:
249
+ return []
250
+
251
+ success, validated, violations = _validate_data(
252
+ result, output, _node_name, "output"
253
+ )
254
+
255
+ return violations
256
+
257
+ def _is_retryable_parse_error(exc: Exception) -> bool:
258
+ """Check if exception is a parse-related error eligible for retry."""
259
+ return isinstance(exc, (ParseError, json.JSONDecodeError))
260
+
261
+ if is_async:
262
+
263
+ @wraps(func)
264
+ async def async_wrapper(*args, **kwargs) -> T:
265
+ # Validate input (outside retry loop — input doesn't change)
266
+ input_data, input_violations = _validate_input(args, kwargs)
267
+ if input_violations:
268
+ violation = HandoffViolation(input_violations[0])
269
+ return _handle_violation(violation, input_data)
270
+
271
+ history: list[AttemptRecord] = []
272
+ last_diagnostic: Diagnostic | None = None
273
+
274
+ for attempt_num in range(1, max_attempts + 1):
275
+ state = RetryState(
276
+ attempt=attempt_num,
277
+ max_attempts=max_attempts,
278
+ last_error=last_diagnostic,
279
+ history=list(history),
280
+ )
281
+ token = _retry_context.set(state)
282
+ start_time = time.monotonic()
283
+ try:
284
+ # Inject retry kwarg if function accepts it
285
+ call_kwargs = dict(kwargs)
286
+ if _accepts_retry and "retry" not in call_kwargs:
287
+ call_kwargs["retry"] = state
288
+
289
+ try:
290
+ result = await func(*args, **call_kwargs)
291
+ except Exception as exc:
292
+ if (
293
+ _is_retryable_parse_error(exc)
294
+ and "parse" in retry_on
295
+ and max_attempts > 1
296
+ ):
297
+ elapsed = (time.monotonic() - start_time) * 1000
298
+ diag = _build_parse_diagnostic(
299
+ exc, getattr(exc, "raw_output", None)
300
+ )
301
+ last_diagnostic = diag
302
+ history.append(
303
+ AttemptRecord(
304
+ attempt=attempt_num,
305
+ diagnostic=diag,
306
+ duration_ms=elapsed,
307
+ )
308
+ )
309
+ if attempt_num < max_attempts:
310
+ continue
311
+ # Final attempt — build violation
312
+ violation_ctx = ViolationContext(
313
+ node_name=_node_name,
314
+ contract_type="output",
315
+ field_path="<parse>",
316
+ expected="Valid parseable output",
317
+ received=str(exc),
318
+ received_type=type(exc).__name__,
319
+ suggestion="Return valid JSON or structured data",
320
+ )
321
+ violation = HandoffViolation(
322
+ violation_ctx, history=history
323
+ )
324
+ return _handle_violation(violation, input_data)
325
+ else:
326
+ raise
327
+
328
+ # Validate output
329
+ output_violations = _validate_output_data(result)
330
+ elapsed = (time.monotonic() - start_time) * 1000
331
+ if output_violations:
332
+ diag = _build_validation_diagnostic(
333
+ output_violations, result
334
+ )
335
+ last_diagnostic = diag
336
+ history.append(
337
+ AttemptRecord(
338
+ attempt=attempt_num,
339
+ diagnostic=diag,
340
+ duration_ms=elapsed,
341
+ )
342
+ )
343
+ if "validation" in retry_on and attempt_num < max_attempts:
344
+ continue
345
+ # Final attempt or validation not retried
346
+ violation = HandoffViolation(
347
+ output_violations[0], history=history
348
+ )
349
+ return _handle_violation(violation, input_data)
350
+
351
+ # Success
352
+ history.append(
353
+ AttemptRecord(
354
+ attempt=attempt_num,
355
+ duration_ms=elapsed,
356
+ )
357
+ )
358
+ return result
359
+
360
+ finally:
361
+ _retry_context.reset(token)
362
+
363
+ return async_wrapper
364
+
365
+ else:
366
+
367
+ @wraps(func)
368
+ def sync_wrapper(*args, **kwargs) -> T:
369
+ # Validate input (outside retry loop — input doesn't change)
370
+ input_data, input_violations = _validate_input(args, kwargs)
371
+ if input_violations:
372
+ violation = HandoffViolation(input_violations[0])
373
+ return _handle_violation(violation, input_data)
374
+
375
+ history: list[AttemptRecord] = []
376
+ last_diagnostic: Diagnostic | None = None
377
+
378
+ for attempt_num in range(1, max_attempts + 1):
379
+ state = RetryState(
380
+ attempt=attempt_num,
381
+ max_attempts=max_attempts,
382
+ last_error=last_diagnostic,
383
+ history=list(history),
384
+ )
385
+ token = _retry_context.set(state)
386
+ start_time = time.monotonic()
387
+ try:
388
+ # Inject retry kwarg if function accepts it
389
+ call_kwargs = dict(kwargs)
390
+ if _accepts_retry and "retry" not in call_kwargs:
391
+ call_kwargs["retry"] = state
392
+
393
+ try:
394
+ result = func(*args, **call_kwargs)
395
+ except Exception as exc:
396
+ if (
397
+ _is_retryable_parse_error(exc)
398
+ and "parse" in retry_on
399
+ and max_attempts > 1
400
+ ):
401
+ elapsed = (time.monotonic() - start_time) * 1000
402
+ diag = _build_parse_diagnostic(
403
+ exc, getattr(exc, "raw_output", None)
404
+ )
405
+ last_diagnostic = diag
406
+ history.append(
407
+ AttemptRecord(
408
+ attempt=attempt_num,
409
+ diagnostic=diag,
410
+ duration_ms=elapsed,
411
+ )
412
+ )
413
+ if attempt_num < max_attempts:
414
+ continue
415
+ # Final attempt — build violation
416
+ violation_ctx = ViolationContext(
417
+ node_name=_node_name,
418
+ contract_type="output",
419
+ field_path="<parse>",
420
+ expected="Valid parseable output",
421
+ received=str(exc),
422
+ received_type=type(exc).__name__,
423
+ suggestion="Return valid JSON or structured data",
424
+ )
425
+ violation = HandoffViolation(
426
+ violation_ctx, history=history
427
+ )
428
+ return _handle_violation(violation, input_data)
429
+ else:
430
+ raise
431
+
432
+ # Validate output
433
+ output_violations = _validate_output_data(result)
434
+ elapsed = (time.monotonic() - start_time) * 1000
435
+ if output_violations:
436
+ diag = _build_validation_diagnostic(
437
+ output_violations, result
438
+ )
439
+ last_diagnostic = diag
440
+ history.append(
441
+ AttemptRecord(
442
+ attempt=attempt_num,
443
+ diagnostic=diag,
444
+ duration_ms=elapsed,
445
+ )
446
+ )
447
+ if "validation" in retry_on and attempt_num < max_attempts:
448
+ continue
449
+ # Final attempt or validation not retried
450
+ violation = HandoffViolation(
451
+ output_violations[0], history=history
452
+ )
453
+ return _handle_violation(violation, input_data)
454
+
455
+ # Success
456
+ history.append(
457
+ AttemptRecord(
458
+ attempt=attempt_num,
459
+ duration_ms=elapsed,
460
+ )
461
+ )
462
+ return result
463
+
464
+ finally:
465
+ _retry_context.reset(token)
466
+
467
+ return sync_wrapper
468
+
469
+ return decorator
handoff/langgraph.py ADDED
@@ -0,0 +1,76 @@
1
+ """LangGraph-specific utilities for handoff validation."""
2
+
3
+ from typing import Any, Type, Callable, TypeVar
4
+ from pydantic import BaseModel
5
+
6
+ from handoff.guard import guard
7
+ from handoff.core import HandoffViolation
8
+
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def guarded_node(
14
+ input: Type[BaseModel] | None = None,
15
+ output: Type[BaseModel] | None = None,
16
+ *,
17
+ max_attempts: int = 1,
18
+ retry_on: tuple[str, ...] = ("validation", "parse"),
19
+ on_fail: str | Callable[[HandoffViolation], Any] = "raise",
20
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
21
+ """
22
+ LangGraph-specific decorator for node validation.
23
+
24
+ Wraps @guard with LangGraph-friendly defaults.
25
+
26
+ Example:
27
+ from handoff.langgraph import guarded_node
28
+
29
+ class AgentInput(BaseModel):
30
+ messages: list
31
+ context: dict
32
+
33
+ class AgentOutput(BaseModel):
34
+ messages: list
35
+ next_agent: str
36
+
37
+ @guarded_node(input=AgentInput, output=AgentOutput)
38
+ def my_agent(state: dict) -> dict:
39
+ # Your agent logic
40
+ return {"messages": [...], "next_agent": "reviewer"}
41
+ """
42
+ return guard(
43
+ input=input,
44
+ output=output,
45
+ max_attempts=max_attempts,
46
+ retry_on=retry_on,
47
+ on_fail=on_fail,
48
+ )
49
+
50
+
51
+ def validate_state(
52
+ state: dict | BaseModel,
53
+ schema: Type[BaseModel],
54
+ node_name: str = "unknown",
55
+ ) -> BaseModel:
56
+ """
57
+ Explicitly validate state against a schema.
58
+
59
+ Use this for manual validation points, e.g., before checkpointing.
60
+
61
+ Raises:
62
+ HandoffViolation: If validation fails
63
+
64
+ Example:
65
+ validated = validate_state(state, MyStateSchema, node_name="pre_checkpoint")
66
+ """
67
+ from handoff.guard import _validate_data
68
+
69
+ success, validated, violations = _validate_data(
70
+ state, schema, node_name, "state"
71
+ )
72
+
73
+ if not success:
74
+ raise HandoffViolation(violations[0])
75
+
76
+ return validated
handoff/retry.py ADDED
@@ -0,0 +1,145 @@
1
+ """Retry-with-feedback data structures and async-safe proxy."""
2
+
3
+ from contextvars import ContextVar
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime, timezone
6
+ from typing import Literal
7
+
8
+ RetryCause = Literal["validation", "parse"]
9
+
10
+ _RAW_OUTPUT_MAX = 500
11
+ _FEEDBACK_MAX_DEFAULT = 2000
12
+
13
+
14
+ @dataclass
15
+ class Diagnostic:
16
+ """Describes why an attempt failed."""
17
+
18
+ cause: RetryCause
19
+ message: str
20
+ errors: list[str] = field(default_factory=list)
21
+ raw_output: str | None = None
22
+ field_path: str | None = None
23
+ suggestion: str | None = None
24
+
25
+ def __post_init__(self):
26
+ if self.raw_output and len(self.raw_output) > _RAW_OUTPUT_MAX:
27
+ self.raw_output = self.raw_output[:_RAW_OUTPUT_MAX]
28
+
29
+
30
+ @dataclass
31
+ class AttemptRecord:
32
+ """Record of a single attempt."""
33
+
34
+ attempt: int
35
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
36
+ diagnostic: Diagnostic | None = None
37
+ duration_ms: float | None = None
38
+
39
+
40
+ @dataclass
41
+ class RetryState:
42
+ """Current retry state passed to functions."""
43
+
44
+ attempt: int = 1
45
+ max_attempts: int = 1
46
+ last_error: Diagnostic | None = None
47
+ history: list[AttemptRecord] = field(default_factory=list)
48
+
49
+ @property
50
+ def remaining(self) -> int:
51
+ return max(0, self.max_attempts - self.attempt)
52
+
53
+ @property
54
+ def is_retry(self) -> bool:
55
+ return self.attempt > 1
56
+
57
+ @property
58
+ def is_final_attempt(self) -> bool:
59
+ return self.attempt >= self.max_attempts
60
+
61
+ def feedback(self, max_chars: int = _FEEDBACK_MAX_DEFAULT) -> str | None:
62
+ if self.last_error is None:
63
+ return None
64
+ text = _format_diagnostic(self.last_error)
65
+ if len(text) > max_chars:
66
+ text = text[:max_chars]
67
+ return text
68
+
69
+
70
+ def _format_diagnostic(diag: Diagnostic) -> str:
71
+ """Render a Diagnostic as LLM-friendly text."""
72
+ lines = [
73
+ f"[Retry] Previous attempt failed ({diag.cause}):",
74
+ f" Message: {diag.message}",
75
+ ]
76
+ if diag.errors:
77
+ lines.append(" Errors:")
78
+ for err in diag.errors:
79
+ lines.append(f" - {err}")
80
+ if diag.field_path:
81
+ lines.append(f" Field: {diag.field_path}")
82
+ if diag.suggestion:
83
+ lines.append(f" Suggestion: {diag.suggestion}")
84
+ if diag.raw_output:
85
+ lines.append(f" Raw output: {diag.raw_output}")
86
+ return "\n".join(lines)
87
+
88
+
89
+ _retry_context: ContextVar[RetryState | None] = ContextVar(
90
+ "_retry_context", default=None
91
+ )
92
+
93
+
94
+ class _RetryProxy:
95
+ """Reads from _retry_context, exposes RetryState interface with safe defaults."""
96
+
97
+ def _get(self) -> RetryState | None:
98
+ return _retry_context.get(None)
99
+
100
+ def get(self) -> RetryState | None:
101
+ return self._get()
102
+
103
+ @property
104
+ def attempt(self) -> int:
105
+ state = self._get()
106
+ return state.attempt if state else 1
107
+
108
+ @property
109
+ def max_attempts(self) -> int:
110
+ state = self._get()
111
+ return state.max_attempts if state else 1
112
+
113
+ @property
114
+ def remaining(self) -> int:
115
+ state = self._get()
116
+ return state.remaining if state else 0
117
+
118
+ @property
119
+ def is_retry(self) -> bool:
120
+ state = self._get()
121
+ return state.is_retry if state else False
122
+
123
+ @property
124
+ def is_final_attempt(self) -> bool:
125
+ state = self._get()
126
+ return state.is_final_attempt if state else True
127
+
128
+ @property
129
+ def last_error(self) -> Diagnostic | None:
130
+ state = self._get()
131
+ return state.last_error if state else None
132
+
133
+ @property
134
+ def history(self) -> list[AttemptRecord]:
135
+ state = self._get()
136
+ return state.history if state else []
137
+
138
+ def feedback(self, max_chars: int = _FEEDBACK_MAX_DEFAULT) -> str | None:
139
+ state = self._get()
140
+ if state is None:
141
+ return None
142
+ return state.feedback(max_chars)
143
+
144
+
145
+ retry = _RetryProxy()
handoff/testing.py ADDED
@@ -0,0 +1,39 @@
1
+ """Test utilities for handoff retry."""
2
+
3
+ from contextlib import contextmanager
4
+
5
+ from handoff.retry import (
6
+ _retry_context,
7
+ RetryState,
8
+ Diagnostic,
9
+ )
10
+
11
+
12
+ @contextmanager
13
+ def mock_retry(
14
+ attempt: int = 2,
15
+ max_attempts: int = 3,
16
+ last_error: Diagnostic | None = None,
17
+ feedback_text: str | None = None,
18
+ ):
19
+ """Context manager that sets retry state for testing.
20
+
21
+ If feedback_text is provided without last_error, creates a simple
22
+ validation Diagnostic with that text as the message.
23
+ """
24
+ if feedback_text and last_error is None:
25
+ last_error = Diagnostic(
26
+ cause="validation",
27
+ message=feedback_text,
28
+ )
29
+
30
+ state = RetryState(
31
+ attempt=attempt,
32
+ max_attempts=max_attempts,
33
+ last_error=last_error,
34
+ )
35
+ token = _retry_context.set(state)
36
+ try:
37
+ yield state
38
+ finally:
39
+ _retry_context.reset(token)
handoff/utils.py ADDED
@@ -0,0 +1,49 @@
1
+ """Utility functions for handoff."""
2
+
3
+ import json
4
+
5
+
6
+ class ParseError(Exception):
7
+ """Raised when output cannot be parsed as JSON."""
8
+
9
+ def __init__(self, message: str, raw_output: str | None = None):
10
+ self.raw_output = raw_output
11
+ super().__init__(message)
12
+
13
+
14
+ def parse_json(text: str) -> dict:
15
+ """Parse JSON from text, handling common LLM output quirks.
16
+
17
+ Strips UTF-8 BOM and markdown code fences before parsing.
18
+
19
+ Raises:
20
+ ParseError: If text cannot be parsed as JSON.
21
+ """
22
+ if not isinstance(text, str):
23
+ raise ParseError(
24
+ f"Expected string, got {type(text).__name__}",
25
+ raw_output=repr(text)[:500],
26
+ )
27
+
28
+ # Strip UTF-8 BOM
29
+ cleaned = text.lstrip("\ufeff")
30
+
31
+ # Strip markdown code fences
32
+ stripped = cleaned.strip()
33
+ if stripped.startswith("```"):
34
+ lines = stripped.split("\n", 1)
35
+ if len(lines) > 1:
36
+ body = lines[1]
37
+ else:
38
+ body = ""
39
+ if body.rstrip().endswith("```"):
40
+ body = body.rstrip()[: -len("```")]
41
+ stripped = body.strip()
42
+
43
+ try:
44
+ return json.loads(stripped)
45
+ except json.JSONDecodeError as e:
46
+ raise ParseError(
47
+ f"Failed to parse JSON: {e}",
48
+ raw_output=text[:500],
49
+ ) from e
@@ -0,0 +1,233 @@
1
+ Metadata-Version: 2.4
2
+ Name: handoff-guard
3
+ Version: 0.1.0
4
+ Summary: Lightweight validation at agent boundaries. Know what broke and where.
5
+ Project-URL: Homepage, https://github.com/acartag7/handoff-guard
6
+ Project-URL: Repository, https://github.com/acartag7/handoff-guard
7
+ Author-email: Arnold <cartagena.arnold@gmail.com>
8
+ License: MIT
9
+ License-File: LICENSE
10
+ Keywords: agents,handoff,langgraph,multi-agent,validation
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Requires-Python: >=3.10
18
+ Requires-Dist: pydantic>=2.0.0
19
+ Provides-Extra: dev
20
+ Requires-Dist: pre-commit>=3.0.0; extra == 'dev'
21
+ Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
22
+ Requires-Dist: pytest>=8.0.0; extra == 'dev'
23
+ Requires-Dist: ruff>=0.1.0; extra == 'dev'
24
+ Provides-Extra: langgraph
25
+ Requires-Dist: langgraph>=0.2.0; extra == 'langgraph'
26
+ Provides-Extra: llm
27
+ Requires-Dist: httpx>=0.25.0; extra == 'llm'
28
+ Requires-Dist: langgraph>=0.2.0; extra == 'llm'
29
+ Description-Content-Type: text/markdown
30
+
31
+ # handoff-guard
32
+
33
+ > Validation for LLM agents that retries with feedback.
34
+
35
+ [![PyPI version](https://badge.fury.io/py/handoff-guard.svg)](https://badge.fury.io/py/handoff-guard)
36
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
37
+
38
+ ## The Problem
39
+
40
+ When an LLM agent returns bad output, you get a generic error and no recovery path:
41
+
42
+ ```
43
+ ValidationError: 1 validation error for State
44
+ field required (type=value_error.missing)
45
+ ```
46
+
47
+ Which node? Which field? What was passed? Can the agent fix it?
48
+
49
+ ## The Solution
50
+
51
+ ```python
52
+ from handoff import guard, retry, parse_json
53
+ from pydantic import BaseModel, Field
54
+
55
+ class WriterOutput(BaseModel):
56
+ draft: str = Field(min_length=100)
57
+ word_count: int = Field(ge=50)
58
+ tone: str
59
+ title: str
60
+
61
+ @guard(output=WriterOutput, node_name="writer", max_attempts=3)
62
+ def writer_agent(state: dict) -> dict:
63
+ prompt = "Write a JSON response with: draft, word_count, tone, title."
64
+
65
+ if retry.is_retry:
66
+ prompt += f"\n\nYour previous attempt failed:\n{retry.feedback()}"
67
+
68
+ response = call_llm(prompt)
69
+ return parse_json(response)
70
+ ```
71
+
72
+ When validation fails, the agent retries with feedback about what went wrong. After all attempts are exhausted:
73
+
74
+ ```
75
+ HandoffViolation in 'writer' (attempt 3/3):
76
+ Contract: output
77
+ Field: draft
78
+ Expected: String should have at least 100 characters
79
+ Suggestion: Increase the length of 'draft'
80
+ History: 3 failed attempts
81
+ ```
82
+
83
+ ## Quick Start
84
+
85
+ ```bash
86
+ pip install handoff-guard
87
+ ```
88
+
89
+ ```bash
90
+ # See retry-with-feedback in action (no API key needed)
91
+ python -m examples.llm_demo.run_demo
92
+
93
+ # Run with real LLM calls
94
+ export OPENROUTER_API_KEY=your_key
95
+ python -m examples.llm_demo.run_demo --pipeline --api
96
+ ```
97
+
98
+ ## Features
99
+
100
+ - **Retry with feedback** — Failed outputs are fed back to the agent as context
101
+ - **Know which node failed** — No more guessing from stack traces
102
+ - **Know which field failed** — Exact path to the problem
103
+ - **Get fix suggestions** — Actionable error messages
104
+ - **`parse_json`** — Strips code fences, handles BOM, raises `ParseError` on failure
105
+ - **Framework agnostic** — Works with LangGraph, CrewAI, or plain Python
106
+ - **Lightweight** — Just Pydantic, no Docker, no telemetry servers
107
+
108
+ ## API
109
+
110
+ ### `@guard` decorator
111
+
112
+ ```python
113
+ @guard(
114
+ input=InputSchema, # Pydantic model for input validation
115
+ output=OutputSchema, # Pydantic model for output validation
116
+ node_name="my_node", # Identifies the node in errors (default: function name)
117
+ max_attempts=3, # Retry up to 3 times (default: 1, no retry)
118
+ retry_on=("validation", "parse"), # What errors trigger retry (default)
119
+ on_fail="raise", # "raise" | "return_none" | "return_input" | callable
120
+ )
121
+ ```
122
+
123
+ ### `retry` proxy
124
+
125
+ Access retry state inside any guarded function:
126
+
127
+ ```python
128
+ from handoff import retry
129
+
130
+ retry.is_retry # True if attempt > 1
131
+ retry.attempt # Current attempt number
132
+ retry.max_attempts # Total allowed attempts
133
+ retry.remaining # Attempts left
134
+ retry.is_final_attempt
135
+ retry.feedback() # Formatted string describing last error, or None
136
+ retry.last_error # Diagnostic object, or None
137
+ retry.history # List of AttemptRecord objects
138
+ ```
139
+
140
+ ### `parse_json`
141
+
142
+ ```python
143
+ from handoff import parse_json
144
+
145
+ data = parse_json('```json\n{"key": "value"}\n```')
146
+ # Returns: {"key": "value"}
147
+ # Raises ParseError on failure (retryable by @guard)
148
+ ```
149
+
150
+ ### `HandoffViolation`
151
+
152
+ Raised when all retry attempts are exhausted:
153
+
154
+ ```python
155
+ from handoff import HandoffViolation
156
+
157
+ try:
158
+ result = my_agent(state)
159
+ except HandoffViolation as e:
160
+ print(e.node_name) # "writer"
161
+ print(e.total_attempts) # 3
162
+ print(e.history) # List of AttemptRecord with diagnostics
163
+ print(e.to_dict()) # Serializable for logging
164
+ ```
165
+
166
+ ### Handle Failures
167
+
168
+ ```python
169
+ @guard(output=Schema, on_fail="raise") # Raise exception (default)
170
+ @guard(output=Schema, on_fail="return_none") # Return None on failure
171
+ @guard(output=Schema, on_fail="return_input") # Return input unchanged
172
+ @guard(output=Schema, on_fail=my_handler) # Custom handler
173
+ ```
174
+
175
+ ## Examples
176
+
177
+ | Demo | What it shows |
178
+ |------|---------------|
179
+ | [`examples/llm_demo`](examples/llm_demo/) | Retry-with-feedback: writer fails, gets feedback, self-corrects |
180
+ | [`examples/rag_demo`](examples/rag_demo/) | Multi-stage pipeline validation + hallucinated citation detection |
181
+
182
+ Both demos support `--api` for real LLM calls and run with mock data by default.
183
+
184
+ ## With LangGraph
185
+
186
+ ```python
187
+ from handoff.langgraph import guarded_node
188
+ from pydantic import BaseModel, Field
189
+
190
+ class RouterOutput(BaseModel):
191
+ next_agent: str = Field(pattern="^(writer|reviewer|done)$")
192
+ messages: list
193
+
194
+ @guarded_node(output=RouterOutput)
195
+ def router(state: dict) -> dict:
196
+ return {
197
+ "next_agent": "writer",
198
+ "messages": state["messages"]
199
+ }
200
+ ```
201
+
202
+ ## Why not just use Pydantic directly?
203
+
204
+ You should! Handoff uses Pydantic under the hood.
205
+
206
+ The difference:
207
+
208
+ | Pydantic alone | Handoff |
209
+ |----------------|---------|
210
+ | `ValidationError: 1 validation error` | `HandoffViolation in 'router_node'` |
211
+ | Generic stack trace | Exact node + field + suggestion |
212
+ | You wire up validation manually | One decorator |
213
+ | No retry | Automatic retry with feedback |
214
+ | Errors are for developers | Errors are actionable for agents |
215
+
216
+ ## Roadmap
217
+
218
+ - [ ] Invariant contracts (input/output relationships)
219
+ - [ ] CrewAI adapter
220
+ - [x] Retry with feedback loop
221
+ - [ ] VS Code extension for violation inspection
222
+
223
+ ## Contributing
224
+
225
+ Contributions welcome! Please open an issue first to discuss what you'd like to change.
226
+
227
+ ## License
228
+
229
+ MIT
230
+
231
+ ---
232
+
233
+ Built for developers who are tired of debugging agent handoffs.
@@ -0,0 +1,11 @@
1
+ handoff/__init__.py,sha256=RDO2wqA1zROrtM2GE_B_1pLA4OOClCzzjMARaorB-Oo,449
2
+ handoff/core.py,sha256=zVGKyLO7TxwoJS_J5n-c47OdQGUmrQJjIvRW97tInBs,3173
3
+ handoff/guard.py,sha256=Mf3Nj7b8MC_WHcvMhzaVWFTSTGjTPHu64Jz2AeTnAvI,18752
4
+ handoff/langgraph.py,sha256=MLm4G2uZWiE30lFKU99DEA_LjlSHzTvb5Ci7urFlyF8,1931
5
+ handoff/retry.py,sha256=QVuPAygS3hEEXKS9h073uW4N0erURXiyIYFllsaR1YY,3891
6
+ handoff/testing.py,sha256=WlsEKyOX8_M8cfUc3xbcoMFU4Pyn-j1ffD8xPigAQ_Q,915
7
+ handoff/utils.py,sha256=c0nN5AAXgx3mbVKa8jHFMqkor0qtchpC0Jc_e_yjzTE,1302
8
+ handoff_guard-0.1.0.dist-info/METADATA,sha256=g8B526tyQGMVMRcA3isJGJwTu2PkSw9PKMFjdmF45xo,6969
9
+ handoff_guard-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ handoff_guard-0.1.0.dist-info/licenses/LICENSE,sha256=HroDuVS_iz3RlUAJhww1Ma4olb4onXHg1MXaIZCb06s,1063
11
+ handoff_guard-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Arnold
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.