kite-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kite/__init__.py +46 -0
- kite/ab_testing.py +384 -0
- kite/agent.py +556 -0
- kite/agents/__init__.py +3 -0
- kite/agents/plan_execute.py +191 -0
- kite/agents/react_agent.py +509 -0
- kite/agents/reflective_agent.py +90 -0
- kite/agents/rewoo.py +119 -0
- kite/agents/tot.py +151 -0
- kite/conversation.py +125 -0
- kite/core.py +974 -0
- kite/data_loaders.py +111 -0
- kite/embedding_providers.py +372 -0
- kite/llm_providers.py +1278 -0
- kite/memory/__init__.py +6 -0
- kite/memory/advanced_rag.py +333 -0
- kite/memory/graph_rag.py +719 -0
- kite/memory/session_memory.py +423 -0
- kite/memory/vector_memory.py +579 -0
- kite/monitoring.py +611 -0
- kite/observers.py +107 -0
- kite/optimization/__init__.py +9 -0
- kite/optimization/resource_router.py +80 -0
- kite/persistence.py +42 -0
- kite/pipeline/__init__.py +5 -0
- kite/pipeline/deterministic_pipeline.py +323 -0
- kite/pipeline/reactive_pipeline.py +171 -0
- kite/pipeline_manager.py +15 -0
- kite/routing/__init__.py +6 -0
- kite/routing/aggregator_router.py +325 -0
- kite/routing/llm_router.py +149 -0
- kite/routing/semantic_router.py +228 -0
- kite/safety/__init__.py +6 -0
- kite/safety/circuit_breaker.py +360 -0
- kite/safety/guardrails.py +82 -0
- kite/safety/idempotency_manager.py +304 -0
- kite/safety/kill_switch.py +75 -0
- kite/tool.py +183 -0
- kite/tool_registry.py +87 -0
- kite/tools/__init__.py +21 -0
- kite/tools/code_execution.py +53 -0
- kite/tools/contrib/__init__.py +19 -0
- kite/tools/contrib/calculator.py +26 -0
- kite/tools/contrib/datetime_utils.py +20 -0
- kite/tools/contrib/linkedin.py +428 -0
- kite/tools/contrib/web_search.py +30 -0
- kite/tools/mcp/__init__.py +31 -0
- kite/tools/mcp/database_mcp.py +267 -0
- kite/tools/mcp/gdrive_mcp_server.py +503 -0
- kite/tools/mcp/gmail_mcp_server.py +601 -0
- kite/tools/mcp/postgres_mcp_server.py +490 -0
- kite/tools/mcp/slack_mcp_server.py +538 -0
- kite/tools/mcp/stripe_mcp_server.py +219 -0
- kite/tools/search.py +90 -0
- kite/tools/system_tools.py +54 -0
- kite/tools_manager.py +27 -0
- kite_agent-0.1.0.dist-info/METADATA +621 -0
- kite_agent-0.1.0.dist-info/RECORD +61 -0
- kite_agent-0.1.0.dist-info/WHEEL +5 -0
- kite_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- kite_agent-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic Router Implementation
|
|
3
|
+
Based on Chapter 6: The Semantic Router
|
|
4
|
+
|
|
5
|
+
A router that classifies user intent and routes to specialists.
|
|
6
|
+
Uses embeddings for fast, cached classification.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import hashlib
|
|
11
|
+
from typing import Dict, List, Optional, Callable
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from functools import lru_cache
|
|
14
|
+
import numpy as np
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ============================================================================
|
|
21
|
+
# ROUTING CONFIGURATION
|
|
22
|
+
# ============================================================================
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Route:
|
|
26
|
+
"""Define a route with examples and handler."""
|
|
27
|
+
name: str
|
|
28
|
+
description: str
|
|
29
|
+
examples: List[str]
|
|
30
|
+
handler: Callable
|
|
31
|
+
embedding: Optional[np.ndarray] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ============================================================================
|
|
35
|
+
# SEMANTIC ROUTER
|
|
36
|
+
# ============================================================================
|
|
37
|
+
|
|
38
|
+
class SemanticRouter:
|
|
39
|
+
"""
|
|
40
|
+
Routes user queries to specialist agents based on semantic similarity.
|
|
41
|
+
|
|
42
|
+
Features:
|
|
43
|
+
- Embedding-based classification (fast, cheap)
|
|
44
|
+
- LRU caching for repeated queries
|
|
45
|
+
- Confidence thresholds
|
|
46
|
+
- Ambiguity handling
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
router = SemanticRouter(embedding_provider=my_embedder)
|
|
50
|
+
router.add_route("technical", tech_agent.handle, [
|
|
51
|
+
"my internet is down",
|
|
52
|
+
"can't connect to wifi",
|
|
53
|
+
"server not responding"
|
|
54
|
+
])
|
|
55
|
+
|
|
56
|
+
response = router.route("my connection keeps dropping")
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self,
|
|
60
|
+
confidence_threshold: float = 0.75,
|
|
61
|
+
embedding_provider = None):
|
|
62
|
+
self.routes: List[Route] = []
|
|
63
|
+
self.confidence_threshold = confidence_threshold
|
|
64
|
+
self.embedding_provider = embedding_provider
|
|
65
|
+
|
|
66
|
+
# Precompute route embeddings if routes exist (none initially)
|
|
67
|
+
self._compute_route_embeddings()
|
|
68
|
+
|
|
69
|
+
@lru_cache(maxsize=10000)
|
|
70
|
+
def _get_embedding(self, text: str) -> np.ndarray:
|
|
71
|
+
"""Get embedding for text with LRU caching."""
|
|
72
|
+
if self.embedding_provider:
|
|
73
|
+
return np.array(self.embedding_provider.embed(text))
|
|
74
|
+
|
|
75
|
+
raise RuntimeError("No embedding provider configured for SemanticRouter")
|
|
76
|
+
|
|
77
|
+
def _compute_route_embeddings(self):
|
|
78
|
+
"""Precompute embeddings for all route examples."""
|
|
79
|
+
if not self.routes:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
print("[CHART] Computing route embeddings...")
|
|
83
|
+
for route in self.routes:
|
|
84
|
+
if route.embedding is None:
|
|
85
|
+
# Combine all examples into one text for the route
|
|
86
|
+
combined_text = " | ".join(route.examples)
|
|
87
|
+
route.embedding = self._get_embedding(combined_text)
|
|
88
|
+
|
|
89
|
+
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
|
|
90
|
+
"""Calculate cosine similarity between two vectors."""
|
|
91
|
+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
92
|
+
|
|
93
|
+
def classify_intent(self, query: str) -> tuple[Route, float]:
|
|
94
|
+
"""
|
|
95
|
+
Classify query intent using semantic similarity.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
query: User query
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
(best_route, confidence_score)
|
|
102
|
+
"""
|
|
103
|
+
if not self.routes:
|
|
104
|
+
raise RuntimeError("No routes configured in SemanticRouter")
|
|
105
|
+
|
|
106
|
+
# Get query embedding (cached if seen before)
|
|
107
|
+
query_embedding = self._get_embedding(query)
|
|
108
|
+
|
|
109
|
+
# Calculate similarity to each route
|
|
110
|
+
scores = []
|
|
111
|
+
for route in self.routes:
|
|
112
|
+
similarity = self._cosine_similarity(query_embedding, route.embedding)
|
|
113
|
+
scores.append((route, similarity))
|
|
114
|
+
|
|
115
|
+
# Sort by similarity
|
|
116
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
117
|
+
|
|
118
|
+
best_route, best_score = scores[0]
|
|
119
|
+
|
|
120
|
+
print(f"\n Intent Classification:")
|
|
121
|
+
print(f" Query: {query}")
|
|
122
|
+
print(f" Best Match: {best_route.name} (confidence: {best_score:.2%})")
|
|
123
|
+
|
|
124
|
+
return best_route, best_score
|
|
125
|
+
|
|
126
|
+
async def route(self, query: str, context: Optional[Dict] = None) -> Dict:
|
|
127
|
+
"""
|
|
128
|
+
Route query to appropriate specialist agent.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
query: User query
|
|
132
|
+
context: Optional context to pass to handler
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Response dictionary with routing info
|
|
136
|
+
"""
|
|
137
|
+
# Classify intent
|
|
138
|
+
route, confidence = self.classify_intent(query)
|
|
139
|
+
|
|
140
|
+
# Handle low confidence (ambiguity)
|
|
141
|
+
if confidence < self.confidence_threshold:
|
|
142
|
+
print(f"[WARN] Low confidence ({confidence:.2%} < {self.confidence_threshold:.2%})")
|
|
143
|
+
|
|
144
|
+
return {
|
|
145
|
+
"route": "none",
|
|
146
|
+
"confidence": confidence,
|
|
147
|
+
"response": "I'm not sure how to help with that. Could you be more specific?",
|
|
148
|
+
"needs_clarification": True,
|
|
149
|
+
"suggested_routes": [r.name for r, _ in self._get_top_routes(query, n=3)]
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Route to specialist
|
|
153
|
+
print(f"[OK] Routing to {route.name}")
|
|
154
|
+
if context:
|
|
155
|
+
response = route.handler(query, context)
|
|
156
|
+
else:
|
|
157
|
+
try:
|
|
158
|
+
response = route.handler(query)
|
|
159
|
+
except TypeError:
|
|
160
|
+
# Fallback if handler expects context but none provided
|
|
161
|
+
response = route.handler(query, {})
|
|
162
|
+
|
|
163
|
+
# Support both sync and async handlers
|
|
164
|
+
import asyncio
|
|
165
|
+
if asyncio.iscoroutine(response):
|
|
166
|
+
response = await response
|
|
167
|
+
|
|
168
|
+
# If response is a dict (from Agent.run), extract the 'response' text if available
|
|
169
|
+
if isinstance(response, dict) and 'response' in response:
|
|
170
|
+
response_text = response['response']
|
|
171
|
+
else:
|
|
172
|
+
response_text = str(response)
|
|
173
|
+
|
|
174
|
+
return {
|
|
175
|
+
"route": route.name,
|
|
176
|
+
"confidence": confidence,
|
|
177
|
+
"response": response_text,
|
|
178
|
+
"needs_clarification": False
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def _get_top_routes(self, query: str, n: int = 3) -> List[tuple[Route, float]]:
|
|
182
|
+
"""Get top N routes by similarity."""
|
|
183
|
+
query_embedding = self._get_embedding(query)
|
|
184
|
+
|
|
185
|
+
scores = [
|
|
186
|
+
(route, self._cosine_similarity(query_embedding, route.embedding))
|
|
187
|
+
for route in self.routes
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
191
|
+
return scores[:n]
|
|
192
|
+
|
|
193
|
+
def add_route(self, name: str, examples: List[str] | str, description: str = "", handler: Callable = None):
|
|
194
|
+
"""Add a new route dynamically."""
|
|
195
|
+
if isinstance(examples, str):
|
|
196
|
+
examples = [examples]
|
|
197
|
+
|
|
198
|
+
route = Route(
|
|
199
|
+
name=name,
|
|
200
|
+
description=description,
|
|
201
|
+
examples=examples,
|
|
202
|
+
handler=handler or (lambda x: f"Default response for {name}")
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Compute embedding if provider exists
|
|
206
|
+
if self.embedding_provider:
|
|
207
|
+
combined_text = " | ".join(examples)
|
|
208
|
+
route.embedding = self._get_embedding(combined_text)
|
|
209
|
+
|
|
210
|
+
self.routes.append(route)
|
|
211
|
+
print(f"[OK] Added route: {name}")
|
|
212
|
+
|
|
213
|
+
def get_stats(self) -> Dict:
|
|
214
|
+
"""Get router statistics."""
|
|
215
|
+
cache_info = self._get_embedding.cache_info()
|
|
216
|
+
|
|
217
|
+
return {
|
|
218
|
+
"total_routes": len(self.routes),
|
|
219
|
+
"confidence_threshold": self.confidence_threshold,
|
|
220
|
+
"cache_size": cache_info.currsize,
|
|
221
|
+
"cache_hits": cache_info.hits,
|
|
222
|
+
"cache_misses": cache_info.misses,
|
|
223
|
+
"cache_hit_rate": (
|
|
224
|
+
cache_info.hits / (cache_info.hits + cache_info.misses)
|
|
225
|
+
if (cache_info.hits + cache_info.misses) > 0
|
|
226
|
+
else 0
|
|
227
|
+
)
|
|
228
|
+
}
|
kite/safety/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Safety patterns module."""
|
|
2
|
+
from .circuit_breaker import CircuitBreaker, CircuitBreakerConfig, CircuitState
|
|
3
|
+
from .idempotency_manager import IdempotencyManager, IdempotencyConfig
|
|
4
|
+
from .kill_switch import KillSwitch
|
|
5
|
+
|
|
6
|
+
__all__ = ['CircuitBreaker', 'CircuitBreakerConfig', 'CircuitState', 'IdempotencyManager', 'IdempotencyConfig', 'KillSwitch']
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Circuit Breaker Pattern for AI Agent Operations
|
|
3
|
+
|
|
4
|
+
Prevents cascading failures when an agent repeatedly attempts failed operations.
|
|
5
|
+
This is critical for write-access AI systems where retries can cause real damage.
|
|
6
|
+
|
|
7
|
+
Author: [Your Name]
|
|
8
|
+
License: MIT
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from datetime import datetime, timedelta
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Dict, Optional, Callable, Any
|
|
15
|
+
import logging
|
|
16
|
+
from functools import wraps
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CircuitState(Enum):
|
|
22
|
+
"""States of the circuit breaker."""
|
|
23
|
+
CLOSED = "closed" # Normal operation
|
|
24
|
+
OPEN = "open" # Blocking requests
|
|
25
|
+
HALF_OPEN = "half_open" # Testing if service recovered
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class CircuitBreakerConfig:
|
|
30
|
+
"""Configuration for circuit breaker behavior."""
|
|
31
|
+
failure_threshold: int = 3 # Failures before opening
|
|
32
|
+
success_threshold: int = 2 # Successes to close from half-open
|
|
33
|
+
timeout_seconds: int = 60 # Time before attempting recovery
|
|
34
|
+
half_open_max_calls: int = 1 # Max concurrent calls in half-open
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class CircuitStats:
|
|
39
|
+
"""Statistics for monitoring circuit breaker health."""
|
|
40
|
+
total_calls: int = 0
|
|
41
|
+
successful_calls: int = 0
|
|
42
|
+
failed_calls: int = 0
|
|
43
|
+
rejected_calls: int = 0
|
|
44
|
+
last_failure_time: Optional[datetime] = None
|
|
45
|
+
last_success_time: Optional[datetime] = None
|
|
46
|
+
state_changes: int = 0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CircuitBreaker:
|
|
50
|
+
"""
|
|
51
|
+
Circuit breaker for AI agent operations.
|
|
52
|
+
|
|
53
|
+
Use this to wrap any operation that:
|
|
54
|
+
- Makes external API calls
|
|
55
|
+
- Modifies state
|
|
56
|
+
- Costs money
|
|
57
|
+
- Could fail repeatedly
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
circuit_breaker = CircuitBreaker(
|
|
61
|
+
name="stripe_refunds",
|
|
62
|
+
config=CircuitBreakerConfig(failure_threshold=3)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@circuit_breaker.protected
|
|
66
|
+
def process_refund(order_id: str, amount: float):
|
|
67
|
+
return stripe.Refund.create(...)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
name: str,
|
|
73
|
+
config: CircuitBreakerConfig = None,
|
|
74
|
+
on_open: Optional[Callable] = None,
|
|
75
|
+
on_close: Optional[Callable] = None
|
|
76
|
+
):
|
|
77
|
+
self.name = name
|
|
78
|
+
self.config = config or CircuitBreakerConfig()
|
|
79
|
+
self.state = CircuitState.CLOSED
|
|
80
|
+
self.stats = CircuitStats()
|
|
81
|
+
|
|
82
|
+
# Callbacks for state changes
|
|
83
|
+
self.on_open = on_open
|
|
84
|
+
self.on_close = on_close
|
|
85
|
+
|
|
86
|
+
# Failure tracking
|
|
87
|
+
self.consecutive_failures = 0
|
|
88
|
+
self.consecutive_successes = 0
|
|
89
|
+
self.open_until: Optional[datetime] = None
|
|
90
|
+
self.half_open_calls = 0
|
|
91
|
+
|
|
92
|
+
logger.info(f"Circuit breaker '{name}' initialized: {config}")
|
|
93
|
+
|
|
94
|
+
def _change_state(self, new_state: CircuitState, reason: str):
|
|
95
|
+
"""Change circuit state and trigger callbacks."""
|
|
96
|
+
if new_state == self.state:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
old_state = self.state
|
|
100
|
+
self.state = new_state
|
|
101
|
+
self.stats.state_changes += 1
|
|
102
|
+
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"Circuit '{self.name}': {old_state.value} {new_state.value} "
|
|
105
|
+
f"(Reason: {reason})"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if new_state == CircuitState.OPEN and self.on_open:
|
|
109
|
+
self.on_open(self.name, self.stats)
|
|
110
|
+
elif new_state == CircuitState.CLOSED and self.on_close:
|
|
111
|
+
self.on_close(self.name, self.stats)
|
|
112
|
+
|
|
113
|
+
def _should_attempt_reset(self) -> bool:
|
|
114
|
+
"""Check if enough time has passed to attempt recovery."""
|
|
115
|
+
if self.state != CircuitState.OPEN:
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
if self.open_until and datetime.now() >= self.open_until:
|
|
119
|
+
return True
|
|
120
|
+
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
def call(self, func: Callable, *args, **kwargs) -> Any:
|
|
124
|
+
"""
|
|
125
|
+
Execute a function with circuit breaker protection.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
func: Function to call
|
|
129
|
+
*args, **kwargs: Arguments to pass to function
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Result from function
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
CircuitBreakerError: If circuit is open
|
|
136
|
+
Original exception: If function fails
|
|
137
|
+
"""
|
|
138
|
+
self.stats.total_calls += 1
|
|
139
|
+
|
|
140
|
+
# Check if circuit should transition from OPEN to HALF_OPEN
|
|
141
|
+
if self._should_attempt_reset():
|
|
142
|
+
self._change_state(CircuitState.HALF_OPEN, "Timeout expired")
|
|
143
|
+
self.half_open_calls = 0
|
|
144
|
+
|
|
145
|
+
# Block calls if circuit is OPEN
|
|
146
|
+
if self.state == CircuitState.OPEN:
|
|
147
|
+
self.stats.rejected_calls += 1
|
|
148
|
+
raise CircuitBreakerError(
|
|
149
|
+
f"Circuit '{self.name}' is OPEN. "
|
|
150
|
+
f"Reset at {self.open_until}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Limit concurrent calls in HALF_OPEN state
|
|
154
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
155
|
+
if self.half_open_calls >= self.config.half_open_max_calls:
|
|
156
|
+
self.stats.rejected_calls += 1
|
|
157
|
+
raise CircuitBreakerError(
|
|
158
|
+
f"Circuit '{self.name}' is HALF_OPEN and at capacity"
|
|
159
|
+
)
|
|
160
|
+
self.half_open_calls += 1
|
|
161
|
+
|
|
162
|
+
# Attempt the call
|
|
163
|
+
try:
|
|
164
|
+
result = func(*args, **kwargs)
|
|
165
|
+
self._on_success()
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
except Exception as e:
|
|
169
|
+
self._on_failure()
|
|
170
|
+
raise
|
|
171
|
+
|
|
172
|
+
finally:
|
|
173
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
174
|
+
self.half_open_calls -= 1
|
|
175
|
+
|
|
176
|
+
def _on_success(self):
|
|
177
|
+
"""Handle successful function call."""
|
|
178
|
+
self.stats.successful_calls += 1
|
|
179
|
+
self.stats.last_success_time = datetime.now()
|
|
180
|
+
self.consecutive_failures = 0
|
|
181
|
+
|
|
182
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
183
|
+
self.consecutive_successes += 1
|
|
184
|
+
|
|
185
|
+
if self.consecutive_successes >= self.config.success_threshold:
|
|
186
|
+
self._change_state(
|
|
187
|
+
CircuitState.CLOSED,
|
|
188
|
+
f"{self.consecutive_successes} consecutive successes"
|
|
189
|
+
)
|
|
190
|
+
self.consecutive_successes = 0
|
|
191
|
+
|
|
192
|
+
logger.debug(f"Circuit '{self.name}': Call succeeded")
|
|
193
|
+
|
|
194
|
+
def _on_failure(self):
|
|
195
|
+
"""Handle failed function call."""
|
|
196
|
+
self.stats.failed_calls += 1
|
|
197
|
+
self.stats.last_failure_time = datetime.now()
|
|
198
|
+
self.consecutive_successes = 0
|
|
199
|
+
self.consecutive_failures += 1
|
|
200
|
+
|
|
201
|
+
logger.warning(
|
|
202
|
+
f"Circuit '{self.name}': Call failed "
|
|
203
|
+
f"({self.consecutive_failures}/{self.config.failure_threshold})"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Open circuit if threshold exceeded
|
|
207
|
+
if self.consecutive_failures >= self.config.failure_threshold:
|
|
208
|
+
self.open_until = (
|
|
209
|
+
datetime.now() +
|
|
210
|
+
timedelta(seconds=self.config.timeout_seconds)
|
|
211
|
+
)
|
|
212
|
+
self._change_state(
|
|
213
|
+
CircuitState.OPEN,
|
|
214
|
+
f"{self.consecutive_failures} consecutive failures"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def protected(self, func: Callable) -> Callable:
|
|
218
|
+
"""
|
|
219
|
+
Decorator to protect a function with circuit breaker.
|
|
220
|
+
|
|
221
|
+
Example:
|
|
222
|
+
@circuit_breaker.protected
|
|
223
|
+
def risky_operation():
|
|
224
|
+
return external_api.call()
|
|
225
|
+
"""
|
|
226
|
+
@wraps(func)
|
|
227
|
+
def wrapper(*args, **kwargs):
|
|
228
|
+
return self.call(func, *args, **kwargs)
|
|
229
|
+
return wrapper
|
|
230
|
+
|
|
231
|
+
def reset(self):
|
|
232
|
+
"""Manually reset circuit to closed state."""
|
|
233
|
+
self._change_state(CircuitState.CLOSED, "Manual reset")
|
|
234
|
+
self.consecutive_failures = 0
|
|
235
|
+
self.consecutive_successes = 0
|
|
236
|
+
self.open_until = None
|
|
237
|
+
|
|
238
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
239
|
+
"""Get current statistics."""
|
|
240
|
+
return {
|
|
241
|
+
"name": self.name,
|
|
242
|
+
"state": self.state.value,
|
|
243
|
+
"total_calls": self.stats.total_calls,
|
|
244
|
+
"successful_calls": self.stats.successful_calls,
|
|
245
|
+
"failed_calls": self.stats.failed_calls,
|
|
246
|
+
"rejected_calls": self.stats.rejected_calls,
|
|
247
|
+
"success_rate": (
|
|
248
|
+
self.stats.successful_calls / self.stats.total_calls
|
|
249
|
+
if self.stats.total_calls > 0 else 0
|
|
250
|
+
),
|
|
251
|
+
"consecutive_failures": self.consecutive_failures,
|
|
252
|
+
"open_until": self.open_until.isoformat() if self.open_until else None,
|
|
253
|
+
"last_failure": (
|
|
254
|
+
self.stats.last_failure_time.isoformat()
|
|
255
|
+
if self.stats.last_failure_time else None
|
|
256
|
+
)
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class CircuitBreakerError(Exception):
|
|
261
|
+
"""Raised when circuit breaker blocks an operation."""
|
|
262
|
+
pass
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class CircuitBreakerRegistry:
|
|
266
|
+
"""
|
|
267
|
+
Global registry for managing multiple circuit breakers.
|
|
268
|
+
|
|
269
|
+
Use this when you have multiple operations that need separate circuits.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(self):
|
|
273
|
+
self._breakers: Dict[str, CircuitBreaker] = {}
|
|
274
|
+
|
|
275
|
+
def get_or_create(
|
|
276
|
+
self,
|
|
277
|
+
name: str,
|
|
278
|
+
config: Optional[CircuitBreakerConfig] = None
|
|
279
|
+
) -> CircuitBreaker:
|
|
280
|
+
"""Get existing circuit breaker or create new one."""
|
|
281
|
+
if name not in self._breakers:
|
|
282
|
+
self._breakers[name] = CircuitBreaker(name, config)
|
|
283
|
+
return self._breakers[name]
|
|
284
|
+
|
|
285
|
+
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
|
|
286
|
+
"""Get statistics for all circuit breakers."""
|
|
287
|
+
return {
|
|
288
|
+
name: breaker.get_stats()
|
|
289
|
+
for name, breaker in self._breakers.items()
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
def reset_all(self):
|
|
293
|
+
"""Reset all circuit breakers."""
|
|
294
|
+
for breaker in self._breakers.values():
|
|
295
|
+
breaker.reset()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# Global registry instance
|
|
299
|
+
registry = CircuitBreakerRegistry()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# Convenience function
|
|
303
|
+
def circuit_breaker(
|
|
304
|
+
name: str,
|
|
305
|
+
config: Optional[CircuitBreakerConfig] = None
|
|
306
|
+
) -> CircuitBreaker:
|
|
307
|
+
"""Get or create a circuit breaker from global registry."""
|
|
308
|
+
return registry.get_or_create(name, config)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
if __name__ == "__main__":
|
|
312
|
+
# Example usage
|
|
313
|
+
import time
|
|
314
|
+
|
|
315
|
+
# Configure logging
|
|
316
|
+
logging.basicConfig(
|
|
317
|
+
level=logging.INFO,
|
|
318
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Create circuit breaker
|
|
322
|
+
breaker = CircuitBreaker(
|
|
323
|
+
name="example_api",
|
|
324
|
+
config=CircuitBreakerConfig(
|
|
325
|
+
failure_threshold=3,
|
|
326
|
+
timeout_seconds=5
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Simulate API calls
|
|
331
|
+
call_count = 0
|
|
332
|
+
|
|
333
|
+
def flaky_api_call():
|
|
334
|
+
"""Simulates an API that fails sometimes."""
|
|
335
|
+
global call_count
|
|
336
|
+
call_count += 1
|
|
337
|
+
|
|
338
|
+
# Fail first 5 calls, then succeed
|
|
339
|
+
if call_count <= 5:
|
|
340
|
+
raise Exception("API Error")
|
|
341
|
+
return {"status": "success", "data": "..."}
|
|
342
|
+
|
|
343
|
+
# Test circuit breaker
|
|
344
|
+
print("\n=== Testing Circuit Breaker ===\n")
|
|
345
|
+
|
|
346
|
+
for i in range(10):
|
|
347
|
+
print(f"\nAttempt {i+1}:")
|
|
348
|
+
try:
|
|
349
|
+
result = breaker.call(flaky_api_call)
|
|
350
|
+
print(f"[OK] Success: {result}")
|
|
351
|
+
except CircuitBreakerError as e:
|
|
352
|
+
print(f" Circuit Open: {e}")
|
|
353
|
+
except Exception as e:
|
|
354
|
+
print(f" API Error: {e}")
|
|
355
|
+
|
|
356
|
+
time.sleep(1)
|
|
357
|
+
|
|
358
|
+
# Show final stats
|
|
359
|
+
print("\n=== Final Statistics ===")
|
|
360
|
+
print(breaker.get_stats())
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Guardrails / Safety Patterns (Chapter 18)
|
|
3
|
+
Provides mechanisms to ensure agent inputs and outputs adhere to safety and structure guidelines.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Type, Optional, Any, Dict
|
|
7
|
+
from pydantic import BaseModel, ValidationError, Field
|
|
8
|
+
import re
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger("Guardrails")
|
|
12
|
+
|
|
13
|
+
class OutputGuardrail:
|
|
14
|
+
"""
|
|
15
|
+
Enforces structured output from an LLM using Pydantic models.
|
|
16
|
+
Chapter 18: Output Filtering / Post-processing.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self, model: Type[BaseModel], fix_on_failure: bool = True):
|
|
19
|
+
self.model = model
|
|
20
|
+
self.fix_on_failure = fix_on_failure
|
|
21
|
+
|
|
22
|
+
def validate(self, output: str) -> Optional[BaseModel]:
|
|
23
|
+
"""
|
|
24
|
+
Validates and parses the LLM output. Uses regex to find the first JSON object.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
# Attempt to find a JSON-like block { ... }
|
|
28
|
+
# distinct from code blocks, just looking for the brace structure
|
|
29
|
+
json_match = re.search(r"\{.*\}", output, re.DOTALL)
|
|
30
|
+
|
|
31
|
+
if json_match:
|
|
32
|
+
clean_output = json_match.group(0)
|
|
33
|
+
else:
|
|
34
|
+
# Fallback to original cleanup if regex fails
|
|
35
|
+
clean_output = output.strip()
|
|
36
|
+
if clean_output.startswith("```json"):
|
|
37
|
+
clean_output = clean_output[7:]
|
|
38
|
+
if clean_output.startswith("```"):
|
|
39
|
+
clean_output = clean_output[3:]
|
|
40
|
+
if clean_output.endswith("```"):
|
|
41
|
+
clean_output = clean_output[:-3]
|
|
42
|
+
|
|
43
|
+
clean_output = clean_output.strip()
|
|
44
|
+
|
|
45
|
+
# Parse
|
|
46
|
+
parsed = self.model.model_validate_json(clean_output)
|
|
47
|
+
logger.info(f"Guardrail passed: {type(parsed).__name__}")
|
|
48
|
+
return parsed
|
|
49
|
+
|
|
50
|
+
except ValidationError as e:
|
|
51
|
+
logger.warning(f"Guardrail validation failed: {e}")
|
|
52
|
+
if self.fix_on_failure:
|
|
53
|
+
return None # Signal need for retry/fix
|
|
54
|
+
raise e
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.error(f"Guardrail parsing error: {e}")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
class StandardEvaluation(BaseModel):
|
|
60
|
+
"""Standard schema for Agent critique/review."""
|
|
61
|
+
score: int = Field(description="Score from 1-10")
|
|
62
|
+
feedback: str = Field(description="Specific feedback for improvement")
|
|
63
|
+
approved: bool = Field(description="Whether the output is acceptable")
|
|
64
|
+
|
|
65
|
+
class InputGuardrail:
|
|
66
|
+
"""
|
|
67
|
+
Filters unsafe or irrelevant user inputs before they reach the agent.
|
|
68
|
+
Chapter 18: Input Validation / Sanitization.
|
|
69
|
+
"""
|
|
70
|
+
def __init__(self, forbidden_terms: list = None):
|
|
71
|
+
self.forbidden_terms = forbidden_terms or ["ignore all instructions", "system prompt"]
|
|
72
|
+
|
|
73
|
+
def check(self, user_input: str) -> bool:
|
|
74
|
+
"""
|
|
75
|
+
Returns True if safe, False if unsafe.
|
|
76
|
+
"""
|
|
77
|
+
content = user_input.lower()
|
|
78
|
+
for term in self.forbidden_terms:
|
|
79
|
+
if term in content:
|
|
80
|
+
logger.warning(f"Guardrail blocked input containing: '{term}'")
|
|
81
|
+
return False
|
|
82
|
+
return True
|